-
Notifications
You must be signed in to change notification settings - Fork 46
/
MySumLayer.py
39 lines (34 loc) · 1.22 KB
/
MySumLayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
class MySumLayer(Layer):
def __init__(self, axis, **kwargs):
self.supports_masking = True
self.axis = axis
super(MySumLayer, self).__init__(**kwargs)
def compute_mask(self, input, input_mask=None):
# do not pass the mask to the next layers
return None
def call(self, x, mask=None):
if mask is not None:
# mask (batch, time)
mask = K.cast(mask, K.floatx())
if K.ndim(x)!=K.ndim(mask):
mask = K.repeat(mask, x.shape[-1])
mask = tf.transpose(mask, [0,2,1])
x = x * mask
if K.ndim(x)==2:
x = K.expand_dims(x)
return K.sum(x, axis=self.axis)
else:
if K.ndim(x)==2:
x = K.expand_dims(x)
return K.sum(x, axis=self.axis)
def compute_output_shape(self, input_shape):
output_shape = []
for i in range(len(input_shape)):
if i!=self.axis:
output_shape.append(input_shape[i])
if len(output_shape)==1:
output_shape.append(1)
return tuple(output_shape)