Skip to content

Commit

Permalink
Add vit_base2_patch32_256 for a model between base_patch16 and patch3…
Browse files Browse the repository at this point in the history
…2 with a slightly larger img size and width
  • Loading branch information
rwightman committed Jan 24, 2022
1 parent cf43343 commit 07379c6
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions timm/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def _cfg(url='', **kwargs):
'vit_giant_patch14_224': _cfg(url=''),
'vit_gigantic_patch14_224': _cfg(url=''),

'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),

# patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
Expand Down Expand Up @@ -202,6 +204,7 @@ def _cfg(url='', **kwargs):
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
Expand Down Expand Up @@ -634,6 +637,16 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
return model


@register_model
def vit_base2_patch32_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32)
# FIXME experiment
"""
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs)
model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs)
return model


@register_model
def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
Expand Down

0 comments on commit 07379c6

Please sign in to comment.