-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmain.py
131 lines (102 loc) · 4.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from argparse import ArgumentParser
from data import get_nli_dataset
import logging
import models
import numpy as np
from pathlib import Path
import random
import torch
from tqdm import tqdm
from transformers import AdamW
parser = ArgumentParser(description="NLI with Transformers")
parser.add_argument("--train_language", type=str, default=None)
parser.add_argument("--test_language", type=str, default=None)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--log_every", type=int, default=100)
parser.add_argument("--learning_rate", type=float, default=0.00005)
parser.add_argument("--gpu", type=int, default=None)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--output_path", type=str, default="output")
parser.add_argument(
"--model",
type=str,
choices=["bert", "roberta"],
default="roberta",
)
parser.add_argument("--data_path", type=str)
logging.basicConfig(level=logging.INFO)
def train(config, train_loader, model, optim, device, epoch):
logging.info("Starting training...")
model.train()
logging.info(f"Epoch: {epoch + 1}/{config.epochs}")
for i, batch in enumerate(train_loader):
optim.zero_grad()
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs[0]
loss.backward()
optim.step()
if i == 0 or i % config.log_every == 0 or i + 1 == len(train_loader):
logging.info(
"Epoch: {} - Progress: {:3.0f}% - Batch: {:>4.0f}/{:<4.0f} - Loss: {:<.4f}".format(
epoch + 1,
100.0 * (1 + i) / len(train_loader),
i + 1,
len(train_loader),
loss.item(),
)
)
def evaluate(model, dataloader, device):
logging.info("Starting evaluation...")
model.eval()
with torch.no_grad():
eval_preds = []
eval_labels = []
for batch in tqdm(dataloader, total=len(dataloader)):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
preds = model(input_ids, attention_mask=attention_mask, labels=labels)
preds = preds[1].argmax(dim=-1)
eval_preds.append(preds.cpu().numpy())
eval_labels.append(batch["labels"].cpu().numpy())
logging.info("Done evaluation")
return np.concatenate(eval_labels), np.concatenate(eval_preds)
def main():
config = parser.parse_args()
np.random.seed(config.seed)
torch.manual_seed(config.seed)
random.seed(config.seed)
device = (
torch.device(f"cuda:{config.gpu}")
if torch.cuda.is_available()
else torch.device("cpu")
)
logging.info(f"Training on {device}.")
tokenizer, model = models.get_model(config)
train_loader, dev_loader, test_loader = get_nli_dataset(config, tokenizer)
optim = AdamW(model.parameters(), lr=config.learning_rate)
model.to(device)
Path(config.output_path).mkdir(parents=True, exist_ok=True)
for epoch in range(config.epochs):
train(config, train_loader, model, optim, device, epoch)
dev_labels, dev_preds = evaluate(model, dev_loader, device)
dev_accuracy = (dev_labels == dev_preds).mean()
logging.info(f"Dev accuracy after epoch {epoch+1}: {dev_accuracy}")
snapshot_path = f"{config.output_path}/{config.model}-mnli_snapshot_epoch_{epoch+1}_devacc_{round(dev_accuracy, 3)}.pt"
torch.save(model, snapshot_path)
test_labels, test_preds = evaluate(model, test_loader, device)
test_accuracy = (test_labels == test_preds).mean()
logging.info(f"Test accuracy for model {config.model}: {test_accuracy}")
with open(
f"{config.output_path}/{config.model}.results.txt",
"w",
) as resultfile:
resultfile.write(f"Test accuracy: {test_accuracy}")
final_snapshot_path = f"{config.output_path}/{config.model}-mnli_final_snapshot_epochs_{config.epochs}_devacc_{round(dev_accuracy, 3)}.pt"
torch.save(model, final_snapshot_path)
if __name__ == "__main__":
main()