-
Notifications
You must be signed in to change notification settings - Fork 2
/
models_vit.py
134 lines (102 loc) · 4.64 KB
/
models_vit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from functools import partial
import paddle
import paddle.nn as nn
from paddle.nn.initializer import Constant, TruncatedNormal
from layer import PatchEmbed, Block
trunc_normal_ = TruncatedNormal(std=.02)
zeros_ = Constant(value=0.)
ones_ = Constant(value=1.)
class VisionTransformer(nn.Layer):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=False):
super().__init__()
self.global_pool = global_pool
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = self.create_parameter([1, 1, embed_dim], is_bias=True)
self.pos_embed = self.create_parameter([1, num_patches + 1, embed_dim], is_bias=True)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in paddle.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.LayerList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
for i in range(depth)])
if self.global_pool:
self.fc_norm = norm_layer(embed_dim)
else:
self.norm = norm_layer(embed_dim) # remove the original norm
# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed)
trunc_normal_(self.cls_token)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand([B, -1, -1]) # stole cls_tokens impl from Phil Wang, thanks
x = paddle.concat((cls_tokens, x), axis=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
if self.global_pool:
x = x[:, 1:, :].mean(axis=1) # global pool without cls token
outcome = self.fc_norm(x)
else:
x = self.norm(x)
outcome = x[:, 0]
return outcome
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def vit_small_patch16(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
return model
def vit_base_patch16(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
return model
def vit_large_patch16(**kwargs):
model = VisionTransformer(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
return model
def vit_huge_patch14(**kwargs):
model = VisionTransformer(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6), **kwargs)
return model
if __name__ == "__main__":
image = paddle.randn([2, 3, 224, 224])
model = vit_base_patch16(num_classes=1000, drop_path_rate=0.1, global_pool=True)
output = model(image)