@@ -44,7 +44,7 @@ def _get_rel_pos_bias(self, window_size):
44
44
old_sub_table = old_relative_position_bias_table [:old_num_relative_distance - 3 ]
45
45
46
46
old_sub_table = old_sub_table .reshape (1 , old_width , old_height , - 1 ).permute (0 , 3 , 1 , 2 )
47
- new_sub_table = F .interpolate (old_sub_table , size = (new_height , new_width ), mode = "bilinear" )
47
+ new_sub_table = F .interpolate (old_sub_table , size = (int ( new_height ), int ( new_width ) ), mode = "bilinear" )
48
48
new_sub_table = new_sub_table .permute (0 , 2 , 3 , 1 ).reshape (new_num_relative_distance - 3 , - 1 )
49
49
50
50
new_relative_position_bias_table = torch .cat (
@@ -96,12 +96,12 @@ def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tenso
96
96
Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
97
97
"""
98
98
if self .gamma_1 is None :
99
- x = x + self .drop_path (self .attn (self .norm1 (x ), resolution , shared_rel_pos_bias = shared_rel_pos_bias ))
100
- x = x + self .drop_path (self .mlp (self .norm2 (x )))
99
+ x = x + self .drop_path1 (self .attn (self .norm1 (x ), resolution , shared_rel_pos_bias = shared_rel_pos_bias ))
100
+ x = x + self .drop_path2 (self .mlp (self .norm2 (x )))
101
101
else :
102
- x = x + self .drop_path (self .gamma_1 * self .attn (self .norm1 (x ), resolution ,
102
+ x = x + self .drop_path1 (self .gamma_1 * self .attn (self .norm1 (x ), resolution ,
103
103
shared_rel_pos_bias = shared_rel_pos_bias ))
104
- x = x + self .drop_path (self .gamma_2 * self .mlp (self .norm2 (x )))
104
+ x = x + self .drop_path2 (self .gamma_2 * self .mlp (self .norm2 (x )))
105
105
return x
106
106
107
107
0 commit comments