Skip to content

Commit

Permalink
0.2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
YeongHyeon committed Feb 4, 2022
1 parent a7691f0 commit e9dfcb4
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 24 deletions.
Binary file added dist/whiteboxlayer-0.2.4-py3-none-any.whl
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name = 'whiteboxlayer',
version = '0.2.3',
version = '0.2.4',
description = 'TensorFlow based custom layers',
author = 'YeongHyeon Park',
author_email = '[email protected]',
Expand Down
82 changes: 59 additions & 23 deletions whiteboxlayer/extensions/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,50 @@ def lstm_cell(layer, x_now, h_prev, c_prev, output_dim, \
c_now_act = layer.activation(x=c_now, activation=activation, name="%s-c-act" %(name))
h_now = tf.compat.v1.multiply(o_now, c_now_act, name="%s-h" %(name))

y_now = layer.fully_connected(x=h_now, c_out=output_dim, \
batch_norm=False, activation=activation, name="%s-y" %(name), verbose=verbose)

if(verbose): print("LSTM Cell (%s)" %(name), x_now.shape, "->", y_now.shape)
return h_now, c_now, y_now

def lstm_layer(layer, x, output_dim, \
batch_norm=False, activation="tanh", recurrent_activation="sigmoid", name='lstm', verbose=True):
batch_norm=False, activation="tanh", recurrent_activation="sigmoid", bi_direction=False, name='lstm', verbose=True):

x = tf.transpose(x, perm=[1, 0, 2])
dim_seq = x.get_shape().as_list()[0]
y, h_now, c_now = None, None, None
y, y_rev, h_now = None, None, None
for idx_s in range(dim_seq):

h_now, c_now, x_new = lstm_cell(layer=layer, \
h_now, c_now = lstm_cell(layer=layer, \
x_now=x[idx_s, :, :], h_prev=h_now, c_prev=c_now, output_dim=output_dim, \
activation=activation, recurrent_activation=recurrent_activation, \
name=name, verbose=(verbose and idx_s == 0))

x_new = tf.expand_dims(x_new, 0)
x_new = tf.expand_dims(h_now, 0)
if(y is None): y = x_new
else: y = tf.concat([y, x_new], 0)

y = tf.transpose(y, perm=[1, 0, 2])
if(batch_norm): y = layer.batch_normalization(x=y, \
trainable=True, name='%s_bn' %(name), verbose=verbose)
if(bi_direction):
h_now, c_now = lstm_cell(layer=layer, \
x_now=x[idx_s, :, :], h_prev=h_now, c_prev=c_now, output_dim=output_dim, \
activation=activation, recurrent_activation=recurrent_activation, \
name='%s_rev' %(name), verbose=(verbose and idx_s == 0))

x_new = tf.expand_dims(h_now, 0)
if(y_rev is None): y_rev = x_new
else: y_rev = tf.concat([y_rev, x_new], 0)

if(not(bi_direction)):
y = layer.fully_connected(x=tf.transpose(y, perm=[1, 0, 2]), c_out=output_dim, \
batch_norm=batch_norm, activation=activation, name="%s-y" %(name), verbose=verbose)
else:
y = layer.fully_connected(x=tf.transpose(y, perm=[1, 0, 2]), c_out=output_dim, \
batch_norm=False, activation=None, name="%s-y" %(name), verbose=verbose)
y_rev = layer.fully_connected(x=tf.transpose(y_rev, perm=[1, 0, 2]), c_out=output_dim, \
batch_norm=False, activation=None, name="%s-y_rev" %(name), verbose=verbose)
y = y + y_rev
if(batch_norm): y = layer.batch_normalization(x=y, \
trainable=True, name='%s_bn' %(name), verbose=verbose)
y = layer.activation(x=y, \
activation=activation, name="%s-act" %(name))
x = tf.transpose(x, perm=[1, 0, 2])
if(verbose): print("LSTM (%s)" %(name), x.shape, "->", y.shape)
return y

def gru_cell(layer, x_now, h_prev, output_dim, \
Expand Down Expand Up @@ -107,30 +125,48 @@ def gru_cell(layer, x_now, h_prev, output_dim, \
h_term2 = tf.compat.v1.multiply(z_now, h_now_hat, name="%s-h_term2" %(name))
h_now = tf.compat.v1.add(h_term1, h_term2, name="%s-c" %(name))

y_now = layer.fully_connected(x=h_now, c_out=output_dim, \
batch_norm=False, activation=activation, name="%s-y" %(name), verbose=verbose)

if(verbose): print("GRU Cell (%s)" %(name), x_now.shape, "->", y_now.shape)
return h_now, y_now
return h_now

def gru_layer(layer, x, output_dim, \
batch_norm=False, activation="tanh", recurrent_activation="sigmoid", name='gru', verbose=True):
batch_norm=False, activation="tanh", recurrent_activation="sigmoid", bi_direction=False, name='gru', verbose=True):

x = tf.transpose(x, perm=[1, 0, 2])
dim_seq = x.get_shape().as_list()[0]
y, h_now = None, None
y, y_rev, h_now = None, None, None
for idx_s in range(dim_seq):

h_now, x_new = gru_cell(layer=layer, \
h_now = gru_cell(layer=layer, \
x_now=x[idx_s, :, :], h_prev=h_now, output_dim=output_dim, \
activation=activation, recurrent_activation=recurrent_activation, \
name=name, verbose=(verbose and idx_s == 0))

x_new = tf.expand_dims(x_new, 0)
x_new = tf.expand_dims(h_now, 0)
if(y is None): y = x_new
else: y = tf.concat([y, x_new], 0)

y = tf.transpose(y, perm=[1, 0, 2])
if(batch_norm): y = layer.batch_normalization(x=y, \
trainable=True, name='%s_bn' %(name), verbose=verbose)
if(bi_direction):
h_now = gru_cell(layer=layer, \
x_now=x[-idx_s, :, :], h_prev=h_now, output_dim=output_dim, \
activation=activation, recurrent_activation=recurrent_activation, \
name='%s_rev' %(name), verbose=(verbose and idx_s == 0))

x_new = tf.expand_dims(h_now, 0)
if(y_rev is None): y_rev = x_new
else: y_rev = tf.concat([y_rev, x_new], 0)

if(not(bi_direction)):
y = layer.fully_connected(x=tf.transpose(y, perm=[1, 0, 2]), c_out=output_dim, \
batch_norm=batch_norm, activation=activation, name="%s-y" %(name), verbose=verbose)
else:
y = layer.fully_connected(x=tf.transpose(y, perm=[1, 0, 2]), c_out=output_dim, \
batch_norm=False, activation=None, name="%s-y" %(name), verbose=verbose)
y_rev = layer.fully_connected(x=tf.transpose(y_rev, perm=[1, 0, 2]), c_out=output_dim, \
batch_norm=False, activation=None, name="%s-y_rev" %(name), verbose=verbose)
y = y + y_rev
if(batch_norm): y = layer.batch_normalization(x=y, \
trainable=True, name='%s_bn' %(name), verbose=verbose)
y = layer.activation(x=y, \
activation=activation, name="%s-act" %(name))
x = tf.transpose(x, perm=[1, 0, 2])
if(verbose): print("GRU (%s)" %(name), x.shape, "->", y.shape)
return y

0 comments on commit e9dfcb4

Please sign in to comment.