-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_to_tf2_ckpt.py
59 lines (48 loc) · 2.19 KB
/
convert_to_tf2_ckpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import tensorflow as tf
import numpy as np
from model import UNet
def convert_small_model(ckpt_path, keys_filename, ckpt_name="cifar10"):
unet = UNet(attention_resolutions=(16,))
ckpt = tf.train.Checkpoint(model=unet)
ckpt_reader = tf.train.load_checkpoint(ckpt_path)
weights = []
with open(keys_filename) as f:
for line in f:
line = line.strip()
if not len(line):
continue
weights.append(ckpt_reader.get_tensor(line.split(" ")[0]))
inputs = tf.constant(np.random.uniform(-1, 1, (2, 32, 32, 3)).astype("float32"))
time = tf.constant(np.arange(2).astype("float32"))
unet(inputs, time, training=False)
unet.set_weights(weights)
out = unet(inputs, time, training=False)
ckpt.save(ckpt_name)
def convert_large_model(ckpt_path, keys_filename, ckpt_name="celebahq256"):
model = UNet(attention_resolutions=(16,), multipliers=(1, 1, 2, 2, 4, 4))
ckpt = tf.train.Checkpoint(model=model)
ckpt_reader = tf.train.load_checkpoint(ckpt_path)
weights = []
with open(keys_filename) as f:
for line in f:
line = line.strip()
if not len(line):
continue
weights.append(ckpt_reader.get_tensor(line.split(" ")[0]))
inputs = tf.constant(np.random.uniform(-1, 1, (2, 256, 256, 3)).astype("float32"))
time = tf.constant(np.arange(2).astype("float32"))
model(inputs, time, training=False)
model.set_weights(weights)
out = model(inputs, time, training=False)
ckpt.save(ckpt_name)
if __name__ == "__main__":
convert_small_model("diffusion_cifar10_model/model.ckpt-790000", "small_model_keys")
print("finished converting cifar10 model")
convert_large_model("diffusion_lsun_bedroom_model/model.ckpt-2388000", "large_model_keys", "lsun_bedroom")
print("finished converting lsun bedroom model")
convert_large_model("diffusion_lsun_cat_model/model.ckpt-1761000", "large_model_keys", "lsun_cat")
print("finished converting lsun cat model")
convert_large_model("diffusion_lsun_church_model/model.ckpt-4432000", "large_model_keys", "lsun_church")
print("finished converting lsun church model")
convert_large_model("diffusion_celeba_hq_model/model.ckpt-560000", "large_model_keys", "celebahq256")
print("finished converting celebahq256 model")