Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How can I convert Imagenet checkpoint to a .pkl file to run stylegan2-ada-pytorch? #1

Open
bekir23 opened this issue Mar 24, 2022 · 3 comments

Comments

@bekir23
Copy link

bekir23 commented Mar 24, 2022

Hi. the paper that you wrote is quite useful for gan training. But I try to use Imagenet checkpoints that you share to run stylegan2-ada-pytorch and then this repo ask me .pkl file to start training. I exported [StyleGAN-ADA-256]. It include folder of 'data', 'data.pkl' and 'version' files. This data.pkl file's size is too little. I think it is pointless to start gan training because similar model files has 100mb or up size. I think that I have to convert 'data' folder to 'data.pkl'. How can I do this?

@anvoynov
Copy link
Contributor

Hi! Do you use the file from the link from the readme?
https://www.dropbox.com/s/7gll7weysn1ull7/imagenet-256-state.pt
it's the standard pytorch *.pt file and you can use it with the official StyleGAN2-ADA code: https://github.com/NVlabs/stylegan2-ada-pytorch

@bekir23
Copy link
Author

bekir23 commented May 23, 2022

These .pt files do not be used in https://github.com/NVlabs/stylegan2-ada-pytorch . It asks .pkl file. I have to convert pt file to pkl fıle. I searched a litle bit on it and encountered some codes that it show me a way to solve it (https://pytorch.org/tutorials/beginner/saving_loading_models.html)
According to above link. I decide to create model structure to load .pt file.

I added some codes in training_loop.py in the stylegan2-ada repo.
training_loop.py load structure of model in below lines while training:

G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module
G_ema = copy.deepcopy(G).eval()

Thus, we reach 3 model structure and the we should load our weight to these strucrutes. First we load pt file and matching appropiate weights to related model structure:

pt_file = torch.load('/stylegan2-ada-pytorch/pretrained_models/imagenet-256-state.pt')
G.load_state_dict(pt_file['G'])
D.load_state_dict(pt_file['D'])
G_ema.load_state_dict(pt_file['G_ema'])

Then we have to save this model as pickle file. Pickle file consist of dict objects,create dict then load your model into it:

resume_data={}
for name, module in [('G', G), ('D', D), ('G_ema', G_ema)]:
    resume_data[name]=module

then save it as pickle:

dest='/stylegan2-ada-pytorch/pretrained_models/imagenet-256-state.pkl'
with open(dest, 'wb') as f:
    pickle.dump(resume_data, f)

Then you can use this pkl file to train your gan model in styklegan2 ada repo.

If do you know any better solution to use these weights in stylegan2 ada repo. Please let me know.
Thank you

@naeun0620
Copy link

naeun0620 commented Mar 29, 2023

Can you upload **G_kwargs, **D_kwargs, and **common_kwargs too? @bekir23

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants