Skip to content

Commit d43ade6

Browse files
committed
Style fixes
1 parent e9c9ffe commit d43ade6

6 files changed

+752
-332
lines changed

.flake8

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[flake8]
2+
max-line-length = 88
3+
extend-ignore =
4+
# See https://github.com/PyCQA/pycodestyle/issues/373
5+
E203,

models_dwv.py

+84-53
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,30 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
3-
4-
# This source code is licensed under the license found in the
5-
# LICENSE file in the root directory of this source tree.
6-
# --------------------------------------------------------
7-
# References:
8-
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9-
# DeiT: https://github.com/facebookresearch/deit
10-
# --------------------------------------------------------
11-
1+
import json
122
from functools import partial
13-
from einops.layers.torch import Rearrange
14-
from wave_dynamic_layer import Dynamic_MLP_OFA, Dynamic_MLP_Decoder, Dynamic_Patch_Embed
15-
from operator import mul
16-
from torch.nn.modules.utils import _pair
17-
from torch.nn import Conv2d, Dropout
183

194
import torch
205
import torch.nn as nn
21-
import pdb
22-
import math
23-
from functools import reduce
24-
import json
6+
from timm.models.vision_transformer import Block
7+
8+
from wave_dynamic_layer import Dynamic_MLP_OFA
259

26-
from timm.models.vision_transformer import PatchEmbed, Block
27-
from util.pos_embed import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed_from_grid_torch
2810

2911
class OFAViT(nn.Module):
30-
""" Masked Autoencoder with VisionTransformer backbone
31-
"""
32-
def __init__(self, img_size=224, patch_size=16, drop_rate=0.,
33-
embed_dim=1024, depth=24, num_heads=16, wv_planes=128, num_classes=45,
34-
global_pool=True, mlp_ratio=4., norm_layer=nn.LayerNorm):
12+
"""Masked Autoencoder with VisionTransformer backbone"""
13+
14+
def __init__(
15+
self,
16+
img_size=224,
17+
patch_size=16,
18+
drop_rate=0.0,
19+
embed_dim=1024,
20+
depth=24,
21+
num_heads=16,
22+
wv_planes=128,
23+
num_classes=45,
24+
global_pool=True,
25+
mlp_ratio=4.0,
26+
norm_layer=nn.LayerNorm,
27+
):
3528
super().__init__()
3629

3730
self.wv_planes = wv_planes
@@ -45,26 +38,40 @@ def __init__(self, img_size=224, patch_size=16, drop_rate=0.,
4538

4639
# --------------------------------------------------------------------------
4740
# MAE encoder specifics
48-
self.patch_embed = Dynamic_MLP_OFA(wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim)
41+
self.patch_embed = Dynamic_MLP_OFA(
42+
wv_planes=128, inter_dim=128, kernel_size=16, embed_dim=embed_dim
43+
)
4944
self.num_patches = (img_size // patch_size) ** 2
5045
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
51-
#---------------------------------------------------------------------------
52-
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
53-
54-
self.blocks = nn.ModuleList([
55-
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
56-
for i in range(depth)])
46+
# ---------------------------------------------------------------------------
47+
self.pos_embed = nn.Parameter(
48+
torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False
49+
) # fixed sin-cos embedding
50+
51+
self.blocks = nn.ModuleList(
52+
[
53+
Block(
54+
embed_dim,
55+
num_heads,
56+
mlp_ratio,
57+
qkv_bias=True,
58+
norm_layer=norm_layer,
59+
)
60+
for i in range(depth)
61+
]
62+
)
5763

5864
self.head_drop = nn.Dropout(drop_rate)
59-
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
60-
65+
self.head = (
66+
nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
67+
)
6168

6269
def forward_features(self, x, wave_list):
6370
# embed patches
6471
wavelist = torch.tensor(wave_list, device=x.device).float()
6572
self.waves = wavelist
6673

67-
x,_ = self.patch_embed(x, self.waves)
74+
x, _ = self.patch_embed(x, self.waves)
6875

6976
x = x + self.pos_embed[:, 1:, :]
7077
# append cls token
@@ -96,40 +103,64 @@ def forward(self, x, wave_list):
96103

97104
def vit_small_patch16(**kwargs):
98105
model = OFAViT(
99-
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
100-
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
106+
patch_size=16,
107+
embed_dim=384,
108+
depth=12,
109+
num_heads=6,
110+
mlp_ratio=4,
111+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
112+
**kwargs,
113+
)
101114
return model
102115

116+
103117
def vit_base_patch16(**kwargs):
104118
model = OFAViT(
105-
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
106-
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
119+
patch_size=16,
120+
embed_dim=768,
121+
depth=12,
122+
num_heads=12,
123+
mlp_ratio=4,
124+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
125+
**kwargs,
126+
)
107127
return model
108128

109129

110130
def vit_large_patch16(**kwargs):
111131
model = OFAViT(
112-
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
113-
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
132+
patch_size=16,
133+
embed_dim=1024,
134+
depth=24,
135+
num_heads=16,
136+
mlp_ratio=4,
137+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
138+
**kwargs,
139+
)
114140
return model
115141

116142

117143
def vit_huge_patch14(**kwargs):
118144
model = OFAViT(
119-
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4,
120-
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
145+
patch_size=14,
146+
embed_dim=1280,
147+
depth=32,
148+
num_heads=16,
149+
mlp_ratio=4,
150+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
151+
**kwargs,
152+
)
121153
return model
122154

123155

124-
if __name__=='__main__':
125-
check_point = torch.load('ofa_base_checkpoint_e99.pth')
156+
if __name__ == "__main__":
157+
check_point = torch.load("ofa_base_checkpoint_e99.pth")
126158
vit_model = vit_base_patch16()
127-
vit_model.load_state_dict(check_point['model'], strict=False)
159+
vit_model.load_state_dict(check_point["model"], strict=False)
128160
vit_model = vit_model.cuda()
129-
C = 2 # can be 2,3,4,6,9,12,13,202 or any number if you can provide the wavelengths of them
130-
inp = torch.randn([1,C,224,224]).cuda()
131-
with open('waves.json','r') as wf:
161+
C = 2 # number of channels
162+
inp = torch.randn([1, C, 224, 224]).cuda()
163+
with open("waves.json", "r") as wf:
132164
wavelists = json.load(wf)
133-
test_out = vit_model(inp, wave_list=wavelists[f'{C}'])
165+
test_out = vit_model(inp, wave_list=wavelists[f"{C}"])
134166
print(test_out.shape)
135-

0 commit comments

Comments
 (0)