-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_flow_matching_2d.py
148 lines (118 loc) · 4.4 KB
/
train_flow_matching_2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from scipy.ndimage import gaussian_filter1d
from torch import Tensor, nn
from torch.nn import Module
from flow_matching import visualization
from flow_matching.datasets import TOY_DATASETS
from flow_matching.solver import ModelWrapper
from flow_matching.utils import set_seed
class Swish(Module):
def forward(self, x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class Mlp(Module):
def __init__(self, dim: int = 2, time_dim: int = 1, h: int = 64) -> None:
super().__init__()
self.input_dim = dim
self.time_dim = time_dim
self.hidden_dim = h
self.layers = nn.Sequential(
nn.Linear(dim + time_dim, h),
Swish(),
nn.Linear(h, h),
Swish(),
nn.Linear(h, h),
Swish(),
nn.Linear(h, dim),
)
def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
size = x_t.size()
x_t = x_t.reshape(-1, self.input_dim)
t = t.reshape(-1, self.time_dim).float()
t = t.reshape(-1, 1).expand(x_t.size(0), 1)
h = torch.cat([x_t, t], dim=1)
output = self.layers(h)
return output.reshape(*size)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, choices=TOY_DATASETS.keys(), required=True)
parser.add_argument("--output-dir", type=str, default="outputs")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42)
args.output_dir = Path(args.output_dir) / "cfm" / args.dataset
args.output_dir.mkdir(parents=True, exist_ok=True)
print(f"Using device: {device}")
print(f"Dataset: {args.dataset}")
# Training parameters
learning_rate = 1e-3
batch_size = 4096
iterations = 20000
log_every = 2000
hidden_dim = 512
dataset = TOY_DATASETS[args.dataset](device=device)
flow = Mlp(dim=dataset.dim, time_dim=1, h=hidden_dim).to(device)
optimizer = torch.optim.AdamW(flow.parameters(), learning_rate)
# Training
losses = []
for global_step in range(iterations):
x_1 = dataset.sample(batch_size)
x_0 = torch.randn_like(x_1).to(device)
t = torch.rand(x_1.size(0), 1).to(device)
# Compute the Conditional Flow Matching objective
# Check eq. (22) and (23) in the paper: https://arxiv.org/abs/2210.02747
# where, we set \sigma_{\min} = 0
x_t = (1 - t) * x_0 + t * x_1 # \phi_t(x_0)
dx_t = x_1 - x_0 # u_t(x|x_t)
optimizer.zero_grad()
loss = F.mse_loss(flow(x_t=x_t, t=t), dx_t)
loss.backward()
optimizer.step()
losses.append(loss.item())
if (global_step + 1) % log_every == 0:
print(f"| step: {global_step+1:6d} | loss: {loss.item():8.4f} |")
flow.eval()
torch.save(flow.state_dict(), Path(args.output_dir) / "ckpt.pth")
# Plot learning curves
steps = np.arange(1, len(losses) + 1)
smoothed_losses = gaussian_filter1d(losses, sigma=5)
blue = "#1f77b4"
plt.figure(figsize=(6, 5))
plt.plot(steps, losses, color=blue, alpha=0.3)
plt.plot(steps, smoothed_losses, color=blue, linewidth=2)
plt.title("Training dynamics", fontsize=16)
plt.xlabel("Steps", fontsize=14)
plt.ylabel("Loss", fontsize=14)
plt.grid(True)
plt.tight_layout()
plt.savefig(Path(args.output_dir) / "losses.png")
print("Training curves saved to", Path(args.output_dir) / "losses.png")
# Sampling with ODE solver and visualization
class WrappedModel(ModelWrapper):
def forward(self, x: Tensor, t: Tensor, **extras) -> Tensor:
return self.model(x_t=x, t=t)
wrapped_model = WrappedModel(flow)
visualization.plot_ode_sampling_evolution(
flow=wrapped_model,
dataset=dataset,
output_dir=args.output_dir,
filename=f"sampling_{args.dataset}_w_solver.png",
)
visualization.save_vector_field_and_samples_as_gif(
flow=wrapped_model,
dataset=dataset,
output_dir=args.output_dir,
filename=f"vector_field_{args.dataset}.gif",
)
visualization.plot_likelihood(
flow=wrapped_model,
dataset=dataset,
output_dir=args.output_dir,
filename=f"likelihood_{args.dataset}.png",
)
if __name__ == "__main__":
main()