Skip to content

Commit

Permalink
fix issue: google#285
Browse files Browse the repository at this point in the history
  • Loading branch information
ChinaShrimp committed Jun 23, 2018
1 parent 7f48589 commit 9db1c98
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions seq2seq/contrib/seq2seq/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

import six

from tensorflow.contrib.distributions.python.ops import bernoulli
from tensorflow.contrib.distributions.python.ops import categorical
from tensorflow.contrib.distributions import Bernoulli
from tensorflow.contrib.distributions import Categorical
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import base as layers_base
Expand Down Expand Up @@ -264,7 +264,7 @@ def sample(self, time, outputs, state, name=None):
select_sample_noise = random_ops.random_uniform(
[self.batch_size], seed=self._scheduling_seed)
select_sample = (self._sampling_probability > select_sample_noise)
sample_id_sampler = categorical.Categorical(logits=outputs)
sample_id_sampler = Categorical(logits=outputs)
return array_ops.where(
select_sample,
sample_id_sampler.sample(seed=self._seed),
Expand Down Expand Up @@ -384,7 +384,7 @@ def initialize(self, name=None):
def sample(self, time, outputs, state, name=None):
with ops.name_scope(name, "ScheduledOutputTrainingHelperSample",
[time, outputs, state]):
sampler = bernoulli.Bernoulli(probs=self._sampling_probability)
sampler = Bernoulli(probs=self._sampling_probability)
return math_ops.cast(
sampler.sample(sample_shape=self.batch_size, seed=self._seed),
dtypes.bool)
Expand Down

0 comments on commit 9db1c98

Please sign in to comment.