From 69a24f212b648eec73faca5642fa535ef2f32457 Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Mon, 29 Jan 2024 22:51:27 +0800 Subject: [PATCH] Add `RepVGG` (#32) * Add `RepVGG` * Update tests * Add test for `get_reparameterized_model` * Update README * Minor update * Fix test * Fix readme * Fix readme --- README.md | 120 ++++---- kimm/layers/__init__.py | 1 + kimm/layers/attention.py | 2 +- kimm/layers/rep_conv2d.py | 260 +++++++++++++++++ kimm/layers/rep_conv2d_test.py | 134 +++++++++ kimm/models/__init__.py | 1 + kimm/models/models_test.py | 27 ++ kimm/models/repvgg.py | 463 ++++++++++++++++++++++++++++++ kimm/utils/__init__.py | 1 + kimm/utils/model_utils.py | 30 ++ kimm/utils/model_utils_test.py | 48 ++++ kimm/utils/timm_utils.py | 28 +- shell/export.sh | 1 + tools/convert_repvgg_from_timm.py | 162 +++++++++++ 14 files changed, 1214 insertions(+), 64 deletions(-) create mode 100644 kimm/layers/rep_conv2d.py create mode 100644 kimm/layers/rep_conv2d_test.py create mode 100644 kimm/models/repvgg.py create mode 100644 kimm/utils/model_utils.py create mode 100644 kimm/utils/model_utils_test.py create mode 100644 tools/convert_repvgg_from_timm.py diff --git a/README.md b/README.md index bbfcbd1..a21774c 100644 --- a/README.md +++ b/README.md @@ -17,66 +17,73 @@ **K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner. -`kimm` is: +KIMM is: -- 🚀 A model zoo where almost all models come with pre-trained weights on ImageNet. +🚀 A model zoo where almost all models come with pre-trained weights on ImageNet. - > **Note:** - > The accuracy of the converted models can be found at [results-imagenet.csv (timm)](https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv) and [https://keras.io/api/applications/ (keras)](https://keras.io/api/applications/), - > and the numerical differences of the converted models can be verified in `tools/convert_*.py` +> [!NOTE] +> The accuracy of the converted models can be found at [results-imagenet.csv (timm)](https://github.com/huggingface/pytorch-image-models/blob/main/results/results-imagenet.csv) and [https://keras.io/api/applications/ (keras)](https://keras.io/api/applications/), +> and the numerical differences of the converted models can be verified in `tools/convert_*.py`. -- ✨ Exposing a common API identical to offcial `keras.applications.*`. +✨ Exposing a common API identical to offcial `keras.applications.*`. - ```python - model = kimm.models.RegNetY002( - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = "imagenet", - name: str = "RegNetY002", - ) - ``` - -- 🔥 Integrated with feature extraction capability. - - ```python - from keras import random - import kimm - - model = kimm.models.ConvNeXtAtto(feature_extractor=True) - x = random.uniform([1, 224, 224, 3]) - y = model(x, training=False) - # y becomes a dict - for k, v in y.items(): - print(k, v.shape) - ``` - -- 🧰 Providing APIs to export models to `.tflite` and `.onnx`. - - ```python - # in tensorflow backend - from keras import backend - import kimm - - backend.set_image_data_format("channels_last") - model = kimm.models.MobileNet050V3Small() - kimm.export.export_tflite(model, [224, 224, 3], "model.tflite") - ``` - - ```python - # in torch backend - from keras import backend - import kimm - - backend.set_image_data_format("channels_first") - model = kimm.models.MobileNet050V3Small() - kimm.export.export_onnx(model, [3, 224, 224], "model.onnx") - ``` +```python +model = kimm.models.RegNetY002( + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "RegNetY002", +) +``` + +🔥 Integrated with feature extraction capability. + +```python +model = kimm.models.ConvNeXtAtto(feature_extractor=True) +x = keras.random.uniform([1, 224, 224, 3]) +y = model.predict(x) +# y becomes a dict +for k, v in y.items(): + print(k, v.shape) +``` + +🧰 Providing APIs to export models to `.tflite` and `.onnx`. + +```python +# tensorflow backend +keras.backend.set_image_data_format("channels_last") +model = kimm.models.MobileNet050V3Small() +kimm.export.export_tflite(model, [224, 224, 3], "model.tflite") +``` + +```python +# torch backend +keras.backend.set_image_data_format("channels_first") +model = kimm.models.MobileNet050V3Small() +kimm.export.export_onnx(model, [3, 224, 224], "model.onnx") +``` + +> [!IMPORTANT] +> `kimm.export.export_tflite` is currently restricted to `tensorflow` backend and `channels_last`. +> `kimm.export.export_onnx` is currently restricted to `torch` backend and `channels_first`. + +🔧 Supporting the reparameterization technique. + +```python +model = kimm.models.RepVGGA0() +reparameterized_model = kimm.utils.get_reparameterized_model(model) +# or +# reparameterized_model = model.get_reparameterized_model() +y1 = model.predict(x) +y2 = model.predict(x) +np.testing.assert_allclose(y1, y2, atol=1e-5) +``` ## Installation @@ -175,6 +182,7 @@ Reference: [Grad-CAM class activation visualization (keras.io)](https://keras.io |MobileViT|[ICLR 2022](https://arxiv.org/abs/2110.02178)|`timm`|`kimm.models.MobileViT*`| |MobileViTV2|[arXiv 2022](https://arxiv.org/abs/2206.02680)|`timm`|`kimm.models.MobileViTV2*`| |RegNet|[CVPR 2020](https://arxiv.org/abs/2003.13678)|`timm`|`kimm.models.RegNet*`| +|RepVGG|[CVPR 2021](https://arxiv.org/abs/2101.03697)|`timm`|`kimm.models.RepVGG*`| |ResNet|[CVPR 2015](https://arxiv.org/abs/1512.03385)|`timm`|`kimm.models.ResNet*`| |TinyNet|[NeurIPS 2020](https://arxiv.org/abs/2010.14819)|`timm`|`kimm.models.TinyNet*`| |VGG|[ICLR 2015](https://arxiv.org/abs/1409.1556)|`timm`|`kimm.models.VGG*`| diff --git a/kimm/layers/__init__.py b/kimm/layers/__init__.py index b495002..577f0aa 100644 --- a/kimm/layers/__init__.py +++ b/kimm/layers/__init__.py @@ -1,3 +1,4 @@ from kimm.layers.attention import Attention from kimm.layers.layer_scale import LayerScale from kimm.layers.position_embedding import PositionEmbedding +from kimm.layers.rep_conv2d import RepConv2D diff --git a/kimm/layers/attention.py b/kimm/layers/attention.py index 6b58cfa..d610557 100644 --- a/kimm/layers/attention.py +++ b/kimm/layers/attention.py @@ -30,7 +30,7 @@ def __init__( self.qkv = layers.Dense( hidden_dim * 3, use_bias=use_qkv_bias, - dtype=self.dtype, + dtype=self.dtype_policy, name=f"{name}_qkv", ) if use_qk_norm: diff --git a/kimm/layers/rep_conv2d.py b/kimm/layers/rep_conv2d.py new file mode 100644 index 0000000..92faef2 --- /dev/null +++ b/kimm/layers/rep_conv2d.py @@ -0,0 +1,260 @@ +import keras +import numpy as np +from keras import Sequential +from keras import layers +from keras import ops +from keras.src.backend import standardize_data_format +from keras.src.layers import Layer +from keras.src.utils.argument_validation import standardize_tuple + + +@keras.saving.register_keras_serializable(package="kimm") +class RepConv2D(Layer): + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding=None, + has_skip: bool = True, + reparameterized: bool = False, + data_format=None, + activation=None, + name="rep_conv2d", + **kwargs, + ): + super().__init__(**kwargs) + self.filters = filters + self.kernel_size = standardize_tuple(kernel_size, 2, "kernel_size") + self.strides = standardize_tuple(strides, 2, "strides") + self.padding = padding + self.has_skip = has_skip + self._reparameterized = reparameterized + self.data_format = standardize_data_format(data_format) + self.activation = activation + self.name = name + + if has_skip is True and (self.strides[0] != 1 or self.strides[1] != 1): + raise ValueError( + "strides must be `1` when `has_skip=True`. " + f"Received: has_skip={has_skip}, strides={strides}" + ) + + self.zero_padding = layers.Identity(dtype=self.dtype_policy) + if padding is None: + padding = "same" + if self.strides[0] > 1: + padding = "valid" + self.zero_padding = layers.ZeroPadding2D( + (self.kernel_size[0] // 2, self.kernel_size[1] // 2), + data_format=self.data_format, + dtype=self.dtype_policy, + name=f"{name}_pad", + ) + self.padding = padding + else: + self.padding = padding + + channel_axis = -1 if self.data_format == "channels_last" else -3 + if self._reparameterized: + self.rep_conv2d = layers.Conv2D( + filters, + kernel_size, + strides, + padding, + data_format=self.data_format, + use_bias=True, + dtype=self.dtype_policy, + name=f"{name}_reparam_conv", + ) + self.identity = None + self.conv_kxk = None + self.conv_1x1 = None + else: + self.rep_conv2d = None + if self.has_skip: + self.identity = layers.BatchNormalization( + axis=channel_axis, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype_policy, + name=f"{name}_identity", + ) + else: + self.identity = None + self.conv_kxk = Sequential( + [ + layers.Conv2D( + filters, + kernel_size, + strides, + padding=self.padding, + data_format=self.data_format, + use_bias=False, + dtype=self.dtype_policy, + ), + layers.BatchNormalization( + axis=channel_axis, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype_policy, + ), + ], + name=f"{name}_conv_kxk", + ) + self.conv_1x1 = Sequential( + [ + layers.Conv2D( + filters, + 1, + strides, + padding=self.padding, + data_format=self.data_format, + use_bias=False, + dtype=self.dtype_policy, + ), + layers.BatchNormalization( + axis=channel_axis, + momentum=0.9, + epsilon=1e-5, + dtype=self.dtype_policy, + ), + ], + name=f"{name}_conv_1x1", + ) + + if activation is None: + self.act = layers.Identity(dtype=self.dtype_policy) + else: + self.act = layers.Activation(activation, dtype=self.dtype_policy) + + # attach extra layers + self.extra_layers = [] + if self.rep_conv2d is not None: + self.extra_layers.append(self.rep_conv2d) + if self.identity is not None: + self.extra_layers.append(self.identity) + if self.conv_kxk is not None: + self.extra_layers.append(self.conv_kxk) + if self.conv_1x1 is not None: + self.extra_layers.append(self.conv_1x1) + self.extra_layers.append(self.act) + + def build(self, input_shape): + if isinstance(self.zero_padding, layers.ZeroPadding2D): + padded_shape = self.zero_padding.compute_output_shape(input_shape) + else: + padded_shape = input_shape + + if self.rep_conv2d is not None: + self.rep_conv2d.build(padded_shape) + if self.identity is not None: + self.identity.build(input_shape) + if self.conv_kxk is not None: + self.conv_kxk.build(padded_shape) + if self.conv_1x1 is not None: + self.conv_1x1.build(input_shape) + + self.built = True + + def call(self, inputs, **kwargs): + x = ops.cast(inputs, self.compute_dtype) + padded_x = self.zero_padding(x) + + # Deploy mode + if self._reparameterized: + return self.act(self.rep_conv2d(padded_x, **kwargs)) + + if self.identity is None: + x = self.conv_1x1(x, **kwargs) + self.conv_kxk(padded_x, **kwargs) + else: + identity = self.identity(x, **kwargs) + x = self.conv_1x1(x, **kwargs) + self.conv_kxk(padded_x, **kwargs) + x = x + identity + return self.act(x) + + def get_config(self): + config = super().get_config() + config.update( + { + "filters": self.filters, + "kernel_size": self.kernel_size, + "strides": self.strides, + "padding": self.padding, + "has_skip": self.has_skip, + "reparameterized": self._reparameterized, + "data_format": self.data_format, + "activation": self.activation, + "name": self.name, + } + ) + return config + + def _get_reparameterized_weights_from_layer(self, layer): + channel_axis = -1 if self.data_format == "channels_last" else -3 + if isinstance(layer, Sequential): + if not isinstance(layer.layers[0], layers.Conv2D): + raise ValueError + if not isinstance(layer.layers[1], layers.BatchNormalization): + raise ValueError + kernel = ops.convert_to_numpy(layer.layers[0].kernel) + gamma = ops.convert_to_numpy(layer.layers[1].gamma) + beta = ops.convert_to_numpy(layer.layers[1].beta) + running_mean = ops.convert_to_numpy(layer.layers[1].moving_mean) + running_var = ops.convert_to_numpy(layer.layers[1].moving_variance) + eps = layer.layers[1].epsilon + elif isinstance(layer, layers.BatchNormalization): + # calculate identity tensor + in_chs = self.conv_kxk.layers[0].input.shape[channel_axis] + kernel_size = self.conv_kxk.layers[0].kernel_size + kernel_value = ops.convert_to_numpy( + ops.zeros_like(self.conv_kxk.layers[0].kernel) + ) + kernel_value = kernel_value.copy() + for i in range(in_chs): + kernel_value[kernel_size[0] // 2, kernel_size[1] // 2, i, i] = 1 + kernel = kernel_value + gamma = ops.convert_to_numpy(layer.gamma) + beta = ops.convert_to_numpy(layer.beta) + running_mean = ops.convert_to_numpy(layer.moving_mean) + running_var = ops.convert_to_numpy(layer.moving_variance) + eps = layer.epsilon + + # use float64 for better precision + kernel = kernel.astype("float64") + gamma = gamma.astype("float64") + beta = beta.astype("float64") + running_var = running_var.astype("float64") + running_var = running_var.astype("float64") + + std = np.sqrt(running_var + eps) + t = np.reshape(gamma / std, [1, 1, 1, -1]) + return kernel * t, beta - running_mean * gamma / std + + def get_reparameterized_weights(self): + kernel_1x1 = 0.0 + bias_1x1 = 0.0 + if self.conv_1x1 is not None: + kernel_1x1, bias_1x1 = self._get_reparameterized_weights_from_layer( + self.conv_1x1 + ) + pad = self.conv_kxk.layers[0].kernel_size[0] // 2 + kernel_1x1 = np.pad( + kernel_1x1, [[pad, pad], [pad, pad], [0, 0], [0, 0]] + ) + + kernel_identity = 0.0 + bias_identity = 0.0 + if self.identity is not None: + ( + kernel_identity, + bias_identity, + ) = self._get_reparameterized_weights_from_layer(self.identity) + + kernel_conv, bias_conv = self._get_reparameterized_weights_from_layer( + self.conv_kxk + ) + + kernel_final = kernel_conv + kernel_1x1 + kernel_identity + bias_final = bias_conv + bias_1x1 + bias_identity + return kernel_final, bias_final diff --git a/kimm/layers/rep_conv2d_test.py b/kimm/layers/rep_conv2d_test.py new file mode 100644 index 0000000..ca6c8df --- /dev/null +++ b/kimm/layers/rep_conv2d_test.py @@ -0,0 +1,134 @@ +import pytest +from absl.testing import parameterized +from keras import backend +from keras import random +from keras.src import testing + +from kimm.layers.rep_conv2d import RepConv2D + +TEST_CASES = [ + { + "filters": 16, + "kernel_size": 3, + "has_skip": True, + "data_format": "channels_last", + "input_shape": (1, 4, 4, 16), + "output_shape": (1, 4, 4, 16), + "num_trainable_weights": 8, + "num_non_trainable_weights": 6, + }, + { + "filters": 16, + "kernel_size": 3, + "has_skip": False, + "data_format": "channels_last", + "input_shape": (1, 4, 4, 8), + "output_shape": (1, 4, 4, 16), + "num_trainable_weights": 6, + "num_non_trainable_weights": 4, + }, + { + "filters": 16, + "kernel_size": 5, + "has_skip": True, + "data_format": "channels_last", + "input_shape": (1, 4, 4, 16), + "output_shape": (1, 4, 4, 16), + "num_trainable_weights": 8, + "num_non_trainable_weights": 6, + }, + { + "filters": 16, + "kernel_size": 3, + "has_skip": True, + "data_format": "channels_first", + "input_shape": (1, 16, 4, 4), + "output_shape": (1, 16, 4, 4), + "num_trainable_weights": 8, + "num_non_trainable_weights": 6, + }, +] + + +class RepConv2DTest(testing.TestCase, parameterized.TestCase): + @parameterized.parameters(TEST_CASES) + @pytest.mark.requires_trainable_backend + def test_rep_conv2d_basic( + self, + filters, + kernel_size, + has_skip, + data_format, + input_shape, + output_shape, + num_trainable_weights, + num_non_trainable_weights, + ): + if ( + backend.backend() == "tensorflow" + and data_format == "channels_first" + ): + self.skipTest( + "Conv2D in tensorflow backend with 'channels_first' is limited " + "to be supported" + ) + self.run_layer_test( + RepConv2D, + init_kwargs={ + "filters": filters, + "kernel_size": kernel_size, + "has_skip": has_skip, + "data_format": data_format, + }, + input_shape=input_shape, + expected_output_shape=output_shape, + expected_num_trainable_weights=num_trainable_weights, + expected_num_non_trainable_weights=num_non_trainable_weights, + expected_num_losses=0, + supports_masking=False, + ) + + @parameterized.parameters(TEST_CASES) + def test_rep_conv2d_get_reparameterized_weights( + self, + filters, + kernel_size, + has_skip, + data_format, + input_shape, + output_shape, + num_trainable_weights, + num_non_trainable_weights, + ): + if ( + backend.backend() == "tensorflow" + and data_format == "channels_first" + ): + self.skipTest( + "Conv2D in tensorflow backend with 'channels_first' is limited " + "to be supported" + ) + layer = RepConv2D( + filters=filters, + kernel_size=kernel_size, + has_skip=has_skip, + data_format=data_format, + ) + layer.build(input_shape) + reparameterized_layer = RepConv2D( + filters=filters, + kernel_size=kernel_size, + has_skip=has_skip, + reparameterized=True, + data_format=data_format, + ) + reparameterized_layer.build(input_shape) + x = random.uniform(input_shape) + + kernel, bias = layer.get_reparameterized_weights() + reparameterized_layer.rep_conv2d.kernel.assign(kernel) + reparameterized_layer.rep_conv2d.bias.assign(bias) + y1 = layer(x, training=False) + y2 = reparameterized_layer(x, training=False) + + self.assertAllClose(y1, y2, atol=1e-5) diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py index b2ec736..2e7fa46 100644 --- a/kimm/models/__init__.py +++ b/kimm/models/__init__.py @@ -10,6 +10,7 @@ from kimm.models.mobilenet_v3 import * # noqa:F403 from kimm.models.mobilevit import * # noqa:F403 from kimm.models.regnet import * # noqa:F403 +from kimm.models.repvgg import * # noqa:F403 from kimm.models.resnet import * # noqa:F403 from kimm.models.vgg import * # noqa:F403 from kimm.models.vision_transformer import * # noqa:F403 diff --git a/kimm/models/models_test.py b/kimm/models/models_test.py index d0a6630..229f421 100644 --- a/kimm/models/models_test.py +++ b/kimm/models/models_test.py @@ -265,6 +265,19 @@ ("BLOCK3_S32", [1, 7, 7, 368]), ], ), + # repvgg + ( + kimm_models.RepVGGA0.__name__, + kimm_models.RepVGGA0, + 224, + [ + ("STEM_S2", [1, 112, 112, 48]), + ("BLOCK0_S4", [1, 56, 56, 48]), + ("BLOCK1_S8", [1, 28, 28, 96]), + ("BLOCK2_S16", [1, 14, 14, 192]), + ("BLOCK3_S32", [1, 7, 7, 1280]), + ], + ), # resnet ( kimm_models.ResNet18.__name__, @@ -335,6 +348,7 @@ ] +@pytest.mark.requires_trainable_backend # numpy is too slow to test class ModelTest(testing.TestCase, parameterized.TestCase): @classmethod def setUpClass(cls): @@ -419,6 +433,19 @@ def test_model_feature_extractor( name, shape = feature_info self.assertEqual(list(y[name].shape), shape) + @parameterized.named_parameters( + (kimm_models.RepVGGA0.__name__, kimm_models.RepVGGA0, 224) + ) + def test_model_get_reparameterized_model(self, model_class, image_size): + x = random.uniform([1, image_size, image_size, 3]) * 255.0 + model = model_class() + reparameterized_model = model.get_reparameterized_model() + + y1 = model(x, training=False) + y2 = reparameterized_model(x, training=False) + + self.assertAllClose(y1, y2, atol=1e-5) + @pytest.mark.serialization @parameterized.named_parameters(MODEL_CONFIGS) def test_model_serialization( diff --git a/kimm/models/repvgg.py b/kimm/models/repvgg.py new file mode 100644 index 0000000..43a8d1a --- /dev/null +++ b/kimm/models/repvgg.py @@ -0,0 +1,463 @@ +import typing + +import keras +from keras import backend + +from kimm import layers as kimm_layers +from kimm.models.base_model import BaseModel +from kimm.utils import add_model_to_registry + + +@keras.saving.register_keras_serializable(package="kimm") +class RepVGG(BaseModel): + available_feature_keys = [ + "STEM_S2", + *[f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])], + ] + + def __init__( + self, + num_blocks: typing.Sequence[int], + num_channels: typing.Sequence[int], + stem_channels: int = 48, + reparameterized: bool = False, + **kwargs, + ): + kwargs["weights_url"] = self.get_weights_url(kwargs["weights"]) + if kwargs["weights_url"] is not None and reparameterized is True: + raise ValueError( + "Weights can only be loaded with `reparameterized=False`. " + "You can first initialize the model with " + "`reparameterized=False` then use " + "`get_reparameterized_model` to get the converted model. " + f"Received: weights={kwargs['weights']}, " + f"reparameterized={reparameterized}" + ) + + input_tensor = kwargs.pop("input_tensor", None) + self.set_properties(kwargs) + channels_axis = ( + -1 if backend.image_data_format() == "channels_last" else -3 + ) + + inputs = self.determine_input_tensor( + input_tensor, + self._input_shape, + self._default_size, + ) + x = inputs + + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} + + # stem + x = kimm_layers.RepConv2D( + stem_channels, + 3, + 2, + has_skip=False, + reparameterized=reparameterized, + activation="relu", + name="stem", + )(x) + features["STEM_S2"] = x + + # stages + current_strides = 2 + for current_stage_idx, (c, n) in enumerate( + zip(num_channels, num_blocks) + ): + strides = 2 + current_strides *= strides + # blocks + for current_block_idx in range(n): + strides = strides if current_block_idx == 0 else 1 + input_channels = x.shape[channels_axis] + has_skip = input_channels == c and strides == 1 + name = f"stages_{current_stage_idx}_{current_block_idx}" + x = kimm_layers.RepConv2D( + c, + 3, + strides, + has_skip=has_skip, + reparameterized=reparameterized, + activation="relu", + name=name, + )(x) + + # add feature + features[f"BLOCK{current_stage_idx}_S{current_strides}"] = x + + # Head + x = self.build_head(x) + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.num_blocks = num_blocks + self.num_channels = num_channels + self.stem_channels = stem_channels + self.reparameterized = reparameterized + + def get_config(self): + config = super().get_config() + config.update( + { + "num_blocks": self.num_blocks, + "num_channels": self.num_channels, + "stem_channels": self.stem_channels, + "reparameterized": self.reparameterized, + } + ) + return config + + def fix_config(self, config): + unused_kwargs = [ + "num_blocks", + "num_channels", + "stem_channels", + ] + for k in unused_kwargs: + config.pop(k, None) + return config + + def get_reparameterized_model(self): + config = self.get_config() + config["reparameterized"] = True + config["weights"] = None + model = RepVGG(**config) + for layer, reparameterized_layer in zip(self.layers, model.layers): + if hasattr(layer, "get_reparameterized_weights"): + kernel, bias = layer.get_reparameterized_weights() + reparameterized_layer.rep_conv2d.kernel.assign(kernel) + reparameterized_layer.rep_conv2d.bias.assign(bias) + else: + for weight, target_weight in zip( + layer.weights, reparameterized_layer.weights + ): + target_weight.assign(weight) + return model + + +""" +Model Definition +""" + + +class RepVGGA0(RepVGG): + available_weights = [ + ( + "imagenet", + RepVGG.default_origin, + "repvgga0_repvgg_a0.rvgg_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "RepVGGA0", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [2, 4, 14, 1], + [48, 96, 192, 1280], + 48, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RepVGGA1(RepVGG): + available_weights = [ + ( + "imagenet", + RepVGG.default_origin, + "repvgga1_repvgg_a1.rvgg_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "RepVGGA1", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [2, 4, 14, 1], + [64, 128, 256, 1280], + 64, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RepVGGA2(RepVGG): + available_weights = [ + ( + "imagenet", + RepVGG.default_origin, + "repvgga2_repvgg_a2.rvgg_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "RepVGGA2", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [2, 4, 14, 1], + [96, 192, 384, 1408], + 64, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RepVGGB0(RepVGG): + available_weights = [ + ( + "imagenet", + RepVGG.default_origin, + "repvggb0_repvgg_b0.rvgg_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "RepVGGB0", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [4, 6, 16, 1], + [64, 128, 256, 1280], + 64, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RepVGGB1(RepVGG): + available_weights = [ + ( + "imagenet", + RepVGG.default_origin, + "repvggb1_repvgg_b1.rvgg_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "RepVGGB1", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [4, 6, 16, 1], + [128, 256, 512, 2048], + 64, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RepVGGB2(RepVGG): + available_weights = [ + ( + "imagenet", + RepVGG.default_origin, + "repvggb2_repvgg_b2.rvgg_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "RepVGGB2", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [4, 6, 16, 1], + [160, 320, 640, 2560], + 64, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +class RepVGGB3(RepVGG): + available_weights = [ + ( + "imagenet", + RepVGG.default_origin, + "repvggb3_repvgg_b3.rvgg_in1k.keras", + ) + ] + + def __init__( + self, + reparameterized: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = "imagenet", + name: str = "RepVGGB3", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + [4, 6, 16, 1], + [192, 384, 768, 2560], + 64, + reparameterized=reparameterized, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +add_model_to_registry(RepVGGA0, "imagenet") +add_model_to_registry(RepVGGA1, "imagenet") +add_model_to_registry(RepVGGA2, "imagenet") +add_model_to_registry(RepVGGB0, "imagenet") +add_model_to_registry(RepVGGB1, "imagenet") +add_model_to_registry(RepVGGB2, "imagenet") +add_model_to_registry(RepVGGB3, "imagenet") diff --git a/kimm/utils/__init__.py b/kimm/utils/__init__.py index b817228..1b6830e 100644 --- a/kimm/utils/__init__.py +++ b/kimm/utils/__init__.py @@ -1,5 +1,6 @@ from kimm.utils.make_divisble import make_divisible from kimm.utils.model_registry import add_model_to_registry +from kimm.utils.model_utils import get_reparameterized_model from kimm.utils.timm_utils import assign_weights from kimm.utils.timm_utils import is_same_weights from kimm.utils.timm_utils import separate_torch_state_dict diff --git a/kimm/utils/model_utils.py b/kimm/utils/model_utils.py new file mode 100644 index 0000000..283d99f --- /dev/null +++ b/kimm/utils/model_utils.py @@ -0,0 +1,30 @@ +from kimm.models.base_model import BaseModel + + +def get_reparameterized_model(model: BaseModel): + if not hasattr(model, "get_reparameterized_model"): + raise ValueError( + "There is no 'get_reparameterized_model' method in the model. " + f"Received: model type={type(model)}" + ) + + config = model.get_config() + if config["reparameterized"] is True: + return model + + config["reparameterized"] = True + config["weights"] = None + reparameterized_model = type(model).from_config(config) + for layer, reparameterized_layer in zip( + model.layers, reparameterized_model.layers + ): + if hasattr(layer, "get_reparameterized_weights"): + kernel, bias = layer.get_reparameterized_weights() + reparameterized_layer.rep_conv2d.kernel.assign(kernel) + reparameterized_layer.rep_conv2d.bias.assign(bias) + else: + for weight, target_weight in zip( + layer.weights, reparameterized_layer.weights + ): + target_weight.assign(weight) + return reparameterized_model diff --git a/kimm/utils/model_utils_test.py b/kimm/utils/model_utils_test.py new file mode 100644 index 0000000..1bb44ed --- /dev/null +++ b/kimm/utils/model_utils_test.py @@ -0,0 +1,48 @@ +from keras import random +from keras.src import testing + +from kimm.models.regnet import RegNetX002 +from kimm.models.repvgg import RepVGG +from kimm.utils.model_utils import get_reparameterized_model + + +class ModelUtilsTest(testing.TestCase): + def test_get_reparameterized_model(self): + # dummy RepVGG with random initialization + model = RepVGG( + [1, 1, 1, 1], + [8, 8, 8, 8], + 8, + include_preprocessing=False, + weights=None, + ) + reparameterized_model = get_reparameterized_model(model) + x = random.uniform([1, 32, 32, 3]) + + y1 = model(x, training=False) + y2 = reparameterized_model(x, training=False) + + self.assertAllClose(y1, y2, atol=1e-5) + + def test_get_reparameterized_model_already(self): + # dummy RepVGG with random initialization and reparameterized=True + model = RepVGG( + [1, 1, 1, 1], + [8, 8, 8, 8], + 8, + reparameterized=True, + include_preprocessing=False, + weights=None, + ) + reparameterized_model = get_reparameterized_model(model) + + # same object + self.assertEqual(id(model), id(reparameterized_model)) + + def test_get_reparameterized_model_invalid(self): + model = RegNetX002(weights=None) + + with self.assertRaisesRegex( + ValueError, "There is no 'get_reparameterized_model' method" + ): + get_reparameterized_model(model) diff --git a/kimm/utils/timm_utils.py b/kimm/utils/timm_utils.py index b26db04..1fd33cb 100644 --- a/kimm/utils/timm_utils.py +++ b/kimm/utils/timm_utils.py @@ -43,13 +43,27 @@ def separate_keras_weights(keras_model: keras.Model): trainable_weights = [] non_trainable_weights = [] for layer in keras_model.layers: - layer: keras.Layer - for weight in layer.trainable_weights: - trainable_weights.append((weight, layer.name + "_" + weight.name)) - for weight in layer.non_trainable_weights: - non_trainable_weights.append( - (weight, layer.name + "_" + weight.name) - ) + if hasattr(layer, "extra_layers"): + for sub_layer in layer.extra_layers: + sub_layer: keras.Layer + for weight in sub_layer.trainable_weights: + trainable_weights.append( + (weight, sub_layer.name + "_" + weight.name) + ) + for weight in sub_layer.non_trainable_weights: + non_trainable_weights.append( + (weight, sub_layer.name + "_" + weight.name) + ) + else: + layer: keras.Layer + for weight in layer.trainable_weights: + trainable_weights.append( + (weight, layer.name + "_" + weight.name) + ) + for weight in layer.non_trainable_weights: + non_trainable_weights.append( + (weight, layer.name + "_" + weight.name) + ) return trainable_weights, non_trainable_weights diff --git a/shell/export.sh b/shell/export.sh index 7ec49b8..37613dd 100755 --- a/shell/export.sh +++ b/shell/export.sh @@ -15,6 +15,7 @@ python3 -m tools.convert_mobilenet_v2_from_timm python3 -m tools.convert_mobilenet_v3_from_timm python3 -m tools.convert_mobilevit_from_timm python3 -m tools.convert_regnet_from_timm +python3 -m tools.convert_repvgg_from_timm python3 -m tools.convert_resnet_from_timm python3 -m tools.convert_vgg_from_timm python3 -m tools.convert_vit_from_timm diff --git a/tools/convert_repvgg_from_timm.py b/tools/convert_repvgg_from_timm.py new file mode 100644 index 0000000..a47cbef --- /dev/null +++ b/tools/convert_repvgg_from_timm.py @@ -0,0 +1,162 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" + +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import repvgg +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "repvgg_a0.rvgg_in1k", + "repvgg_a1.rvgg_in1k", + "repvgg_a2.rvgg_in1k", + "repvgg_b0.rvgg_in1k", + "repvgg_b1.rvgg_in1k", + "repvgg_b2.rvgg_in1k", + "repvgg_b3.rvgg_in1k", +] +keras_model_classes = [ + repvgg.RepVGGA0, + repvgg.RepVGGA1, + repvgg.RepVGGA2, + repvgg.RepVGGB0, + repvgg.RepVGGB1, + repvgg.RepVGGB2, + repvgg.RepVGGB3, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + weights=None, + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # for torch_name, (_, keras_name) in zip( + # non_trainable_state_dict.keys(), non_trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(non_trainable_state_dict.keys())) + # print(len(non_trainable_weights)) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # skip reparam_conv + if "reparam_conv_conv2d" in keras_name: + continue + # repconv2d + torch_name = torch_name.replace( + "conv.kxk.kernel", "conv_kxk.conv.kernel" + ) + torch_name = torch_name.replace("conv.kxk.gamma", "conv_kxk.bn.gamma") + torch_name = torch_name.replace("conv.kxk.beta", "conv_kxk.bn.beta") + torch_name = torch_name.replace( + "conv.1x1.kernel", "conv_1x1.conv.kernel" + ) + torch_name = torch_name.replace("conv.1x1.gamma", "conv_1x1.bn.gamma") + torch_name = torch_name.replace("conv.1x1.beta", "conv_1x1.bn.beta") + # repconv2d bn + torch_name = torch_name.replace( + "conv.kxk.moving.mean", "conv_kxk.bn.moving.mean" + ) + torch_name = torch_name.replace( + "conv.kxk.moving.variance", "conv_kxk.bn.moving.variance" + ) + torch_name = torch_name.replace( + "conv.1x1.moving.mean", "conv_1x1.bn.moving.mean" + ) + torch_name = torch_name.replace( + "conv.1x1.moving.variance", "conv_1x1.bn.moving.variance" + ) + # head + torch_name = torch_name.replace("classifier", "head.fc") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}")