Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PARSeq Model #2089

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_hub.src.models.parseq.parseq_image_converter import (
PARSeqImageConverter,
)
from keras_hub.src.models.resnet.resnet_image_converter import (
ResNetImageConverter,
)
Expand Down
7 changes: 7 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
from keras_hub.src.models.parseq.parseq_preprocessor import PARSeqPreprocessor
from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer
from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import (
Expand Down Expand Up @@ -336,6 +339,10 @@
from keras_hub.src.models.text_classifier_preprocessor import (
TextClassifierPreprocessor,
)
from keras_hub.src.models.text_recognition import TextRecognition
from keras_hub.src.models.text_recognition_preprocessor import (
TextRecognitionPreprocessor,
)
from keras_hub.src.models.text_to_image import TextToImage
from keras_hub.src.models.vgg.vgg_backbone import VGGBackbone
from keras_hub.src.models.vgg.vgg_image_classifier import VGGImageClassifier
Expand Down
1 change: 1 addition & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)
from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer
from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_hub.src.models.roberta.roberta_tokenizer import RobertaTokenizer
from keras_hub.src.models.t5.t5_tokenizer import T5Tokenizer
Expand Down
Empty file.
62 changes: 62 additions & 0 deletions keras_hub/src/models/parseq/parseq_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.PARSeqBackbone")
class PARSeqBackbone(Backbone):
"""Scene Text Detection with PARSeq.
Performs OCR in natural scenes using the PARSeq model described in [Scene
Text Recognition with Permuted Autoregressive Sequence Models](
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
iterative decoding by performing an autoregressive decoding phase, followed
by a refinement phase.
"""

def __init__(
self,
image_encoder,
decode_autoregressive=True,
alphabet_size=97,
max_text_length=25,
num_decoder_layers=1,
num_decoder_heads=12,
dropout_rate=0.1,
dtype=None,
**kwargs,
):
# === Layers ===
self.image_encoder = image_encoder

image_input = self.image_encoder.input
output = self.image_encoder(image_input)

# === Config ===
self.decode_autoregressive = decode_autoregressive
self.alphabet_size = alphabet_size
self.max_text_length = max_text_length
self.num_decoder_layers = num_decoder_layers
self.num_decoder_heads = num_decoder_heads
self.dropout_rate = dropout_rate

super().__init__(
inputs=image_input,
outputs=output,
dtype=dtype,
**kwargs,
)

def get_config(self):
config = super().get_config()
config.update(
{
"encoder": keras.layers.serialize(self.image_encoder),
"alphabet_size": self.alphabet_size,
"max_text_length": self.max_text_length,
"num_decoder_layers": self.num_decoder_layers,
"num_decoder_heads": self.num_decoder_heads,
"dropout_rate": self.dropout_rate,
}
)
Empty file.
8 changes: 8 additions & 0 deletions keras_hub/src/models/parseq/parseq_image_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone


@keras_hub_export("keras_hub.layers.PARSeqImageConverter")
class PARSeqImageConverter(ImageConverter):
backbone_cls = PARSeqBackbone
14 changes: 14 additions & 0 deletions keras_hub/src/models/parseq/parseq_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
from keras_hub.src.models.parseq.parseq_image_converter import (
PARSeqImageConverter,
)
from keras_hub.src.models.text_recognition_preprocessor import (
TextRecognitionPreprocessor,
)


@keras_hub_export("keras_hub.models.PARSeqPreprocessor")
class PARSeqPreprocessor(TextRecognitionPreprocessor):
backbone_cls = PARSeqBackbone
image_converter_cls = PARSeqImageConverter
142 changes: 142 additions & 0 deletions keras_hub/src/models/parseq/parseq_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import re

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.tokenizers import tokenizer
from keras_hub.src.utils.tensor_utils import is_int_dtype
from keras_hub.src.utils.tensor_utils import is_string_dtype
from keras_hub.src.utils.tensor_utils import preprocessing_function

PARSEQ_VOCAB = (
"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"
"\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
)

try:
import tensorflow as tf
import tensorflow_text as tf_text
except ImportError:
tf = None
tf_text = None


@keras_hub_export(
[
"keras_hub.tokenizers.PARSeqTokenizer",
"keras_hub.models.PARSeqTokenizer",
]
)
class PARSeqTokenizer(tokenizer.Tokenizer):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mattdangerw,

When you get some time can you take look at this tokenizer, the original implementation is here:

Preprocess for dataset: https://github.com/baudm/parseq/blob/main/strhub/data/dataset.py
Tokenizer: https://github.com/baudm/parseq/blob/main/strhub/data/utils.py#L102

I followed roughly other tokenizers in kerashub and wrote this, we can further discussion on what changes are required.

Thanks!

