-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·80 lines (65 loc) · 2.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python
import os
import sys
import glob
import random
import argparse
import torch
import torch.nn as nn
from Utils.log import trace
from Utils.config import Config
from Utils.DataLoader import DataBatchIterator
from Utils.DataLoader import PAD_WORD
from NMT import Trainer
from NMT import Statistics
from NMT import model_factory
from NMT import dump_checkpoint
def main():
# Load config.
config = Config("train", training=True)
trace(config)
torch.backends.cudnn.benchmark = True
# Load train dataset.
train_data = load_dataset(
config.train_dataset,
config.train_batch_size,
config, prefix="Training:")
# Load valid dataset.
valid_data = load_dataset(
config.valid_dataset,
config.valid_batch_size,
config, prefix="Validation:")
# Build model.
vocab = train_data.get_vocab()
model = model_factory(config,
config.checkpoint, *vocab)
if config.verbose: trace(model)
# start training
trg_vocab = train_data.trg_vocab
padding_idx = trg_vocab.padding_idx
trainer = Trainer(model, trg_vocab, padding_idx, config)
start_epoch = 1
for epoch in range(start_epoch, config.epochs + 1):
trainer.train(epoch, config.epochs,
train_data, valid_data,
train_data.num_batches)
dump_checkpoint(trainer.model, config.save_model)
def load_dataset(dataset, batch_size, config, prefix):
# Load training/validation dataset.
train_src = os.path.join(
config.data_path, dataset + "." + config.src_lang)
train_trg = os.path.join(
config.data_path, dataset + "." + config.trg_lang)
train_data = DataBatchIterator(
train_src, train_trg,
share_vocab=config.share_vocab,
training=config.training,
shuffle=config.shuffle_data,
batch_size=batch_size,
max_length=config.max_seq_len,
vocab=config.save_vocab,
mini_batch_sort_order=config.mini_batch_sort_order)
trace(prefix, train_data)
return train_data
if __name__ == "__main__":
main()