-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
82 lines (65 loc) · 2.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import hydra
import wandb
import logging
import pandas as pd
import pytorch_lightning as pl
from omegaconf.omegaconf import OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from data import DataModule
from model import ColaModel
logger = logging.getLogger(__name__)
class SamplesVisualisationLogger(pl.Callback):
def __init__(self, datamodule):
super().__init__()
self.datamodule = datamodule
def on_validation_end(self, trainer, pl_module):
val_batch = next(iter(self.datamodule.val_dataloader()))
sentences = val_batch["sentence"]
outputs = pl_module(val_batch["input_ids"], val_batch["attention_mask"])
preds = torch.argmax(outputs.logits, 1)
labels = val_batch["label"]
df = pd.DataFrame(
{"Sentence": sentences, "Label": labels.numpy(), "Predicted": preds.numpy()}
)
wrong_df = df[df["Label"] != df["Predicted"]]
trainer.logger.experiment.log(
{
"examples": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),
"global_step": trainer.global_step,
}
)
@hydra.main(config_path="./configs", config_name="config")
def main(cfg):
logger.info(OmegaConf.to_yaml(cfg, resolve=True))
logger.info(f"Using the model: {cfg.model.name}")
logger.info(f"Using the tokenizer: {cfg.model.tokenizer}")
cola_data = DataModule(
cfg.model.tokenizer, cfg.processing.batch_size, cfg.processing.max_length
)
cola_model = ColaModel(cfg.model.name)
checkpoint_callback = ModelCheckpoint(
dirpath="./models",
filename="best-checkpoint",
monitor="valid/loss",
mode="min",
)
early_stopping_callback = EarlyStopping(
monitor="valid/loss", patience=3, verbose=True, mode="min"
)
wandb_logger = WandbLogger(project="MLOps Basics", entity="hoanshiro")
trainer = pl.Trainer(
max_epochs=cfg.training.max_epochs,
logger=wandb_logger,
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
log_every_n_steps=cfg.training.log_every_n_steps,
deterministic=cfg.training.deterministic,
limit_train_batches=cfg.training.limit_train_batches,
limit_val_batches=cfg.training.limit_val_batches,
)
trainer.fit(cola_model, cola_data)
wandb.finish()
if __name__ == "__main__":
main()