-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsyntactic_abstractor.py
107 lines (88 loc) · 3.95 KB
/
syntactic_abstractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
This module implements what we called a 'Syntactic Abstractor'. This is an experiment that we ran which didn't make it into the paper.
This module retrieves a symbol for each input via "symbolic attention" then performs self-attention on the retrieved symbols.
I.e., it is a mix between an Abstractor and an Encoder. There is no relational cross-attention, but there are 'symbols'.
This didn't work especially well in our experiments.
"""
import tensorflow as tf
from transformer_modules import EncoderLayer, AddPositionalEmbedding
from symbol_retrieving_abstractor import MultiHeadSymbolRetriever, MultiHeadSymbolRetrieval2
class SyntacticAbstractor(tf.keras.layers.Layer):
"""
An implementation of the SyntacticAbstractor Abstractor module.
1) Retrieve symbols
2) Self-attention
"""
def __init__(
self,
num_layers,
num_heads,
dff,
n_symbols,
symbol_n_heads=1,
symbol_binding_dim=None,
add_pos_embedding=True,
symbol_retriever_type=1, # there are two implementations; which one to use.
dropout_rate=0.1,
**kwargs):
"""
Parameters
----------
num_layers : int
number of layers
num_heads : int
number of 'heads' in relational cross-attention (relation dimension)
dff : int
dimension of intermediate layer in feedforward network
n_symbols : int
number of symbols
symbol_n_heads : int, optional
number of heads in SymbolRetriever, by default 1
symbol_binding_dim : int, optional
dimension of binding symbols, by default None
add_pos_embedding : bool, optional
whether to add positional embeddings to symbols after retrieval, by default True
symbol_retriever_type : int, optional
type of symbol retriever, by default 1.
dropout_rate : float, optional
dropout rate, by default 0.1
**kwargs : dict
kwargs for parent Layer class
"""
super(SyntacticAbstractor, self).__init__(**kwargs)
self.num_layers = num_layers
self.num_heads = num_heads
self.dff = dff
self.n_symbols = n_symbols
self.symbol_n_heads = symbol_n_heads
self.symbol_binding_dim = symbol_binding_dim
self.should_add_pos_embedding = add_pos_embedding
self.symbol_retriever_type = symbol_retriever_type
self.dropout_rate = dropout_rate
# NOTE: we choose symbol_dim to be the same as d_model
# this is required for residual connection to work
def build(self, input_shape):
_, self.sequence_length, self.d_model = input_shape
if self.symbol_retriever_type == 1:
self.symbol_retrieval = MultiHeadSymbolRetriever(
n_heads=self.symbol_n_heads, n_symbols=self.n_symbols,
symbol_dim=self.d_model, binding_dim=self.symbol_binding_dim)
elif self.symbol_retriever_type == 2:
self.symbol_retrieval = MultiHeadSymbolRetrieval2(
n_heads=self.symbol_n_heads, n_symbols=self.n_symbols,
symbol_dim=self.d_model, binding_dim=self.symbol_binding_dim)
if self.should_add_pos_embedding: self.add_pos_embedding = AddPositionalEmbedding()
self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
self.encoder_layers = [
EncoderLayer(d_model=self.d_model, num_heads=self.num_heads,
dff=self.dff, dropout_rate=self.dropout_rate)
for _ in range(self.num_layers)]
self.last_attn_scores = None
def call(self, inputs):
symbol_seq = self.symbol_retrieval(inputs) # retrieve symbols
if self.should_add_pos_embedding:
symbol_seq = self.add_pos_embedding(symbol_seq)
symbol_seq = self.dropout(symbol_seq)
for i in range(self.num_layers):
symbol_seq = self.encoder_layers[i](symbol_seq)
return symbol_seq