-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PARSeq Model #2089
Draft
sineeli
wants to merge
21
commits into
keras-team:master
Choose a base branch
from
sineeli:parseq
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
PARSeq Model #2089
Changes from 13 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
528d3a4
Base for parseq model
sineeli 3bf11cd
make it vit compatiable with diff height and width sizes
sineeli a8fb177
correct vit conv scripts
sineeli 6f4363a
make class token optional in backbone by default its included
sineeli d1cece0
add flags to adjust vit network
sineeli 92b2745
add test case for without class_token
sineeli ed00b73
Merge branch 'master' into parseq
sineeli 25f661c
decoder file
sineeli f97fab1
parseq tokenizer base
sineeli d424210
add api for parseq tokenizer
sineeli 3f3ad0d
Add missing arg max_label_length.
sineeli bb4457e
nit
sineeli 68829f8
Merge branch 'master' into parseq
sineeli 1bde466
add missing normalization step using tf_text
sineeli e6c5379
add missing config for preprocessor
sineeli 5b08c93
add default start, pad and end tokens
sineeli 49260ef
nit
sineeli b4150ed
correct special token order
sineeli ed8b9d7
return padding mask as well
sineeli 4e4511c
use proper keras ops
sineeli 9222331
nit
sineeli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
|
||
|
||
@keras_hub_export("keras_hub.models.PARSeqBackbone") | ||
class PARSeqBackbone(Backbone): | ||
"""Scene Text Detection with PARSeq. | ||
Performs OCR in natural scenes using the PARSeq model described in [Scene | ||
Text Recognition with Permuted Autoregressive Sequence Models]( | ||
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows | ||
iterative decoding by performing an autoregressive decoding phase, followed | ||
by a refinement phase. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_encoder, | ||
decode_autoregressive=True, | ||
alphabet_size=97, | ||
max_text_length=25, | ||
num_decoder_layers=1, | ||
num_decoder_heads=12, | ||
dropout_rate=0.1, | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Layers === | ||
self.image_encoder = image_encoder | ||
|
||
image_input = self.image_encoder.input | ||
output = self.image_encoder(image_input) | ||
|
||
# === Config === | ||
self.decode_autoregressive = decode_autoregressive | ||
self.alphabet_size = alphabet_size | ||
self.max_text_length = max_text_length | ||
self.num_decoder_layers = num_decoder_layers | ||
self.num_decoder_heads = num_decoder_heads | ||
self.dropout_rate = dropout_rate | ||
|
||
super().__init__( | ||
inputs=image_input, | ||
outputs=output, | ||
dtype=dtype, | ||
**kwargs, | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"encoder": keras.layers.serialize(self.image_encoder), | ||
"alphabet_size": self.alphabet_size, | ||
"max_text_length": self.max_text_length, | ||
"num_decoder_layers": self.num_decoder_layers, | ||
"num_decoder_heads": self.num_decoder_heads, | ||
"dropout_rate": self.dropout_rate, | ||
} | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter | ||
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.PARSeqImageConverter") | ||
class PARSeqImageConverter(ImageConverter): | ||
backbone_cls = PARSeqBackbone |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone | ||
from keras_hub.src.models.parseq.parseq_image_converter import ( | ||
PARSeqImageConverter, | ||
) | ||
from keras_hub.src.models.text_recognition_preprocessor import ( | ||
TextRecognitionPreprocessor, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.PARSeqPreprocessor") | ||
class PARSeqPreprocessor(TextRecognitionPreprocessor): | ||
backbone_cls = PARSeqBackbone | ||
image_converter_cls = PARSeqImageConverter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import re | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.tokenizers import tokenizer | ||
from keras_hub.src.utils.tensor_utils import is_int_dtype | ||
from keras_hub.src.utils.tensor_utils import is_string_dtype | ||
from keras_hub.src.utils.tensor_utils import preprocessing_function | ||
|
||
PARSEQ_VOCAB = ( | ||
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!" | ||
"\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" | ||
) | ||
|
||
try: | ||
import tensorflow as tf | ||
except ImportError: | ||
tf = None | ||
|
||
|
||
@keras_hub_export( | ||
[ | ||
"keras_hub.tokenizers.PARSeqTokenizer", | ||
"keras_hub.models.PARSeqTokenizer", | ||
] | ||
) | ||
class PARSeqTokenizer(tokenizer.Tokenizer): | ||
def __init__( | ||
self, | ||
vocabulary=PARSEQ_VOCAB, | ||
remove_whitespace=True, | ||
normalize_unicode=True, | ||
max_label_length=25, | ||
dtype="int32", | ||
**kwargs, | ||
): | ||
if not is_int_dtype(dtype) and not is_string_dtype(dtype): | ||
raise ValueError( | ||
"Output dtype must be an integer type or a string. " | ||
f"Received: dtype={dtype}" | ||
) | ||
self._add_special_token("[E]", "start_token") | ||
self._add_special_token("[B]", "end_token") | ||
self._add_special_token("[P]", "pad_token") | ||
super().__init__(dtype=dtype, **kwargs) | ||
self.vocabulary = vocabulary | ||
self.target_charset = tf.convert_to_tensor(vocabulary, dtype=tf.string) | ||
self.lowercase_only = self.target_charset == tf.strings.lower( | ||
self.target_charset | ||
) | ||
self.uppercase_only = self.target_charset == tf.strings.upper( | ||
self.target_charset | ||
) | ||
escaped_charset = re.escape(vocabulary) # Escape for safe regex | ||
self.unsupported_regex = f"[^{escaped_charset}]" | ||
self._itos = (self.start_token,) + tuple(vocabulary) + (self.end_token,) | ||
self._stoi = {s: i for i, s in enumerate(self._itos)} | ||
|
||
# Create lookup tables. | ||
self.char_to_id = tf.lookup.StaticHashTable( | ||
initializer=tf.lookup.KeyValueTensorInitializer( | ||
keys=list(self._stoi.keys()), | ||
values=list(self._stoi.values()), | ||
key_dtype=tf.string, | ||
value_dtype=tf.int32, | ||
), | ||
default_value=0, | ||
) | ||
self.id_to_char = tf.lookup.StaticHashTable( | ||
initializer=tf.lookup.KeyValueTensorInitializer( | ||
keys=list(self._stoi.values()), | ||
values=list(self._stoi.keys()), | ||
key_dtype=tf.int32, | ||
value_dtype=tf.string, | ||
), | ||
default_value=self.pad_token, | ||
) | ||
|
||
self.remove_whitespace = remove_whitespace | ||
self.normalize_unicode = normalize_unicode | ||
self.max_label_length = max_label_length | ||
|
||
def id_to_token(self, id): | ||
if id >= self.vocabulary_size() or id < 0: | ||
raise ValueError( | ||
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. " | ||
f"Received: {id}" | ||
) | ||
return self._itos[id] | ||
|
||
def token_to_id(self, token): | ||
return self._stoi[token] | ||
|
||
def _preprocess(self, label): | ||
"""Performs preprocessing include only characters from ASCII.""" | ||
if self.remove_whitespace: | ||
label = tf.strings.regex_replace(label, r"\s+", "") | ||
|
||
if self.normalize_unicode: | ||
label = tf.strings.regex_replace(label, "[^!-~]", "") | ||
|
||
if self.lowercase_only: | ||
label = tf.strings.lower(label) | ||
elif self.uppercase_only: | ||
label = tf.strings.upper(label) | ||
|
||
label = tf.strings.regex_replace(label, self.unsupported_regex, "") | ||
label = tf.strings.substr(label, 0, self.max_label_length) | ||
|
||
return label | ||
|
||
@preprocessing_function | ||
def tokenize(self, inputs): | ||
self._check_vocabulary() | ||
inputs = tf.convert_to_tensor(inputs) | ||
unbatched = inputs.shape.rank == 0 | ||
if unbatched: | ||
inputs = tf.expand_dims(inputs, 0) | ||
|
||
inputs = tf.map_fn(self._preprocess, inputs, dtype=tf.string) | ||
|
||
if tf.size(inputs) > 0: | ||
chars = tf.strings.unicode_split(inputs, "UTF-8") | ||
token_ids = self.char_to_id.lookup(chars) | ||
token_ids = tf.cast(token_ids, dtype=tf.int32) | ||
else: | ||
token_ids = tf.ragged.constant([], dtype=tf.int32) | ||
|
||
return token_ids | ||
|
||
def vocabulary_size(self): | ||
"""Get the integer size of the tokenizer vocabulary.""" | ||
self._check_vocabulary() | ||
return len(self.vocabulary) | ||
|
||
def _check_vocabulary(self): | ||
if self.vocabulary is None: | ||
raise ValueError( | ||
"No vocabulary has been set for PARSeqTokenizer. Make sure " | ||
"to pass a `vocabulary` argument when creating the layer." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.task import Task | ||
|
||
|
||
@keras_hub_export("keras_hub.models.TextRecognition") | ||
class TextRecognition(Task): | ||
"""Base class for all TextRecognition tasks. | ||
`TextRecognition` tasks wrap a `keras_hub.models.Task` and | ||
a `keras_hub.models.Preprocessor` to create a model that can be used for | ||
recognizing text in images. | ||
All `TextRecognition` tasks include a `from_preset()` constructor which can | ||
be used to load a pre-trained config and weights. | ||
""" | ||
|
||
def compile( | ||
self, | ||
optimizer="auto", | ||
loss="auto", | ||
*, | ||
metrics="auto", | ||
**kwargs, | ||
): | ||
"""Configures the `ImageOCR` task for training. | ||
The `ImageOCR` task extends the default compilation signature of | ||
`keras.Model.compile` with defaults for `optimizer`, `loss`, and | ||
`metrics`. To override these defaults, pass any value | ||
to these arguments during compilation. | ||
Args: | ||
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` | ||
instance. Defaults to `"auto"`, which uses the default optimizer | ||
for the given model and task. See `keras.Model.compile` and | ||
`keras.optimizers` for more info on possible `optimizer` values. | ||
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. | ||
Defaults to `"auto"`, where a | ||
`keras.losses.SparseCategoricalCrossentropy` loss will be | ||
applied for the classification task. See | ||
`keras.Model.compile` and `keras.losses` for more info on | ||
possible `loss` values. | ||
metrics: `"auto"`, or a list of metrics to be evaluated by | ||
the model during training and testing. Defaults to `"auto"`, | ||
where a `keras.metrics.SparseCategoricalAccuracy` will be | ||
applied to track the accuracy of the model during training. | ||
See `keras.Model.compile` and `keras.metrics` for | ||
more info on possible `metrics` values. | ||
**kwargs: See `keras.Model.compile` for a full list of arguments | ||
supported by the compile method. | ||
""" | ||
if optimizer == "auto": | ||
optimizer = keras.optimizers.Adam(1e-4) | ||
if loss == "auto": | ||
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False) | ||
if metrics == "auto": | ||
metrics = [keras.metrics.SparseCategoricalAccuracy()] | ||
super().compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=metrics, | ||
**kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.preprocessor import Preprocessor | ||
from keras_hub.src.utils.tensor_utils import preprocessing_function | ||
|
||
|
||
@keras_hub_export("keras_hub.models.TextRecognitionPreprocessor") | ||
class TextRecognitionPreprocessor(Preprocessor): | ||
"""Base class for image segmentation preprocessing layers. | ||
|
||
`TextRecognitionPreprocessor` wraps a | ||
`keras_hub.layers.ImageConverter` to create a preprocessing layer for | ||
text recognition tasks. It is intended to be paired with a | ||
`keras_hub.models.TextRecognition` task. | ||
|
||
All `TextRecognitionPreprocessor` instances take three inputs: `x`, `y`, and | ||
`sample_weight`. | ||
|
||
- `x`: The first input, should always be included. It can be an image or | ||
a batch of images. | ||
- `y`: (Optional) text representation of the letters/words from image. | ||
- `sample_weight`: (Optional) Will be passed through unaltered. | ||
|
||
The layer will output either `x`, an `(x, y)` tuple if labels were provided, | ||
or an `(x, y, sample_weight)` tuple if labels and sample weight were | ||
provided. `x` will be the input images after all model preprocessing has | ||
been applied. | ||
|
||
All `TextRecognitionPreprocessor` tasks include a `from_preset()` | ||
constructor which can be used to load a pre-trained config. | ||
You can call the `from_preset()` constructor directly on this base class, in | ||
which case the correct class for your model will be automatically | ||
instantiated. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_converter=None, | ||
tokenizer=None, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.image_converter = image_converter | ||
self.tokenizer = tokenizer | ||
|
||
@preprocessing_function | ||
def call(self, x, y=None, sample_weight=None): | ||
if self.image_converter: | ||
x = self.image_converter(x) | ||
if y is not None: | ||
y = self.tokenizer(y) | ||
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mattdangerw,
When you get some time can you take look at this tokenizer, the original implementation is here:
Preprocess for dataset: https://github.com/baudm/parseq/blob/main/strhub/data/dataset.py
Tokenizer: https://github.com/baudm/parseq/blob/main/strhub/data/utils.py#L102
I followed roughly other tokenizers in kerashub and wrote this, we can further discussion on what changes are required.
Thanks!