-
Notifications
You must be signed in to change notification settings - Fork 19
/
checkpoint.py
171 lines (139 loc) · 6.03 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import logging
import numpy
import os
import time
from contextlib import closing
from six.moves import cPickle
from blocks.extensions.saveload import SAVED_TO, LOADED_FROM
from blocks.extensions import TrainingExtension, SimpleExtension
from blocks.serialization import secure_dump, load, BRICK_DELIMITER
from blocks.utils import reraise_as
logger = logging.getLogger(__name__)
class SaveLoadUtils(object):
"""Utility class for checkpointing."""
@property
def path_to_folder(self):
return self.folder
@property
def path_to_parameters(self):
return os.path.join(self.folder, 'params.npz')
@property
def path_to_iteration_state(self):
return os.path.join(self.folder, 'iterations_state.pkl')
@property
def path_to_log(self):
return os.path.join(self.folder, 'log')
def load_parameter_values(self, path):
with closing(numpy.load(path)) as source:
param_values = {}
for name, value in source.items():
if name != 'pkl':
name_ = name.replace(BRICK_DELIMITER, '/')
if not name_.startswith('/'):
name_ = '/' + name_
param_values[name_] = value
return param_values
def save_parameter_values(self, param_values, path):
param_values = {name.replace("/", BRICK_DELIMITER): param
for name, param in param_values.items()}
numpy.savez(path, **param_values)
class CheckpointNMT(SimpleExtension, SaveLoadUtils):
"""Redefines checkpointing for NMT.
Saves only parameters (npz), iteration state (pickle) and log (pickle).
"""
def __init__(self, saveto, **kwargs):
self.folder = saveto
kwargs.setdefault("after_training", True)
super(CheckpointNMT, self).__init__(**kwargs)
def dump_parameters(self, main_loop):
params_to_save = main_loop.model.get_parameter_values()
self.save_parameter_values(params_to_save,
self.path_to_parameters)
def dump_iteration_state(self, main_loop):
secure_dump(main_loop.iteration_state, self.path_to_iteration_state)
def dump_log(self, main_loop):
secure_dump(main_loop.log, self.path_to_log, cPickle.dump)
def dump(self, main_loop):
if not os.path.exists(self.path_to_folder):
os.mkdir(self.path_to_folder)
print("")
logger.info(" Saving model")
start = time.time()
logger.info(" ...saving parameters")
self.dump_parameters(main_loop)
logger.info(" ...saving iteration state")
self.dump_iteration_state(main_loop)
logger.info(" ...saving log")
self.dump_log(main_loop)
logger.info(" Model saved, took {} seconds.".format(time.time()-start))
def do(self, callback_name, *args):
try:
self.dump(self.main_loop)
except Exception:
raise
finally:
already_saved_to = self.main_loop.log.current_row.get(SAVED_TO, ())
self.main_loop.log.current_row[SAVED_TO] = (already_saved_to +
(self.path_to_folder +
'params.npz',))
class LoadNMT(TrainingExtension, SaveLoadUtils):
"""Loads parameters log and iterations state."""
def __init__(self, saveto, **kwargs):
self.folder = saveto
super(LoadNMT, self).__init__(saveto, **kwargs)
def before_training(self):
if not os.path.exists(self.path_to_folder):
logger.info("No dump found")
return
logger.info("Loading the state from {} into the main loop"
.format(self.path_to_folder))
try:
self.load_to(self.main_loop)
self.main_loop.log.current_row[LOADED_FROM] = self.path_to_folder
except Exception:
reraise_as("Failed to load the state")
def load_parameters(self):
return self.load_parameter_values(self.path_to_parameters)
def load_iteration_state(self):
with open(self.path_to_iteration_state, "rb") as source:
return load(source, use_cpickle=True)
def load_log(self):
with open(self.path_to_log, "rb") as source:
return cPickle.load(source)
def load_to(self, main_loop):
"""Loads the dump from the root folder into the main loop."""
logger.info(" Reloading model")
try:
logger.info(" ...loading model parameters")
params_all = self.load_parameters()
params_this = main_loop.model.get_parameter_dict()
missing = set(params_this.keys()) - set(params_all.keys())
for pname in params_this.keys():
if pname in params_all:
val = params_all[pname]
if params_this[pname].get_value().shape != val.shape:
logger.warning(
" Dimension mismatch {}-{} for {}"
.format(params_this[pname].get_value().shape,
val.shape, pname))
params_this[pname].set_value(val)
logger.info(" Loaded to CG {:15}: {}"
.format(str(val.shape), pname))
else:
logger.warning(
" Parameter does not exist: {}".format(pname))
logger.info(
" Number of parameters loaded for computation graph: {}"
.format(len(params_this) - len(missing)))
except Exception as e:
logger.error(" Error {0}".format(str(e)))
try:
logger.info(" Loading iteration state...")
main_loop.iteration_state = self.load_iteration_state()
except Exception as e:
logger.error(" Error {0}".format(str(e)))
try:
logger.info(" Loading log...")
main_loop.log = self.load_log()
except Exception as e:
logger.error(" Error {0}".format(str(e)))