-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdtml.py
105 lines (96 loc) · 3.88 KB
/
dtml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import random, os, sys
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import *
from tensorflow.keras.initializers import *
import tensorflow.keras.backend as K
class LayerNormalization(Layer):
def __init__(self, eps=1e-6, **kwargs):
self.eps = eps
super(LayerNormalization, self).__init__(**kwargs)
def build(self, input_shape):
self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:], initializer=Ones(), trainable=True)
self.beta = self.add_weight(name='beta', shape=input_shape[-1:], initializer=Zeros(), trainable=True)
super(LayerNormalization, self).build(input_shape)
def call(self, x):
mean = K.mean(x, axis=-1, keepdims=True)
std = K.std(x, axis=-1, keepdims=True)
return self.gamma * (x - mean) / (std + self.eps) + self.beta
def compute_output_shape(self, input_shape):
return input_shape
# It's safe to use a 1-d mask for self-attention
class ScaledDotProductAttention():
def __init__(self, attn_dropout=0.1):
self.dropout = Dropout(attn_dropout)
def __call__(self, q, k, v, mask): # mask_k or mask_qk
temper = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype='float32'))
attn = Lambda(lambda x:K.batch_dot(x[0],x[1],axes=[2,2])/x[2])([q, k, temper]) # shape=(batch, q, k)
if mask is not None:
mmask = Lambda(lambda x:(-1e+9)*(1.-K.cast(x, 'float32')))(mask)
attn = Add()([attn, mmask])
attn = Activation('softmax')(attn)
attn = self.dropout(attn)
output = Lambda(lambda x:K.batch_dot(x[0], x[1]))([attn, v])
return output, attn
class MultiHeadAttention():
# mode 0 - big martixes, faster; mode 1 - more clear implementation
def __init__(self, n_head, d_model, dropout, mode=0):
self.mode = mode
self.n_head = n_head
self.d_k = self.d_v = d_k = d_v = d_model // n_head
self.dropout = dropout
if mode == 0:
self.qs_layer = Dense(n_head*d_k, use_bias=False)
self.ks_layer = Dense(n_head*d_k, use_bias=False)
self.vs_layer = Dense(n_head*d_v, use_bias=False)
elif mode == 1:
self.qs_layers = []
self.ks_layers = []
self.vs_layers = []
for _ in range(n_head):
self.qs_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
self.ks_layers.append(TimeDistributed(Dense(d_k, use_bias=False)))
self.vs_layers.append(TimeDistributed(Dense(d_v, use_bias=False)))
self.attention = ScaledDotProductAttention()
self.w_o = TimeDistributed(Dense(d_model))
def __call__(self, q, k, v, mask=None):
d_k, d_v = self.d_k, self.d_v
n_head = self.n_head
if self.mode == 0:
qs = self.qs_layer(q) # [batch_size, len_q, n_head*d_k]
ks = self.ks_layer(k)
vs = self.vs_layer(v)
def reshape1(x):
s = tf.shape(x) # [batch_size, len_q, n_head * d_k]
x = tf.reshape(x, [s[0], s[1], n_head, s[2]//n_head])
x = tf.transpose(x, [2, 0, 1, 3])
x = tf.reshape(x, [-1, s[1], s[2]//n_head]) # [n_head * batch_size, len_q, d_k]
return x
qs = Lambda(reshape1)(qs)
ks = Lambda(reshape1)(ks)
vs = Lambda(reshape1)(vs)
if mask is not None:
mask = Lambda(lambda x:K.repeat_elements(x, n_head, 0))(mask)
head, attn = self.attention(qs, ks, vs, mask=mask)
def reshape2(x):
s = tf.shape(x) # [n_head * batch_size, len_v, d_v]
x = tf.reshape(x, [n_head, -1, s[1], s[2]])
x = tf.transpose(x, [1, 2, 0, 3])
x = tf.reshape(x, [-1, s[1], n_head*d_v]) # [batch_size, len_v, n_head * d_v]
return x
head = Lambda(reshape2)(head)
elif self.mode == 1:
heads = []; attns = []
for i in range(n_head):
qs = self.qs_layers[i](q)
ks = self.ks_layers[i](k)
vs = self.vs_layers[i](v)
head, attn = self.attention(qs, ks, vs, mask)
heads.append(head); attns.append(attn)
head = Concatenate()(heads) if n_head > 1 else heads[0]
attn = Concatenate()(attns) if n_head > 1 else attns[0]
outputs = self.w_o(head)
outputs = Dropout(self.dropout)(outputs)
return outputs, attn