-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
'ASTModel' object has no attribute 'module' #80
Comments
Hi, could you paste the full error message? And which code are you running? |
I'm running my code. I'd like to make ASTModel act as an encoder encoding the input signals. |
If you could paste your code piece of model creation and weight loading, I can take a look. But it seems that you forgot to convert the model to a if not isinstance(audio_model, torch.nn.DataParallel):
audio_model = torch.nn.DataParallel(audio_model) The reason is our weights are trained with |
One simple sample of correct model creation and weight loading is in the Colab script at https://github.com/YuanGongND/ast/blob/master/Audio_Spectrogram_Transformer_Inference_Demo.ipynb It can be run online with one click. I guess what you need is: input_tdim = 1024
checkpoint_path = '/content/ast/pretrained_models/audio_mdl.pth'
ast_mdl = ASTModel(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False)
print(f'[*INFO] load checkpoint: {checkpoint_path}')
checkpoint = torch.load(checkpoint_path, map_location='cuda')
audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0])
audio_model.load_state_dict(checkpoint) |
Yeah, I've tried added DataParallel, it does work. BTW, does AST only support the input of length 1024? |
No, it supports any length with or without using pretrained weights. If you don't use pretrained weights, just change the If you want to use pretrained weights, do NOT: ast_mdl = ASTModel(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=False, audioset_pretrain=False)
print(f'[*INFO] load checkpoint: {checkpoint_path}')
checkpoint = torch.load(checkpoint_path, map_location='cuda')
audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0])
audio_model.load_state_dict(checkpoint) Do: ast_mdl = ASTModel(label_dim=527, input_tdim=input_tdim, imagenet_pretrain=True, audioset_pretrain=True)
audio_model = torch.nn.DataParallel(ast_mdl, device_ids=[0]) The reason is the model need to internally adjust the positional embedding mismatch between pretraining and fine-tuning due to the difference of |
I encountered this error when applying this repo on my data. How to fix it?
The text was updated successfully, but these errors were encountered: