Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the PARSeq model for Scene Text Recognition #2036

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_hub.src.models.parseq.parseq_image_converter import (
PARSeqImageConverter,
)
from keras_hub.src.models.resnet.resnet_image_converter import (
ResNetImageConverter,
)
Expand Down
4 changes: 4 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@
from keras_hub.src.models.image_object_detector_preprocessor import (
ImageObjectDetectorPreprocessor,
)
from keras_hub.src.models.image_ocr import ImageOCR
from keras_hub.src.models.image_segmenter import ImageSegmenter
from keras_hub.src.models.image_segmenter_preprocessor import (
ImageSegmenterPreprocessor,
Expand Down Expand Up @@ -250,6 +251,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
from keras_hub.src.models.parseq.parseq_ocr import PARSeqImageOCR
from keras_hub.src.models.parseq.parseq_preprocessor import PARSeqPreprocessor
from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import (
Expand Down
65 changes: 65 additions & 0 deletions keras_hub/src/models/image_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.task import Task


@keras_hub_export("keras_hub.models.ImageOCR")
class ImageOCR(Task):
"""Base class for all OCR tasks.

`ImageOCR` tasks wrap a `keras_hub.models.Task` and
a `keras_hub.models.Preprocessor` to create a model that can be used for
recognizing text in images.

All `ImageOCR` tasks include a `from_preset()` constructor which can
be used to load a pre-trained config and weights.
"""

def compile(
self,
optimizer="auto",
loss="auto",
*,
metrics="auto",
**kwargs,
):
"""Configures the `ImageOCR` task for training.

The `ImageOCR` task extends the default compilation signature of
`keras.Model.compile` with defaults for `optimizer`, `loss`, and
`metrics`. To override these defaults, pass any value
to these arguments during compilation.

Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default optimizer
for the given model and task. See `keras.Model.compile` and
`keras.optimizers` for more info on possible `optimizer` values.
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, where a
`keras.losses.SparseCategoricalCrossentropy` loss will be
applied for the classification task. See
`keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
metrics: `"auto"`, or a list of metrics to be evaluated by
the model during training and testing. Defaults to `"auto"`,
where a `keras.metrics.SparseCategoricalAccuracy` will be
applied to track the accuracy of the model during training.
See `keras.Model.compile` and `keras.metrics` for
more info on possible `metrics` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
optimizer = keras.optimizers.Adam(1e-4)
if loss == "auto":
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
if metrics == "auto":
metrics = [keras.metrics.SparseCategoricalAccuracy()]
super().compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
**kwargs,
)
Empty file.
Loading
Loading