Skip to content

Commit

Permalink
fix a bug for using smaller fstride with pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanGongND committed Jul 9, 2022
1 parent 87a8004 commit d7d8b4b
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/models/ast_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_
# otherwise interpolate
else:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
if f_dim < 12:
new_pos_embed = new_pos_embed[:, :, 6 - int(f_dim/2): 6 - int(f_dim/2) + f_dim, :]
# otherwise interpolate
elif f_dim > 12:
new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))

Expand Down

0 comments on commit d7d8b4b

Please sign in to comment.