From d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c Mon Sep 17 00:00:00 2001 From: ygong Date: Sat, 9 Jul 2022 17:50:29 -0400 Subject: [PATCH] fix a bug for using smaller fstride with pretrained model --- src/models/ast_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/models/ast_models.py b/src/models/ast_models.py index 3e73aa0..be5e0fb 100644 --- a/src/models/ast_models.py +++ b/src/models/ast_models.py @@ -145,6 +145,11 @@ def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_ # otherwise interpolate else: new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear') + if f_dim < 12: + new_pos_embed = new_pos_embed[:, :, 6 - int(f_dim/2): 6 - int(f_dim/2) + f_dim, :] + # otherwise interpolate + elif f_dim > 12: + new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear') new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2) self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))