Skip to content

Commit

Permalink
Bugfix: Make sure lhuc flag is passed as bool to configs (#382)
Browse files Browse the repository at this point in the history
  • Loading branch information
fhieber authored May 7, 2018
1 parent 531ae4d commit 3c1a8c5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Note that Sockeye has checks in place to not translate with an old model that wa
Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [1.18.9]
### Fixed
- Fixed a problem with lhuc boolean flags passed as None.

### Added
- Reorganized beam search. Normalization is applied only to completed hypotheses, and pruning of
hypotheses (logprob against highest-scoring completed hypothesis) can be specified with
Expand Down
10 changes: 5 additions & 5 deletions sockeye/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def check_arg_compatibility(args: argparse.Namespace):
"Target embedding size must match transformer model size: %s vs. %s"
% (args.transformer_model_size, args.num_embed[1]))

if args.lhuc:
if args.lhuc is not None:
check_condition(args.encoder == C.RNN_NAME or args.decoder == C.RNN_NAME,
"LHUC is only supported for RNN models for now.")

Expand Down Expand Up @@ -443,7 +443,7 @@ def create_encoder_config(args: argparse.Namespace,
residual=args.rnn_residual_connections,
first_residual_layer=args.rnn_first_residual_layer,
forget_bias=args.rnn_forget_bias,
lhuc=args.lhuc and (C.LHUC_ENCODER in args.lhuc or C.LHUC_ALL in args.lhuc)),
lhuc=args.lhuc is not None and (C.LHUC_ENCODER in args.lhuc or C.LHUC_ALL in args.lhuc)),
conv_config=config_conv,
reverse_input=args.rnn_encoder_reverse_input)
encoder_num_hidden = args.rnn_num_hidden
Expand Down Expand Up @@ -531,14 +531,14 @@ def create_decoder_config(args: argparse.Namespace, encoder_num_hidden: int) ->
residual=args.rnn_residual_connections,
first_residual_layer=args.rnn_first_residual_layer,
forget_bias=args.rnn_forget_bias,
lhuc=args.lhuc and (C.LHUC_ENCODER in args.lhuc or C.LHUC_ALL in args.lhuc)),
lhuc=args.lhuc is not None and (C.LHUC_DECODER in args.lhuc or C.LHUC_ALL in args.lhuc)),
attention_config=config_attention,
hidden_dropout=args.rnn_decoder_hidden_dropout,
state_init=args.rnn_decoder_state_init,
context_gating=args.rnn_context_gating,
layer_normalization=args.layer_normalization,
attention_in_upper_layers=args.rnn_attention_in_upper_layers,
state_init_lhuc=args.lhuc and (C.LHUC_STATE_INIT in args.lhuc or C.LHUC_ALL in args.lhuc))
state_init_lhuc=args.lhuc is not None and (C.LHUC_STATE_INIT in args.lhuc or C.LHUC_ALL in args.lhuc))

return config_decoder

Expand Down Expand Up @@ -625,7 +625,7 @@ def create_model_config(args: argparse.Namespace,
weight_tying=args.weight_tying,
weight_tying_type=args.weight_tying_type if args.weight_tying else None,
weight_normalization=args.weight_normalization,
lhuc=bool(args.lhuc))
lhuc=args.lhuc is not None)
return model_config


Expand Down

0 comments on commit 3c1a8c5

Please sign in to comment.