Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can not use image_size=512 to train interpreter #30

Open
PapaMadeleine2022 opened this issue Aug 26, 2023 · 3 comments
Open

Can not use image_size=512 to train interpreter #30

PapaMadeleine2022 opened this issue Aug 26, 2023 · 3 comments

Comments

@PapaMadeleine2022
Copy link

PapaMadeleine2022 commented Aug 26, 2023

Hello, first, I use guided-diffusion to train DDPMs on my dataset as follows:

MODEL_FLAGS="--batch_size 1 --attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size  512 --learn_sigma True --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"

mpiexec -n 4 python scripts/image_train.py --data_dir /xxx/dataset $MODEL_FLAGS 

Then I use ddpm-segmentation to train interpreter as follows:

MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 512 --noise_schedule linear --num_channels 128 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
DATASET=SatelliteMap # Available datasets: bedroom_28, ffhq_34, cat_15, horse_21, celeba_19, ade_bedroom_30

python train_interpreter.py --exp experiments/${DATASET}/ddpm.json $MODEL_FLAGS

But it stills shows shape mismatch error:

File "/xxx/anaconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1667, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for UNetModel:
	Missing key(s) in state_dict: "input_blocks.7.0.skip_connection.weight", "input_blocks.7.0.skip_connection.bias", "input_blocks.10.1.norm.weight", "input_blocks.10.1.norm.bias", "input_blocks.10.1.qkv.weight", "input_blocks.10.1.qkv.bias", "input_blocks.10.1.proj_out.weight", "input_blocks.10.1.proj_out.bias", "input_blocks.11.1.norm.weight", "input_blocks.11.1.norm.bias", "input_blocks.11.1.qkv.weight", "input_blocks.11.1.qkv.bias", "input_blocks.11.1.proj_out.weight", "input_blocks.11.1.proj_out.bias", "input_blocks.13.0.skip_connection.weight", "input_blocks.13.0.skip_connection.bias".
	Unexpected key(s) in state_dict: "input_blocks.18.0.in_layers.0.weight", "input_blocks.18.0.in_layers.0.bias", "input_blocks.18.0.in_layers.2.weight", "input_blocks.18.0.in_layers.2.bias", "input_blocks.18.0.emb_layers.1.weight", "input_blocks.18.0.emb_layers.1.bias", "input_blocks.18.0.out_layers.0.weight", "input_blocks.18.0.out_layers.0.bias", "input_blocks.18.0.out_layers.3.weight", "input_blocks.18.0.out_layers.3.bias", "input_blocks.19.0.in_layers.0.weight", "input_blocks.19.0.in_layers.0.bias", "input_blocks.19.0.in_layers.2.weight", "input_blocks.19.0.in_layers.2.bias", "input_blocks.19.0.emb_layers.1.weight", "input_blocks.19.0.emb_layers.1.bias", "input_blocks.19.0.out_layers.0.weight", "input_blocks.19.0.out_layers.0.bias", "input_blocks.19.0.out_layers.3.weight", "input_blocks.19.0.out_layers.3.bias", "input_blocks.19.1.norm.weight", "input_blocks.19.1.norm.bias", "input_blocks.19.1.qkv.weight", "input_blocks.19.1.qkv.bias", "input_blocks.19.1.proj_out.weight", "input_blocks.19.1.proj_out.bias", "input_blocks.20.0.in_layers.0.weight", "input_blocks.20.0.in_layers.0.bias", "input_blocks.20.0.in_layers.2.weight", "input_blocks.20.0.in_layers.2.bias", "input_blocks.20.0.emb_layers.1.weight", "input_blocks.20.0.emb_layers.1.bias", "input_blocks.20.0.out_layers.0.weight", "input_blocks.20.0.out_layers.0.bias", "input_blocks.20.0.out_layers.3.weight", "input_blocks.20.0.out_layers.3.bias", "input_blocks.20.1.norm.weight", "input_blocks.20.1.norm.bias", "input_blocks.20.1.qkv.weight", "input_blocks.20.1.qkv.bias", "input_blocks.20.1.proj_out.weight", "input_blocks.20.1.proj_out.bias", "input_blocks.4.0.skip_connection.weight", "input_blocks.4.0.skip_connection.bias", "input_blocks.10.0.skip_connection.weight", "input_blocks.10.0.skip_connection.bias", "input_blocks.16.0.skip_connection.weight", "input_blocks.16.0.skip_connection.bias", "output_blocks.18.0.in_layers.0.weight", "output_blocks.18.0.in_layers.0.bias", "output_blocks.18.0.in_layers.2.weight", "output_blocks.18.0.in_layers.2.bias", "output_blocks.18.0.emb_layers.1.weight", "output_blocks.18.0.emb_layers.1.bias", "output_blocks.18.0.out_layers.0.weight", "output_blocks.18.0.out_layers.0.bias", "output_blocks.18.0.out_layers.3.weight", "output_blocks.18.0.out_layers.3.bias", "output_blocks.18.0.skip_connection.weight", "output_blocks.18.0.skip_connection.bias", "output_blocks.19.0.in_layers.0.weight", "output_blocks.19.0.in_layers.0.bias", "output_blocks.19.0.in_layers.2.weight", "output_blocks.19.0.in_layers.2.bias", "output_blocks.19.0.emb_layers.1.weight", "output_blocks.19.0.emb_layers.1.bias", "output_blocks.19.0.out_layers.0.weight", "output_blocks.19.0.out_layers.0.bias", "output_blocks.19.0.out_layers.3.weight", "output_blocks.19.0.out_layers.3.bias", "output_blocks.19.0.skip_connection.weight", "output_blocks.19.0.skip_connection.bias", "output_blocks.20.0.in_layers.0.weight", "output_blocks.20.0.in_layers.0.bias", "output_blocks.20.0.in_layers.2.weight", "output_blocks.20.0.in_layers.2.bias", "output_blocks.20.0.emb_layers.1.weight", "output_blocks.20.0.emb_layers.1.bias", "output_blocks.20.0.out_layers.0.weight", "output_blocks.20.0.out_layers.0.bias", "output_blocks.20.0.out_layers.3.weight", "output_blocks.20.0.out_layers.3.bias", "output_blocks.20.0.skip_connection.weight", "output_blocks.20.0.skip_connection.bias", "output_blocks.17.1.in_layers.0.weight", "output_blocks.17.1.in_layers.0.bias", "output_blocks.17.1.in_layers.2.weight", "output_blocks.17.1.in_layers.2.bias", "output_blocks.17.1.emb_layers.1.weight", "output_blocks.17.1.emb_layers.1.bias", "output_blocks.17.1.out_layers.0.weight", "output_blocks.17.1.out_layers.0.bias", "output_blocks.17.1.out_layers.3.weight", "output_blocks.17.1.out_layers.3.bias".
	size mismatch for input_blocks.0.0.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 3, 3, 3]).
	size mismatch for input_blocks.0.0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for input_blocks.1.0.in_layers.0.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for input_blocks.1.0.in_layers.0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for input_blocks.1.0.in_layers.2.weight: copying a param with shape torch.Size([64, 64, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for input_blocks.1.0.in_layers.2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for input_blocks.1.0.emb_layers.1.weight: copying a param with shape torch.Size([128, 512]) from checkpoint, the shape in current model is torch.Size([256, 512]).
	size mismatch for input_blocks.1.0.emb_layers.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for input_blocks.1.0.out_layers.0.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for input_blocks.1.0.out_layers.0.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([128]).

How to fix it? Can you give some advises?

@CyberAI-XDU
Copy link

Hello, I get the same question, have you solved it yet?

@PapaMadeleine2022
Copy link
Author

Hello, I get the same question, have you solved it yet?

Not yet

@sxm-ljy
Copy link

sxm-ljy commented Sep 15, 2023

Can this be trained using one's own dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants