-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer.py
126 lines (106 loc) · 3.14 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Command line interface to run the neural network model!
From the project root directory, do:
python trainer.py fit
References:
- https://lightning.ai/docs/pytorch/2.0.2/cli/lightning_cli.html
- https://pytorch-lightning.medium.com/introducing-lightningcli-v2-supercharge-your-training-c070d43c7dd6
"""
import os
import sys
from pathlib import Path
import torch
import lightning as L
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, WandbLogger
from lightning.pytorch.cli import ArgsType, LightningCLI
from chabud.datapipe import ChaBuDDataPipeModule
from chabud.model import ChaBuDNet
from chabud.callbacks import LogIntermediatePredictions
def main():
cwd = os.getcwd()
(Path(cwd) / "logs").mkdir(exist_ok=True)
name = sys.argv[1]
# LOGGERs
name = sys.argv[1]
wandb_logger = WandbLogger(
project="chabud2023",
name=name,
save_dir="logs",
log_model=False,
)
csv_logger = CSVLogger(save_dir="logs/csv_logger", name=name)
# CALLBACKS
lr_cb = LearningRateMonitor(
logging_interval="step",
log_momentum=True,
)
ckpt_cb = ModelCheckpoint(
monitor="val/iou",
mode="max",
save_top_k=2,
verbose=True,
filename="epoch:{epoch}-step:{step}-loss:{val/loss:.3f}-iou:{val/iou:.3f}",
auto_insert_metric_name=False,
)
log_preds_cb = LogIntermediatePredictions(logger=wandb_logger)
# DATAMODULE
dm = ChaBuDDataPipeModule(batch_size=20)
dm.setup()
# MODEL
model = ChaBuDNet(
lr=1e-3, model_name="tinycd", submission_filepath="submission.csv"
)
debug = False
trainer = L.Trainer(
fast_dev_run=False,
limit_train_batches=2 if debug else 1.0,
limit_val_batches=2 if debug else 1.0,
limit_test_batches=2 if debug else 1.0,
devices=1,
accelerator="gpu",
precision="16-mixed",
max_epochs=2 if debug else 20,
accumulate_grad_batches=1,
logger=[
csv_logger,
wandb_logger,
],
callbacks=[ckpt_cb, log_preds_cb],
log_every_n_steps=1,
)
# TRAIN
print("TRAIN")
trainer.fit(
model,
train_dataloaders=dm.train_dataloader(),
val_dataloaders=dm.val_dataloader(),
)
# EVAL
device = "cuda"
print("EVAL")
model = ChaBuDNet.load_from_checkpoint(ckpt_cb.best_model_path).to(device)
model.eval()
model.freeze()
trainer.test(model, dataloaders=dm.test_dataloader())
def cli_main(
save_config_callback=None,
seed_everything_default=42,
trainer_defaults: dict = {"logger": False},
args: ArgsType = None,
):
"""
Command-line inteface to run ChaBuDNet with ChaBuDDataPipeModule.
"""
cli = LightningCLI(
model_class=ChaBuDNet,
datamodule_class=ChaBuDDataPipeModule,
save_config_callback=save_config_callback,
seed_everything_default=seed_everything_default,
trainer_defaults=trainer_defaults,
args=args,
)
if __name__ == "__main__":
# cli_main()
main()
print("Done!")