-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathdofa_v1.py
executable file
·117 lines (93 loc) · 3.92 KB
/
dofa_v1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from functools import partial
from wave_dynamic_layer import Dynamic_MLP_OFA
from operator import mul
from torch.nn.modules.utils import _pair
from torch.nn import Conv2d, Dropout
import numpy as np
import torch
import torch.nn as nn
import pdb
import math
from functools import reduce
import json
from timm.models.vision_transformer import PatchEmbed, Block
class OFAViT(nn.Module):
""" Masked Autoencoder with VisionTransformer backbone
"""
def __init__(self, img_size=224, patch_size=16, drop_rate=0.,
embed_dim=1024, depth=24, num_heads=16, wv_planes=128, num_classes=45,
global_pool=True, mlp_ratio=4., norm_layer=nn.LayerNorm):
super().__init__()
self.wv_planes = wv_planes
self.global_pool = global_pool
if self.global_pool:
norm_layer = norm_layer
embed_dim = embed_dim
self.fc_norm = norm_layer(embed_dim)
else:
self.norm = norm_layer(embed_dim)
self.patch_embed = Dynamic_MLP_OFA(wv_planes=128, inter_dim=128, kernel_size=patch_size, embed_dim=embed_dim)
self.num_patches = (img_size // patch_size) ** 2
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
for i in range(depth)])
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, wave_list):
# embed patches
wavelist = torch.tensor(wave_list, device=x.device).float()
self.waves = wavelist
x, _ = self.patch_embed(x, self.waves)
x = x + self.pos_embed[:, 1:, :]
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for block in self.blocks:
x = block(x)
if self.global_pool:
x = x[:, 1:, :].mean(dim=1) # global pool without cls token
outcome = self.fc_norm(x)
else:
x = self.norm(x)
outcome = x[:, 0]
return outcome
def forward_head(self, x, pre_logits=False):
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x, wave_list):
x = self.forward_features(x, wave_list)
x = self.forward_head(x)
return x
def vit_small_patch16(**kwargs):
model = OFAViT(
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_base_patch16(**kwargs):
model = OFAViT(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_large_patch16(**kwargs):
model = OFAViT(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model
def vit_huge_patch14(**kwargs):
model = OFAViT(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model