forked from gle-bellier/discrete-fm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
69 lines (52 loc) · 2.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import List
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import torch
from einops import rearrange
def tensor_to_numpy(tensor, dict_size):
"""Convert a PyTorch tensor to a numpy array suitable for Matplotlib."""
B, H, W = tensor.shape
tensor = tensor.reshape(B, 1, H, W)
# remove mask class
tensor = torch.clamp(tensor - 1, 0)
tensor = tensor * (255 // dict_size)
nrow = int(tensor.size(0) ** 0.5)
tensor = make_grid(tensor, nrow=nrow, pad_value=0).cpu()
tensor = tensor.permute(1, 2, 0) # Change to (H, W, C)
array = tensor.numpy().astype(np.uint8) # Scale and convert to uint8
return array
def create_animation(tensors, output_path, duration=5, dict_size=10):
"""Create an animation from a series of tensors and save it as a video file."""
# Choose a reasonable fps, like 25
fps = 25
# Calculate the number of frames to use based on the desired duration and fps
num_frames = fps * duration
# Sample tensors to match the desired number of frames
sampled_indices = np.linspace(0, len(tensors) - 1, num_frames).astype(int)
tensors = [tensors[i] for i in sampled_indices]
# tensor of shape (T B C H W)
frames = [tensor_to_numpy(tensor, dict_size) for tensor in tensors]
frames += [frames[-1]] * fps
fig, ax = plt.subplots()
img = ax.imshow(frames[0])
ax.axis('off')
def update(frame):
img.set_data(frame)
return [img]
ani = animation.FuncAnimation(fig, update, frames=frames, blit=True, interval=1/fps)
# Save the animation
Writer = animation.writers['pillow']
writer = Writer(fps=fps, metadata=dict(artist='Me'), bitrate=1800)
ani.save(output_path, dpi=150, writer=writer)
plt.close(fig)
def plot_generation(xts: List[torch.Tensor], n_plots: int = 5) -> None:
stride = len(xts) // n_plots
xts = [xts[i * stride] for i in range(n_plots)] + [xts[-1]]
grid = torch.stack(xts, dim=0)
grid = torch.clamp(grid - 1, 0).cpu().numpy()
grid = rearrange(grid, 't b h w -> (t h) (b w)')
plt.figure(figsize=(10, 10))
plt.imshow(grid, cmap='gray')
plt.axis('off')