1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
- # References:
8
- # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
- # DeiT: https://github.com/facebookresearch/deit
10
- # --------------------------------------------------------
11
-
1
+ import json
12
2
from functools import partial
13
- from einops .layers .torch import Rearrange
14
- from wave_dynamic_layer import Dynamic_MLP_OFA , Dynamic_MLP_Decoder , Dynamic_Patch_Embed
15
- from operator import mul
16
- from torch .nn .modules .utils import _pair
17
- from torch .nn import Conv2d , Dropout
18
3
19
4
import torch
20
5
import torch .nn as nn
21
- import pdb
22
- import math
23
- from functools import reduce
24
- import json
6
+ from timm .models .vision_transformer import Block
7
+
8
+ from wave_dynamic_layer import Dynamic_MLP_OFA
25
9
26
- from timm .models .vision_transformer import PatchEmbed , Block
27
- from util .pos_embed import get_2d_sincos_pos_embed , get_1d_sincos_pos_embed_from_grid_torch
28
10
29
11
class OFAViT (nn .Module ):
30
- """ Masked Autoencoder with VisionTransformer backbone
31
- """
32
- def __init__ (self , img_size = 224 , patch_size = 16 , drop_rate = 0. ,
33
- embed_dim = 1024 , depth = 24 , num_heads = 16 , wv_planes = 128 , num_classes = 45 ,
34
- global_pool = True , mlp_ratio = 4. , norm_layer = nn .LayerNorm ):
12
+ """Masked Autoencoder with VisionTransformer backbone"""
13
+
14
+ def __init__ (
15
+ self ,
16
+ img_size = 224 ,
17
+ patch_size = 16 ,
18
+ drop_rate = 0.0 ,
19
+ embed_dim = 1024 ,
20
+ depth = 24 ,
21
+ num_heads = 16 ,
22
+ wv_planes = 128 ,
23
+ num_classes = 45 ,
24
+ global_pool = True ,
25
+ mlp_ratio = 4.0 ,
26
+ norm_layer = nn .LayerNorm ,
27
+ ):
35
28
super ().__init__ ()
36
29
37
30
self .wv_planes = wv_planes
@@ -45,26 +38,40 @@ def __init__(self, img_size=224, patch_size=16, drop_rate=0.,
45
38
46
39
# --------------------------------------------------------------------------
47
40
# MAE encoder specifics
48
- self .patch_embed = Dynamic_MLP_OFA (wv_planes = 128 , inter_dim = 128 , kernel_size = 16 , embed_dim = embed_dim )
41
+ self .patch_embed = Dynamic_MLP_OFA (
42
+ wv_planes = 128 , inter_dim = 128 , kernel_size = 16 , embed_dim = embed_dim
43
+ )
49
44
self .num_patches = (img_size // patch_size ) ** 2
50
45
self .cls_token = nn .Parameter (torch .zeros (1 , 1 , embed_dim ))
51
- #---------------------------------------------------------------------------
52
- self .pos_embed = nn .Parameter (torch .zeros (1 , self .num_patches + 1 , embed_dim ), requires_grad = False ) # fixed sin-cos embedding
53
-
54
- self .blocks = nn .ModuleList ([
55
- Block (embed_dim , num_heads , mlp_ratio , qkv_bias = True , norm_layer = norm_layer )
56
- for i in range (depth )])
46
+ # ---------------------------------------------------------------------------
47
+ self .pos_embed = nn .Parameter (
48
+ torch .zeros (1 , self .num_patches + 1 , embed_dim ), requires_grad = False
49
+ ) # fixed sin-cos embedding
50
+
51
+ self .blocks = nn .ModuleList (
52
+ [
53
+ Block (
54
+ embed_dim ,
55
+ num_heads ,
56
+ mlp_ratio ,
57
+ qkv_bias = True ,
58
+ norm_layer = norm_layer ,
59
+ )
60
+ for i in range (depth )
61
+ ]
62
+ )
57
63
58
64
self .head_drop = nn .Dropout (drop_rate )
59
- self .head = nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
60
-
65
+ self .head = (
66
+ nn .Linear (embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
67
+ )
61
68
62
69
def forward_features (self , x , wave_list ):
63
70
# embed patches
64
71
wavelist = torch .tensor (wave_list , device = x .device ).float ()
65
72
self .waves = wavelist
66
73
67
- x ,_ = self .patch_embed (x , self .waves )
74
+ x , _ = self .patch_embed (x , self .waves )
68
75
69
76
x = x + self .pos_embed [:, 1 :, :]
70
77
# append cls token
@@ -96,40 +103,64 @@ def forward(self, x, wave_list):
96
103
97
104
def vit_small_patch16 (** kwargs ):
98
105
model = OFAViT (
99
- patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , mlp_ratio = 4 ,
100
- norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
106
+ patch_size = 16 ,
107
+ embed_dim = 384 ,
108
+ depth = 12 ,
109
+ num_heads = 6 ,
110
+ mlp_ratio = 4 ,
111
+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
112
+ ** kwargs ,
113
+ )
101
114
return model
102
115
116
+
103
117
def vit_base_patch16 (** kwargs ):
104
118
model = OFAViT (
105
- patch_size = 16 , embed_dim = 768 , depth = 12 , num_heads = 12 , mlp_ratio = 4 ,
106
- norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
119
+ patch_size = 16 ,
120
+ embed_dim = 768 ,
121
+ depth = 12 ,
122
+ num_heads = 12 ,
123
+ mlp_ratio = 4 ,
124
+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
125
+ ** kwargs ,
126
+ )
107
127
return model
108
128
109
129
110
130
def vit_large_patch16 (** kwargs ):
111
131
model = OFAViT (
112
- patch_size = 16 , embed_dim = 1024 , depth = 24 , num_heads = 16 , mlp_ratio = 4 ,
113
- norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
132
+ patch_size = 16 ,
133
+ embed_dim = 1024 ,
134
+ depth = 24 ,
135
+ num_heads = 16 ,
136
+ mlp_ratio = 4 ,
137
+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
138
+ ** kwargs ,
139
+ )
114
140
return model
115
141
116
142
117
143
def vit_huge_patch14 (** kwargs ):
118
144
model = OFAViT (
119
- patch_size = 14 , embed_dim = 1280 , depth = 32 , num_heads = 16 , mlp_ratio = 4 ,
120
- norm_layer = partial (nn .LayerNorm , eps = 1e-6 ), ** kwargs )
145
+ patch_size = 14 ,
146
+ embed_dim = 1280 ,
147
+ depth = 32 ,
148
+ num_heads = 16 ,
149
+ mlp_ratio = 4 ,
150
+ norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
151
+ ** kwargs ,
152
+ )
121
153
return model
122
154
123
155
124
- if __name__ == ' __main__' :
125
- check_point = torch .load (' ofa_base_checkpoint_e99.pth' )
156
+ if __name__ == " __main__" :
157
+ check_point = torch .load (" ofa_base_checkpoint_e99.pth" )
126
158
vit_model = vit_base_patch16 ()
127
- vit_model .load_state_dict (check_point [' model' ], strict = False )
159
+ vit_model .load_state_dict (check_point [" model" ], strict = False )
128
160
vit_model = vit_model .cuda ()
129
- C = 2 # can be 2,3,4,6,9,12,13,202 or any number if you can provide the wavelengths of them
130
- inp = torch .randn ([1 ,C , 224 ,224 ]).cuda ()
131
- with open (' waves.json' , 'r' ) as wf :
161
+ C = 2 # number of channels
162
+ inp = torch .randn ([1 , C , 224 , 224 ]).cuda ()
163
+ with open (" waves.json" , "r" ) as wf :
132
164
wavelists = json .load (wf )
133
- test_out = vit_model (inp , wave_list = wavelists [f' { C } ' ])
165
+ test_out = vit_model (inp , wave_list = wavelists [f" { C } " ])
134
166
print (test_out .shape )
135
-
0 commit comments