forked from nicolas-dufour/cheese_classification_challenge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
92 lines (85 loc) · 3.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import wandb
import hydra
from tqdm import tqdm
@hydra.main(config_path="configs/train", config_name="config")
def train(cfg):
logger = wandb.init(project="challenge_cheese", name=cfg.experiment_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = hydra.utils.instantiate(cfg.model.instance).to(device)
optimizer = hydra.utils.instantiate(cfg.optim, params=model.parameters())
loss_fn = hydra.utils.instantiate(cfg.loss_fn)
datamodule = hydra.utils.instantiate(cfg.datamodule)
train_loader = datamodule.train_dataloader()
val_loaders = datamodule.val_dataloader()
for epoch in tqdm(range(cfg.epochs)):
epoch_loss = 0
epoch_num_correct = 0
num_samples = 0
for i, batch in enumerate(train_loader):
images, labels = batch
images = images.to(device)
labels = labels.to(device)
preds = model(images)
loss = loss_fn(preds, labels)
logger.log({"loss": loss.detach().cpu().numpy()})
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.detach().cpu().numpy() * len(images)
epoch_num_correct += (
(preds.argmax(1) == labels).sum().detach().cpu().numpy()
)
num_samples += len(images)
epoch_loss /= num_samples
epoch_acc = epoch_num_correct / num_samples
logger.log(
{
"epoch": epoch,
"train_loss_epoch": epoch_loss,
"train_acc": epoch_acc,
}
)
val_metrics = {}
for val_set_name, val_loader in val_loaders.items():
epoch_loss = 0
epoch_num_correct = 0
num_samples = 0
y_true = []
y_pred = []
for i, batch in enumerate(val_loader):
images, labels = batch
images = images.to(device)
labels = labels.to(device)
preds = model(images)
loss = loss_fn(preds, labels)
y_true.extend(labels.detach().cpu().tolist())
y_pred.extend(preds.argmax(1).detach().cpu().tolist())
epoch_loss += loss.detach().cpu().numpy() * len(images)
epoch_num_correct += (
(preds.argmax(1) == labels).sum().detach().cpu().numpy()
)
num_samples += len(images)
epoch_loss /= num_samples
epoch_acc = epoch_num_correct / num_samples
val_metrics[f"{val_set_name}/loss"] = epoch_loss
val_metrics[f"{val_set_name}/acc"] = epoch_acc
val_metrics[f"{val_set_name}/confusion_matrix"] = (
wandb.plot.confusion_matrix(
y_true=y_true,
preds=y_pred,
class_names=[
datamodule.idx_to_class[i][:10].lower()
for i in range(len(datamodule.idx_to_class))
],
)
)
logger.log(
{
"epoch": epoch,
**val_metrics,
}
)
torch.save(model.state_dict(), cfg.checkpoint_path)
if __name__ == "__main__":
train()