Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Python3 compatibility; better Unicode support (#22)
Browse files Browse the repository at this point in the history
* Python3 compatibility fixes

* Removed print statements

* Fixes suggested by @rsepassi

* Made Python3 ByteTextEncoder compatible with Python2

* Python3 compatibility fixes
  • Loading branch information
vthorsteinsson authored and rsepassi committed Jun 24, 2017
1 parent 204b359 commit 3410bea
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 50 deletions.
2 changes: 1 addition & 1 deletion tensor2tensor/data_generators/generator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import io
import os
import tarfile
import urllib

# Dependency imports

import six
from six.moves import xrange # pylint: disable=redefined-builtin
import six.moves.urllib_request

from tensor2tensor.data_generators.text_encoder import SubwordTextEncoder
from tensor2tensor.data_generators.tokenizer import Tokenizer
Expand Down
3 changes: 2 additions & 1 deletion tensor2tensor/data_generators/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import division
from __future__ import print_function

import cPickle
import gzip
import io
import json
Expand All @@ -32,6 +31,8 @@
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from six.moves import zip # pylint: disable=redefined-builtin
from six.moves import cPickle

from tensor2tensor.data_generators import generator_utils

import tensorflow as tf
Expand Down
87 changes: 55 additions & 32 deletions tensor2tensor/data_generators/text_encoder.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import six
from six.moves import xrange # pylint: disable=redefined-builtin
from collections import defaultdict
from tensor2tensor.data_generators import tokenizer

import tensorflow as tf
Expand All @@ -35,7 +36,10 @@
PAD = '<pad>'
EOS = '<EOS>'
RESERVED_TOKENS = [PAD, EOS]

if six.PY2:
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
else:
RESERVED_TOKENS_BYTES = [bytes(PAD, 'ascii'), bytes(EOS, 'ascii')]

class TextEncoder(object):
"""Base class for converting from ints to/from human readable strings."""
Expand Down Expand Up @@ -87,17 +91,25 @@ class ByteTextEncoder(TextEncoder):
"""Encodes each byte to an id. For 8-bit strings only."""

def encode(self, s):
return [ord(c) + self._num_reserved_ids for c in s]
numres = self._num_reserved_ids
if six.PY2:
return [ord(c) + numres for c in s]
# Python3: explicitly convert to UTF-8
return [c + numres for c in s.encode("utf-8")]

def decode(self, ids):
numres = self._num_reserved_ids
decoded_ids = []
int2byte = six.int2byte
for id_ in ids:
if 0 <= id_ < self._num_reserved_ids:
decoded_ids.append(RESERVED_TOKENS[int(id_)])
if 0 <= id_ < numres:
decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
else:
decoded_ids.append(chr(id_))

return ''.join(decoded_ids)
decoded_ids.append(int2byte(id_ - numres))
if six.PY2:
return ''.join(decoded_ids)
# Python3: join byte arrays and then decode string
return b''.join(decoded_ids).decode("utf-8")

@property
def vocab_size(self):
Expand All @@ -111,20 +123,16 @@ def __init__(self, vocab_filename, reverse=False, num_reserved_ids=2):
"""Initialize from a file, one token per line."""
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
self._reverse = reverse
if vocab_filename is not None:
self._load_vocab_from_file(vocab_filename)
self._load_vocab_from_file(vocab_filename)

def encode(self, sentence):
"""Converts a space-separated string of tokens to a list of ids."""
ret = [self._token_to_id[tok] for tok in sentence.strip().split()]
if self._reverse:
ret = ret[::-1]
return ret
return ret[::-1] if self._reverse else ret

def decode(self, ids):
if self._reverse:
ids = ids[::-1]
return ' '.join([self._safe_id_to_token(i) for i in ids])
seq = reversed(ids) if self._reverse else ids
return ' '.join([self._safe_id_to_token(i) for i in seq])

@property
def vocab_size(self):
Expand Down Expand Up @@ -243,15 +251,22 @@ def _escaped_token_to_subtokens(self, escaped_token):
"""
ret = []
pos = 0
while pos < len(escaped_token):
end = len(escaped_token)
while True:
lesc = len(escaped_token)
while pos < lesc:
end = lesc
while end > pos:
subtoken = self._subtoken_string_to_id.get(escaped_token[pos:end], -1)
if subtoken != -1:
break
end -= 1
ret.append(subtoken)
pos = end
if end > pos:
pos = end
else:
# This kinda should not happen, but it does. Cop out by skipping the
# nonexistent subtoken from the returned list.
# print("Unable to find subtoken in string '{0}'".format(escaped_token))
pos += 1
return ret

@classmethod
Expand Down Expand Up @@ -322,13 +337,13 @@ def build_from_token_counts(self,
# then count the resulting potential subtokens, keeping the ones
# with high enough counts for our new vocabulary.
for i in xrange(num_iterations):
counts = {}
counts = defaultdict(int)
for token, count in six.iteritems(token_counts):
escaped_token = self._escape_token(token)
# we will count all tails of the escaped_token, starting from boundaries
# determined by our current segmentation.
if i == 0:
starts = list(range(len(escaped_token)))
starts = xrange(len(escaped_token))
else:
subtokens = self._escaped_token_to_subtokens(escaped_token)
pos = 0
Expand All @@ -337,31 +352,33 @@ def build_from_token_counts(self,
starts.append(pos)
pos += len(self.subtoken_to_subtoken_string(subtoken))
for start in starts:
for end in xrange(start + 1, len(escaped_token) + 1):
for end in xrange(start + 1, len(escaped_token)):
subtoken_string = escaped_token[start:end]
counts[subtoken_string] = counts.get(subtoken_string, 0) + count
counts[subtoken_string] += count
# array of lists of candidate subtoken strings, by length
len_to_subtoken_strings = []
for subtoken_string, count in six.iteritems(counts):
if count < min_count or len(subtoken_string) <= 1:
lsub = len(subtoken_string)
# all subtoken strings of length 1 are included regardless of count
if count < min_count and lsub != 1:
continue
while len(len_to_subtoken_strings) <= len(subtoken_string):
while len(len_to_subtoken_strings) <= lsub:
len_to_subtoken_strings.append([])
len_to_subtoken_strings[len(subtoken_string)].append(subtoken_string)
len_to_subtoken_strings[lsub].append(subtoken_string)
new_subtoken_strings = []
# consider the candidates longest to shortest, so that if we accept
# a longer subtoken string, we can decrement the counts of its prefixes.
for subtoken_strings in len_to_subtoken_strings[::-1]:
for subtoken_string in subtoken_strings:
count = counts[subtoken_string]
if count < min_count:
if count < min_count and len(subtoken_string) != 1:
# subtoken strings of length 1 are included regardless of count
continue
new_subtoken_strings.append((-count, subtoken_string))
for l in xrange(1, len(subtoken_string)):
counts[subtoken_string[:l]] -= count
# make sure we have all single characters.
new_subtoken_strings.extend([(-counts.get(chr(i), 0), chr(i))
for i in xrange(2**8)])
# Make sure to include the underscore as a subtoken string
new_subtoken_strings.append((0, '_'))
new_subtoken_strings.sort()
self._init_from_list([''] * self._num_reserved_ids +
[p[1] for p in new_subtoken_strings])
Expand Down Expand Up @@ -390,13 +407,19 @@ def _load_from_file(self, filename):
subtoken_strings = []
with tf.gfile.Open(filename) as f:
for line in f:
subtoken_strings.append(line.strip()[1:-1].decode('string-escape'))
if six.PY2:
subtoken_strings.append(line.strip()[1:-1].decode('string-escape'))
else:
subtoken_strings.append(line.strip()[1:-1])
self._init_from_list(subtoken_strings)

def _store_to_file(self, filename):
with tf.gfile.Open(filename, 'w') as f:
for subtoken_string in self._all_subtoken_strings:
f.write('\'' + subtoken_string.encode('string-escape') + '\'\n')
if six.PY2:
f.write('\'' + subtoken_string.encode('string-escape') + '\'\n')
else:
f.write('\'' + subtoken_string + '\'\n')

def _escape_token(self, token):
r"""Translate '\'->'\\' and '_'->'\u', then append '_'.
Expand Down
24 changes: 8 additions & 16 deletions tensor2tensor/data_generators/tokenizer.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,21 @@
from __future__ import division
from __future__ import print_function

import array
import string

# Dependency imports

from six.moves import xrange # pylint: disable=redefined-builtin

from collections import defaultdict

class Tokenizer(object):
"""Vocab for breaking words into wordpieces.
"""

def __init__(self):
self._separator_chars = string.punctuation + string.whitespace
self._separator_char_mask = array.array(
"l", [chr(i) in self._separator_chars for i in xrange(256)])
self.token_counts = dict()
_SEPARATOR_CHAR_SET = set(string.punctuation + string.whitespace)

def _increment_token_count(self, token):
if token in self.token_counts:
self.token_counts[token] += 1
else:
self.token_counts[token] = 1
def __init__(self):
self.token_counts = defaultdict(int)

def encode(self, raw_text):
"""Encode a raw string as a list of tokens.
Expand All @@ -87,11 +79,11 @@ def encode(self, raw_text):
token = raw_text[token_start:pos]
if token != " " or token_start == 0:
ret.append(token)
self._increment_token_count(token)
self.token_counts[token] += 1
token_start = pos
final_token = raw_text[token_start:]
ret.append(final_token)
self._increment_token_count(final_token)
self.token_counts[final_token] += 1
return ret

def decode(self, tokens):
Expand All @@ -111,7 +103,7 @@ def decode(self, tokens):
return ret

def _is_separator_char(self, c):
return self._separator_char_mask[ord(c)]
return c in self._SEPARATOR_CHAR_SET

def _is_word_char(self, c):
return not self._is_separator_char(c)
return c not in self._SEPARATOR_CHAR_SET

0 comments on commit 3410bea

Please sign in to comment.