Skip to content

Commit

Permalink
Add the Segment Anything Model to KerasCV (keras-team#1987)
Browse files Browse the repository at this point in the history
* Start adding components for the segment anything model

* SAMLayerNormalization -> keras.layers.LayerNormalization

They both behave exactly the same when moving_mean and moving_variance are None and epsilon is 1e-6

* Move the image encoder to detectron2 backbone and fix for tf.keras backend

* Address review comments and address saving bug

- Use `keras_cv.export_api.keras_cv_export` instead of `keras.saving.register_keras_serializable`.
- Add a `SerializableSequential` class to address the saving bug with the `Sequential` model.
- Push the helper functions in `keras_cv/layers/detectron2_layers.py` to the bottom of the file.
- Add the detectron2 layers to the `keras_cv/layers/__init__.py` file.
- Add a test for the `ViTDetPatchingAndEmbedding` layer.

* Make the backbone functional; unite MLP and MLPBlock

* Address David's review comments

* Add SAM Task model; make MaskDecoder and PromptEncoder XLA compatible

* Remove a stray file

* Add docs for the Task model

* Add more references

[skip ci]

* Remove SerializableSequential layer

* detectron2 -> vit_det; add SAM presets; fix ViTDet presets

* Increse test tolerence for GCB Run
  • Loading branch information
tirthasheshpatel authored Sep 19, 2023
1 parent 6ce4365 commit c90fa35
Show file tree
Hide file tree
Showing 25 changed files with 3,153 additions and 3 deletions.
1 change: 1 addition & 0 deletions keras_cv/backend/tf_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from keras_core.src.backend.tensorflow.numpy import * # noqa: F403, F401

# Some TF APIs where the numpy API doesn't support raggeds that we need
from tensorflow import broadcast_to # noqa: F403, F401
from tensorflow import concat as concatenate # noqa: F403, F401
from tensorflow import range as arange # noqa: F403, F401
from tensorflow import reduce_all as all # noqa: F403, F401
Expand Down
5 changes: 5 additions & 0 deletions keras_cv/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,9 @@
)
from keras_cv.layers.spatial_pyramid import SpatialPyramidPooling
from keras_cv.layers.transformer_encoder import TransformerEncoder
from keras_cv.layers.vit_det_layers import AddRelativePositionalEmbedding
from keras_cv.layers.vit_det_layers import MultiHeadAttentionWithRelativePE
from keras_cv.layers.vit_det_layers import ViTDetPatchingAndEmbedding
from keras_cv.layers.vit_det_layers import WindowedTransformerEncoder
from keras_cv.layers.vit_det_layers import WindowPartitioning
from keras_cv.layers.vit_layers import PatchingAndEmbedding
Loading

0 comments on commit c90fa35

Please sign in to comment.