-
Notifications
You must be signed in to change notification settings - Fork 5
/
sampling.py
67 lines (57 loc) · 2.18 KB
/
sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from forward_noising import (
get_index_from_list,
sqrt_one_minus_alphas_cumprod,
betas,
posterior_variance,
sqrt_recip_alphas,
)
import matplotlib.pyplot as plt
from dataloader import show_tensor_image
from unet import SimpleUnet
@torch.no_grad()
def sample_timestep(model, x, t):
"""
Calls the model to predict the noise in the image and returns
the denoised image.
Applies noise to this image, if we are not in the last step yet.
"""
betas_t = get_index_from_list(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = get_index_from_list(sqrt_one_minus_alphas_cumprod, t, x.shape)
sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
# Call model (current image - noise prediction)
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
if t == 0:
# As pointed out by Luis Pereira (see YouTube comment)
# The t's are offset from the t's in the paper
return model_mean
else:
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
@torch.no_grad()
def sample_plot_image(model, device, img_size, T):
# Sample noise
img = torch.randn((1, 3, img_size, img_size), device=device)
plt.figure(figsize=(15, 15))
plt.axis("off")
num_images = 10
stepsize = int(T / num_images)
# Reversed iteration
for i in reversed(range(0, T)):
t = torch.tensor([i], device=device, dtype=torch.long)
img = sample_timestep(model, img, t)
img = torch.clamp(img, -1.0, 1.0)
if i % stepsize == 0:
plt.subplot(1, num_images, int(i / stepsize) + 1)
show_tensor_image(img.detach().cpu())
plt.savefig("sample.png")
if __name__ == "__main__":
img_size = 64
T = 300
model = SimpleUnet()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model.load_state_dict(torch.load("trained_models/ddpm_mse_epochs_500.pth"))
model.to(device)
sample_plot_image(model=model, device=device, img_size=img_size, T=T)