Skip to content

Commit

Permalink
Add comment about "traced" naming convention (#1037)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjdenkowski authored Apr 4, 2022
1 parent 7729323 commit 1ff2a1c
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions sockeye/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,13 @@ def save_parameters(self, fname: str):
:param fname: Path to save parameters to.
"""
self.apply(layers.interleave_kv)
# Do not save parameters for traced modules. Traced modules are created
# at runtime and use the same parameters as non-traced versions.
# Sockeye follows the convention of using the "traced" prefix for
# modules that are created at runtime by tracing other modules.
# Ex: traced_encoder = trace(encoder, ...)
# Traced modules use the same parameters as the original versions so we
# filter their names from the state dictionary to avoid saving redundant
# copies of their parameters. Copies can also cause errors at loadtime
# if the traced modules do not yet exist.
filtered_state_dict = {name: param for (name, param) in self.state_dict().items() if 'traced' not in name}
pt.save(filtered_state_dict, fname)
self.apply(layers.separate_kv)
Expand Down

0 comments on commit 1ff2a1c

Please sign in to comment.