Skip to content

Commit

Permalink
Add RepVGG (#32)
Browse files Browse the repository at this point in the history
* Add `RepVGG`

* Update tests

* Add test for `get_reparameterized_model`

* Update README

* Minor update

* Fix test

* Fix readme

* Fix readme
  • Loading branch information
james77777778 authored Jan 29, 2024
1 parent 219fe28 commit 69a24f2
Show file tree
Hide file tree
Showing 14 changed files with 1,214 additions and 64 deletions.
120 changes: 64 additions & 56 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,66 +17,73 @@

**K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner.

`kimm` is:
KIMM is:

- 🚀 A model zoo where almost all models come with pre-trained weights on ImageNet.
🚀 A model zoo where almost all models come with pre-trained weights on ImageNet.

> **Note:**
> The accuracy of the converted models can be found at [results-imagenet.csv (timm)](https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv) and [https://keras.io/api/applications/ (keras)](https://keras.io/api/applications/),
> and the numerical differences of the converted models can be verified in `tools/convert_*.py`
> [!NOTE]
> The accuracy of the converted models can be found at [results-imagenet.csv (timm)](https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv) and [https://keras.io/api/applications/ (keras)](https://keras.io/api/applications/),
> and the numerical differences of the converted models can be verified in `tools/convert_*.py`.
- ✨ Exposing a common API identical to offcial `keras.applications.*`.
✨ Exposing a common API identical to offcial `keras.applications.*`.

```python
model = kimm.models.RegNetY002(
input_tensor: keras.KerasTensor = None,
input_shape: typing.Optional[typing.Sequence[int]] = None,
include_preprocessing: bool = True,
include_top: bool = True,
pooling: typing.Optional[str] = None,
dropout_rate: float = 0.0,
classes: int = 1000,
classifier_activation: str = "softmax",
weights: typing.Optional[str] = "imagenet",
name: str = "RegNetY002",
)
```

- 🔥 Integrated with feature extraction capability.

```python
from keras import random
import kimm

model = kimm.models.ConvNeXtAtto(feature_extractor=True)
x = random.uniform([1, 224, 224, 3])
y = model(x, training=False)
# y becomes a dict
for k, v in y.items():
print(k, v.shape)
```

- 🧰 Providing APIs to export models to `.tflite` and `.onnx`.

```python
# in tensorflow backend
from keras import backend
import kimm

backend.set_image_data_format("channels_last")
model = kimm.models.MobileNet050V3Small()
kimm.export.export_tflite(model, [224, 224, 3], "model.tflite")
```

```python
# in torch backend
from keras import backend
import kimm

backend.set_image_data_format("channels_first")
model = kimm.models.MobileNet050V3Small()
kimm.export.export_onnx(model, [3, 224, 224], "model.onnx")
```
```python
model = kimm.models.RegNetY002(
input_tensor: keras.KerasTensor = None,
input_shape: typing.Optional[typing.Sequence[int]] = None,
include_preprocessing: bool = True,
include_top: bool = True,
pooling: typing.Optional[str] = None,
dropout_rate: float = 0.0,
classes: int = 1000,
classifier_activation: str = "softmax",
weights: typing.Optional[str] = "imagenet",
name: str = "RegNetY002",
)
```

🔥 Integrated with feature extraction capability.

```python
model = kimm.models.ConvNeXtAtto(feature_extractor=True)
x = keras.random.uniform([1, 224, 224, 3])
y = model.predict(x)
# y becomes a dict
for k, v in y.items():
print(k, v.shape)
```

🧰 Providing APIs to export models to `.tflite` and `.onnx`.

```python
# tensorflow backend
keras.backend.set_image_data_format("channels_last")
model = kimm.models.MobileNet050V3Small()
kimm.export.export_tflite(model, [224, 224, 3], "model.tflite")
```

```python
# torch backend
keras.backend.set_image_data_format("channels_first")
model = kimm.models.MobileNet050V3Small()
kimm.export.export_onnx(model, [3, 224, 224], "model.onnx")
```

> [!IMPORTANT]
> `kimm.export.export_tflite` is currently restricted to `tensorflow` backend and `channels_last`.
> `kimm.export.export_onnx` is currently restricted to `torch` backend and `channels_first`.
🔧 Supporting the reparameterization technique.

```python
model = kimm.models.RepVGGA0()
reparameterized_model = kimm.utils.get_reparameterized_model(model)
# or
# reparameterized_model = model.get_reparameterized_model()
y1 = model.predict(x)
y2 = model.predict(x)
np.testing.assert_allclose(y1, y2, atol=1e-5)
```

## Installation

Expand Down Expand Up @@ -175,6 +182,7 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io
|MobileViT|[ICLR 2022](https://arxiv.org/abs/2110.02178)|`timm`|`kimm.models.MobileViT*`|
|MobileViTV2|[arXiv 2022](https://arxiv.org/abs/2206.02680)|`timm`|`kimm.models.MobileViTV2*`|
|RegNet|[CVPR 2020](https://arxiv.org/abs/2003.13678)|`timm`|`kimm.models.RegNet*`|
|RepVGG|[CVPR 2021](https://arxiv.org/abs/2101.03697)|`timm`|`kimm.models.RepVGG*`|
|ResNet|[CVPR 2015](https://arxiv.org/abs/1512.03385)|`timm`|`kimm.models.ResNet*`|
|TinyNet|[NeurIPS 2020](https://arxiv.org/abs/2010.14819)|`timm`|`kimm.models.TinyNet*`|
|VGG|[ICLR 2015](https://arxiv.org/abs/1409.1556)|`timm`|`kimm.models.VGG*`|
Expand Down
1 change: 1 addition & 0 deletions kimm/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from kimm.layers.attention import Attention
from kimm.layers.layer_scale import LayerScale
from kimm.layers.position_embedding import PositionEmbedding
from kimm.layers.rep_conv2d import RepConv2D
2 changes: 1 addition & 1 deletion kimm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
self.qkv = layers.Dense(
hidden_dim * 3,
use_bias=use_qkv_bias,
dtype=self.dtype,
dtype=self.dtype_policy,
name=f"{name}_qkv",
)
if use_qk_norm:
Expand Down
Loading

0 comments on commit 69a24f2

Please sign in to comment.