diff --git a/models/convnext.py b/models/convnext.py index 94d42384d..8d57b904b 100644 --- a/models/convnext.py +++ b/models/convnext.py @@ -75,7 +75,7 @@ def __init__(self, in_chans=3, num_classes=1000, LayerNorm(dims[0], eps=1e-6, data_format="channels_first") ) self.downsample_layers.append(stem) - for i in range(3): + for i in range(len(dims) - 1): downsample_layer = nn.Sequential( LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), @@ -85,7 +85,7 @@ def __init__(self, in_chans=3, num_classes=1000, self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] cur = 0 - for i in range(4): + for i in range(len(dims)): stage = nn.Sequential( *[Block(dim=dims[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] @@ -106,7 +106,7 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) def forward_features(self, x): - for i in range(4): + for i in range(len(self.stages)): x = self.downsample_layers[i](x) x = self.stages[i](x) return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)