Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
adding hparam to make encoder self-attention optional.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 313707269
  • Loading branch information
eli7 authored and copybara-github committed May 29, 2020
1 parent f65b5e4 commit c104976
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions tensor2tensor/models/evolved_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,34 +223,35 @@ def evolved_transformer_encoder(encoder_input,
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)

with tf.variable_scope("self_attention"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
if hparams.get("et_encoder_self_attention", True):
with tf.variable_scope("self_attention"):
residual_state = hidden_state
hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

hidden_state = common_attention.multihead_attention(
hidden_state,
None,
encoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))
hidden_state = common_attention.multihead_attention(
hidden_state,
None,
encoder_self_attention_bias,
hparams.attention_key_channels or hparams.hidden_size,
hparams.attention_value_channels or hparams.hidden_size,
hparams.hidden_size,
hparams.num_heads,
hparams.attention_dropout,
attention_type=hparams.self_attention_type,
max_relative_position=hparams.max_relative_position,
heads_share_relative_embedding=(
hparams.heads_share_relative_embedding),
add_relative_to_values=hparams.add_relative_to_values,
save_weights_to=save_weights_to,
make_image_summary=make_image_summary,
dropout_broadcast_dims=attention_dropout_broadcast_dims,
max_length=hparams.get("max_length"),
vars_3d=hparams.get("attention_variables_3d"),
activation_dtype=hparams.get("activation_dtype", "float32"),
weight_dtype=hparams.get("weight_dtype", "float32"))

hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)
hidden_state = common_layers.layer_postprocess(
residual_state, hidden_state, hparams)

with tf.variable_scope("dense_layers"):
residual_state = hidden_state
Expand Down

0 comments on commit c104976

Please sign in to comment.