Skip to content

Commit d191392

Browse files
vince62sThai Chau Truong
authored and
Thai Chau Truong
committed
Load data at the correct position when resuming from a checkpoint
1 parent 4193a87 commit d191392

File tree

6 files changed

+173
-16
lines changed

6 files changed

+173
-16
lines changed

onmt/inputters/dynamic_iterator.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def __init__(
129129
batch_type,
130130
batch_size,
131131
batch_size_multiple,
132+
resume_corpora_info={},
132133
data_type="text",
133134
bucket_size=2048,
134135
bucket_size_init=-1,
@@ -144,6 +145,7 @@ def __init__(
144145
self.transforms = transforms
145146
self.vocabs = vocabs
146147
self.corpora_info = corpora_info
148+
self.resume_corpora_info = resume_corpora_info
147149
self.task = task
148150
self.init_iterators = False
149151
self.batch_size = batch_size
@@ -171,7 +173,17 @@ def __init__(
171173

172174
@classmethod
173175
def from_opt(
174-
cls, corpora, transforms, vocabs, opt, task, copy, device, stride=1, offset=0
176+
cls,
177+
corpora,
178+
transforms,
179+
vocabs,
180+
opt,
181+
task,
182+
copy,
183+
device,
184+
resume_corpora_info={},
185+
stride=1,
186+
offset=0,
175187
):
176188
"""Initilize `DynamicDatasetIter` with options parsed from `opt`."""
177189
corpora_info = {}
@@ -206,6 +218,7 @@ def from_opt(
206218
opt.batch_type,
207219
batch_size,
208220
batch_size_multiple,
221+
resume_corpora_info=resume_corpora_info,
209222
data_type=opt.data_type,
210223
bucket_size=bucket_size,
211224
bucket_size_init=bucket_size_init,
@@ -388,6 +401,7 @@ def build_dynamic_dataset_iter(
388401
vocabs,
389402
copy=False,
390403
task=CorpusTask.TRAIN,
404+
resume_corpora_info={},
391405
stride=1,
392406
offset=0,
393407
src=None,
@@ -412,7 +426,14 @@ def build_dynamic_dataset_iter(
412426
advance to avoid the GPU waiting during the refilling of the bucket.
413427
"""
414428
transforms = make_transforms(opt, transforms_cls, vocabs)
415-
corpora = get_corpora(opt, task, src=src, tgt=tgt, align=align)
429+
corpora = get_corpora(
430+
opt,
431+
task,
432+
src=src,
433+
tgt=tgt,
434+
align=align,
435+
resume_corpora_info=resume_corpora_info,
436+
)
416437
if corpora is None:
417438
assert task != CorpusTask.TRAIN, "only valid corpus is ignorable."
418439
return None
@@ -442,6 +463,7 @@ def build_dynamic_dataset_iter(
442463
vocabs,
443464
opt,
444465
task,
466+
resume_corpora_info=resume_corpora_info,
445467
copy=copy,
446468
stride=stride,
447469
offset=offset,

onmt/inputters/text_corpus.py

+38-8
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,14 @@ class ParallelCorpus(object):
9999
"""A parallel corpus file pair that can be loaded to iterate."""
100100

101101
def __init__(
102-
self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None
102+
self,
103+
name,
104+
src,
105+
tgt,
106+
align=None,
107+
n_src_feats=0,
108+
src_feats_defaults=None,
109+
line_number_to_resume=0,
103110
):
104111
"""Initialize src & tgt side file path."""
105112
self.id = name
@@ -108,6 +115,12 @@ def __init__(
108115
self.align = align
109116
self.n_src_feats = n_src_feats
110117
self.src_feats_defaults = src_feats_defaults
118+
self.line_number_to_resume = line_number_to_resume
119+
self.can_read_file = False
120+
121+
def activate_reading_mode(self, line_number):
122+
if line_number >= self.line_number_to_resume:
123+
self.can_read_file = True
111124

112125
def load(self, offset=0, stride=1):
113126
"""
@@ -116,7 +129,7 @@ def load(self, offset=0, stride=1):
116129
`stride` example, starting from `offset`.
117130
"""
118131

119-
def make_ex(sline, tline, align):
132+
def make_ex(sline, tline, align, line_number):
120133
sline, sfeats = parse_features(
121134
sline,
122135
n_feats=self.n_src_feats,
@@ -131,6 +144,7 @@ def make_ex(sline, tline, align):
131144
"tgt": tline,
132145
"src_original": sline,
133146
"tgt_original": tline,
147+
"cid_line_number": line_number,
134148
}
135149
if align is not None:
136150
example["align"] = align
@@ -145,19 +159,25 @@ def make_ex(sline, tline, align):
145159
for i, (sline, tline, align) in enumerate(
146160
itertools.zip_longest(fs, ft, fa)
147161
):
162+
self.activate_reading_mode(line_number=i)
163+
if not self.can_read_file:
164+
continue
148165
if (i // stride) % stride == offset:
149-
yield make_ex(sline, tline, align)
166+
yield make_ex(sline, tline, align, i)
150167
else:
151168
with exfile_open(self.src, mode="rb") as fs, exfile_open(
152169
self.tgt, mode="rb"
153170
) as ft, exfile_open(self.align, mode="rb") as fa:
154171
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
172+
self.activate_reading_mode(line_number=i)
173+
if not self.can_read_file:
174+
continue
155175
if (i // stride) % stride == offset:
156176
if tline is not None:
157177
tline = tline.decode("utf-8")
158178
if align is not None:
159179
align = align.decode("utf-8")
160-
yield make_ex(sline.decode("utf-8"), tline, align)
180+
yield make_ex(sline.decode("utf-8"), tline, align, i)
161181

162182
def __str__(self):
163183
cls_name = type(self).__name__
@@ -169,19 +189,25 @@ def __str__(self):
169189
)
170190

171191

172-
def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
192+
def get_corpora(
193+
opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None, resume_corpora_info={}
194+
):
173195
corpora_dict = {}
174196
if task == CorpusTask.TRAIN:
175197
for corpus_id, corpus_dict in opts.data.items():
176198
if corpus_id != CorpusName.VALID:
177199
if corpus_dict.get("path_txt", None) is None:
200+
resume_line = 0
201+
if corpus_id in resume_corpora_info:
202+
resume_line = resume_corpora_info[corpus_id]
178203
corpora_dict[corpus_id] = ParallelCorpus(
179204
corpus_id,
180205
corpus_dict["path_src"],
181206
corpus_dict["path_tgt"],
182207
corpus_dict["path_align"],
183208
n_src_feats=opts.n_src_feats,
184209
src_feats_defaults=opts.src_feats_defaults,
210+
line_number_to_resume=resume_line,
185211
)
186212
else:
187213
corpora_dict[corpus_id] = BlockwiseCorpus(
@@ -244,8 +270,6 @@ def _process(self, stream):
244270
example["src_feats"] = [
245271
feat.strip().split(" ") for feat in example["src_feats"]
246272
]
247-
line_number = i * self.stride + self.offset
248-
example["cid_line_number"] = line_number
249273
example["cid"] = self.cid
250274
if "align" in example:
251275
example["align"] = example["align"].strip().split(" ")
@@ -258,6 +282,7 @@ def _process(self, stream):
258282
or ("align" in example and example["align"] == 0)
259283
):
260284
# empty example: skip
285+
line_number = example["cid_line_number"]
261286
empty_msg = f"Empty line in {self.cid}#{line_number}."
262287
if self.skip_empty_level == "error":
263288
raise IOError(empty_msg)
@@ -282,7 +307,12 @@ def __iter__(self):
282307

283308

284309
def build_corpora_iters(
285-
corpora, transforms, corpora_info, skip_empty_level="warning", stride=1, offset=0
310+
corpora,
311+
transforms,
312+
corpora_info,
313+
skip_empty_level="warning",
314+
stride=1,
315+
offset=0,
286316
):
287317
"""Return `ParallelCorpusIterator` for all corpora defined in opts."""
288318
corpora_iters = dict()

onmt/models/model_saver.py

+90-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import os
22
import torch
33
import re
4+
import subprocess
45
from collections import deque
6+
import onmt.utils
57
from onmt.utils.logging import logger
68
from onmt.inputters.inputter import vocabs_to_dict
79
from onmt.modules.lora import lora_state_dict
810

911

10-
def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
12+
def build_model_saver(
13+
model_opt, opt, model, vocabs, optim, resume_corpora_info, device_id
14+
):
1115
# _check_save_model_path
1216
save_model_path = os.path.abspath(opt.save_model)
1317
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
@@ -20,6 +24,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
2024
optim,
2125
opt.keep_checkpoint,
2226
opt.save_format,
27+
resume_corpora_info,
2328
device_id,
2429
)
2530
return model_saver
@@ -81,6 +86,65 @@ def fix_key(s):
8186
return checkpoint
8287

8388

89+
def load_corpora_info(opts, checkpoint):
90+
message_resume_from_beginning = (
91+
"The training will resume from the beginning of each corpus."
92+
)
93+
# Check if resume_from_corpora is True
94+
if not opts.resume_from_corpora:
95+
logger.info(
96+
"No resume from corpora is specified. " + message_resume_from_beginning
97+
)
98+
return {}
99+
100+
# Check if the corpus list from the last training
101+
# and in the new training are identical.
102+
checkpoint_corpora = checkpoint.get("corpus_info", None)
103+
if checkpoint_corpora is None:
104+
logger.info(
105+
"Incoherent info: Some corpora in the last training "
106+
+ "and in the new list do not match. "
107+
+ message_resume_from_beginning
108+
)
109+
return {}
110+
111+
checkpoint_corpus_names = [name for name in checkpoint_corpora]
112+
new_corpus_names = [name for name in opts.data]
113+
if set(checkpoint_corpus_names) != set(new_corpus_names):
114+
logger.info(
115+
"Incoherent info: Some corpora in the last training "
116+
+ "and in the new list do not match. "
117+
+ message_resume_from_beginning
118+
)
119+
return {}
120+
121+
# For each corpus, check if the last line number to resume
122+
# is smaller than or equal to the number of text lines.
123+
message_incoherent_line_number = (
124+
"Incoherent info: text line numbers "
125+
+ "to resume in some corpora exceed their total numbers of lines. "
126+
+ message_resume_from_beginning
127+
)
128+
for c_name in checkpoint_corpora:
129+
number_of_text_lines = int(
130+
subprocess.getoutput(
131+
"wc -l " + opts.data[c_name]["path_src"] + " | awk '{print $1}'"
132+
)
133+
)
134+
if checkpoint_corpora[c_name] > number_of_text_lines - 1:
135+
logger.info(message_incoherent_line_number)
136+
return {}
137+
138+
# To set the text lines to resume, we increase all text lines by 1
139+
# (and return to the beginning if the end is reached)
140+
checkpoint_corpora[c_name] = (
141+
checkpoint_corpora[c_name] + 1
142+
) % number_of_text_lines
143+
144+
logger.info("The training will resume from the saved text line in each corpus.")
145+
return checkpoint_corpora
146+
147+
84148
class ModelSaverBase(object):
85149
"""Base class for model saving operations
86150
@@ -98,6 +162,7 @@ def __init__(
98162
optim,
99163
keep_checkpoint=-1,
100164
save_format="pytorch",
165+
resume_corpora_info={},
101166
device_id=0,
102167
):
103168
self.base_path = base_path
@@ -108,14 +173,35 @@ def __init__(
108173
self.last_saved_step = None
109174
self.keep_checkpoint = keep_checkpoint
110175
self.save_format = save_format
176+
self.corpus_info = resume_corpora_info
111177
self.device_id = device_id
112178

113179
if keep_checkpoint > 0:
114180
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
115181
if save_format == "safetensors":
116182
self.model_queue = deque([], maxlen=keep_checkpoint)
117183

118-
def save(self, step, moving_average=None):
184+
def update_corpus_info_from_batches(self, batches, distributed=False):
185+
# Update the last text line of each corpus
186+
if batches is not None:
187+
# Gather corpus line numbers to save to checkpoints
188+
batch_cids = sum([batch["cid"] for batch in batches], [])
189+
batch_cid_line_numbers = sum(
190+
[batch["cid_line_number"] for batch in batches], []
191+
)
192+
if distributed:
193+
batch_cids = sum(onmt.utils.distributed.all_gather_list(batch_cids), [])
194+
batch_cid_line_numbers = sum(
195+
onmt.utils.distributed.all_gather_list(batch_cid_line_numbers), []
196+
)
197+
# Save the last processed line number of each corpus
198+
new_corpus_info = {
199+
c_name: cid_line_number
200+
for c_name, cid_line_number in zip(batch_cids, batch_cid_line_numbers)
201+
}
202+
self.corpus_info = {**self.corpus_info, **new_corpus_info}
203+
204+
def save(self, step, moving_average=None, batches=None, distributed=False):
119205
"""Main entry point for model saver
120206
121207
It wraps the `_save` method with checks and apply `keep_checkpoint`
@@ -266,6 +352,7 @@ def _save(self, step, model):
266352
"vocab": vocabs_to_dict(self.vocabs),
267353
"opt": self.model_opt,
268354
"optim": self.optim.state_dict(),
355+
"corpus_info": self.corpus_info,
269356
}
270357
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
271358
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
@@ -355,6 +442,7 @@ def _st_save(self, step, model):
355442
"vocab": vocabs_to_dict(self.vocabs),
356443
"opt": self.model_opt,
357444
"optim": self.optim.state_dict(),
445+
"corpus_info": self.corpus_info,
358446
}
359447

360448
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:

onmt/opts.py

+7
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,13 @@ def _add_train_general_opts(parser):
12631263
help="If training from a checkpoint then this is the "
12641264
"path to the pretrained model's state_dict.",
12651265
)
1266+
group.add(
1267+
"--resume_from_corpora",
1268+
"-resume_from_corpora",
1269+
action="store_true",
1270+
help="If training from a checkpoint and this is set to True "
1271+
" then the data generator will resume from the last line of each corpora.",
1272+
)
12661273
group.add(
12671274
"--reset_optim",
12681275
"-reset_optim",

0 commit comments

Comments
 (0)