Skip to content

Commit

Permalink
Merge pull request #40 from philipperemy/issue_39
Browse files Browse the repository at this point in the history
dilation name to make it unique
  • Loading branch information
Philippe Rémy authored Feb 20, 2019
2 parents d6f5d8e + f0bfea6 commit f5a07fc
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions tcn/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,15 @@ def wave_net_activation(x):
return keras.layers.multiply([tanh_out, sigm_out])


def residual_block(x, s, i, activation, nb_filters, kernel_size, padding, dropout_rate=0, name=''):
# type: (Layer, int, int, str, int, int, str, float, str) -> Tuple[Layer, Layer]
def residual_block(x, s, i, c, activation, nb_filters, kernel_size, padding, dropout_rate=0, name=''):
# type: (Layer, int, int, int, str, int, int, str, float, str) -> Tuple[Layer, Layer]
"""Defines the residual block for the WaveNet TCN
Args:
x: The previous layer in the model
s: The stack index i.e. which stack in the overall TCN
i: The dilation power of 2 we are using for this residual block
c: The dilation name to make it unique. In case we have same dilation twice: [1, 1, 2, 4].
activation: The name of the type of activation to use
nb_filters: The number of convolutional filters to use in this block
kernel_size: The size of the convolutional kernel
Expand All @@ -67,7 +68,7 @@ def residual_block(x, s, i, activation, nb_filters, kernel_size, padding, dropou
original_x = x
conv = Conv1D(filters=nb_filters, kernel_size=kernel_size,
dilation_rate=i, padding=padding,
name=name + '_d_%s_conv_%d_tanh_s%d' % (padding, i, s))(x)
name=name + '_d_%s_conv_%d-%d_tanh_s%d' % (padding, i, c, s))(x)
if activation == 'norm_relu':
x = Activation('relu')(conv)
x = Lambda(channel_normalization)(x)
Expand All @@ -76,7 +77,7 @@ def residual_block(x, s, i, activation, nb_filters, kernel_size, padding, dropou
else:
x = Activation(activation)(conv)

x = SpatialDropout1D(dropout_rate, name=name + '_spatial_dropout1d_%d_s%d_%f' % (i, s, dropout_rate))(x)
x = SpatialDropout1D(dropout_rate, name=name + '_spatial_dropout1d_%d-%d_s%d_%f' % (i, c, s, dropout_rate))(x)

# 1x1 conv.
x = Convolution1D(nb_filters, 1, padding='same')(x)
Expand Down Expand Up @@ -156,8 +157,8 @@ def __call__(self, inputs):
x = Convolution1D(self.nb_filters, 1, padding=self.padding, name=self.name + '_initial_conv')(x)
skip_connections = []
for s in range(self.nb_stacks):
for i in self.dilations:
x, skip_out = residual_block(x, s, i, self.activation, self.nb_filters,
for i, d in enumerate(self.dilations):
x, skip_out = residual_block(x, s, d, i, self.activation, self.nb_filters,
self.kernel_size, self.padding, self.dropout_rate, name=self.name)
skip_connections.append(skip_out)
if self.use_skip_connections:
Expand Down

0 comments on commit f5a07fc

Please sign in to comment.