def __init__(
self,
vocabulary=PARSEQ_VOCAB,
remove_whitespace=True,
normalize_unicode=True,
max_label_length=25,
dtype="int32",
**kwargs,
):
if not is_int_dtype(dtype) and not is_string_dtype(dtype):
raise ValueError(
"Output dtype must be an integer type or a string. "
f"Received: dtype={dtype}"
)
super().__init__(dtype=dtype, **kwargs)
self.vocabulary = vocabulary
self.target_charset = tf.convert_to_tensor(vocabulary, dtype=tf.string)
self.lowercase_only = self.target_charset == tf.strings.lower(
self.target_charset
)
self.uppercase_only = self.target_charset == tf.strings.upper(
self.target_charset
)
escaped_charset = re.escape(vocabulary) # Escape for safe regex
self.unsupported_regex = f"[^{escaped_charset}]"
self._itos = ("[E]",) + tuple(vocabulary) + ("[B]", "[P]")
self._stoi = {s: i for i, s in enumerate(self._itos)}

self.remove_whitespace = remove_whitespace
self.normalize_unicode = normalize_unicode
self.max_label_length = max_label_length
self._add_special_token("[B]", "start_token")
self._add_special_token("[E]", "end_token")
self._add_special_token("[P]", "pad_token")
# Create lookup tables.
self.char_to_id = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=list(self._stoi.keys()),
values=list(self._stoi.values()),
key_dtype=tf.string,
value_dtype=tf.int32,
),
default_value=0,
)
self.id_to_char = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=list(self._stoi.values()),
values=list(self._stoi.keys()),
key_dtype=tf.int32,
value_dtype=tf.string,
),
default_value=self.pad_token,
)

def id_to_token(self, id):
if id >= self.vocabulary_size() or id < 0:
raise ValueError(
f"`id` must be in range [0, {self.vocabulary_size() - 1}]. "
f"Received: {id}"
)
return self._itos[id]

def token_to_id(self, token):
return self._stoi[token]

def _preprocess(self, label):
"""Performs preprocessing include only characters from ASCII."""
if self.remove_whitespace:
label = tf.strings.regex_replace(label, r"\s+", "")

if self.normalize_unicode:
label = tf_text.normalize_utf8(label, normalization_form="NFKD")
label = tf.strings.regex_replace(label, r"[^!-~]", "")

if self.lowercase_only:
label = tf.strings.lower(label)
elif self.uppercase_only:
label = tf.strings.upper(label)

label = tf.strings.regex_replace(label, self.unsupported_regex, "")
label = tf.strings.substr(label, 0, self.max_label_length)

return label

@preprocessing_function
def tokenize(self, inputs):
self._check_vocabulary()
inputs = tf.convert_to_tensor(inputs)
unbatched = inputs.shape.rank == 0
if unbatched:
inputs = tf.expand_dims(inputs, 0)

inputs = tf.map_fn(self._preprocess, inputs, dtype=tf.string)

if tf.size(inputs) > 0:
chars = tf.strings.unicode_split(inputs, "UTF-8")
token_ids = self.char_to_id.lookup(chars)
token_ids = tf.cast(token_ids, dtype=tf.int32)
else:
token_ids = tf.ragged.constant([], dtype=tf.int32)

return token_ids

def vocabulary_size(self):
"""Get the integer size of the tokenizer vocabulary."""
self._check_vocabulary()
return len(self.vocabulary)

def _check_vocabulary(self):
if self.vocabulary is None:
raise ValueError(
"No vocabulary has been set for PARSeqTokenizer. Make sure "
"to pass a `vocabulary` argument when creating the layer."
)
65 changes: 65 additions & 0 deletions keras_hub/src/models/text_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.task import Task


@keras_hub_export("keras_hub.models.TextRecognition")
class TextRecognition(Task):
"""Base class for all TextRecognition tasks.
`TextRecognition` tasks wrap a `keras_hub.models.Task` and
a `keras_hub.models.Preprocessor` to create a model that can be used for
recognizing text in images.
All `TextRecognition` tasks include a `from_preset()` constructor which can
be used to load a pre-trained config and weights.
"""

def compile(
self,
optimizer="auto",
loss="auto",
*,
metrics="auto",
**kwargs,
):
"""Configures the `ImageOCR` task for training.
The `ImageOCR` task extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
`metrics`. To override these defaults, pass any value
to these arguments during compilation.
Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for the given model and task. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, where a
`keras.losses.SparseCategoricalCrossentropy` loss will be
applied for the classification task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
metrics: `"auto"`, or a list of metrics to be evaluated by
the model during training and testing. Defaults to `"auto"`,
where a `keras.metrics.SparseCategoricalAccuracy` will be
applied to track the accuracy of the model during training.
See `keras.Model.compile` and `keras.metrics` for
more info on possible `metrics` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.Adam(1e-4)
if loss == "auto":
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
if metrics == "auto":
metrics = [keras.metrics.SparseCategoricalAccuracy()]
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
**kwargs,
)
Loading
Loading