Skip to content

Commit 847aa95

Browse files
same for first
1 parent 9d6df41 commit 847aa95

File tree

4 files changed

+50
-30
lines changed

4 files changed

+50
-30
lines changed

first.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@
77

88
log_sigmoid = torch.nn.LogSigmoid()
99

10+
class PositionEncoding(torch.nn.Module):
11+
def __init__(self, size):
12+
super().__init__()
13+
self.size = size
14+
15+
def forward(self, n):
16+
zero = torch.zeros(n)
17+
pos = torch.arange(0, n).to(torch.float)
18+
pe = torch.stack([pos == 1] + [zero]*(self.size-1), dim=1)
19+
return pe
20+
1021
ap = argparse.ArgumentParser()
1122
ap.add_argument('--train_length', type=int, default=50)
1223
ap.add_argument('--test_length', type=int, default=1000)
@@ -17,30 +28,24 @@
1728

1829
alphabet = ["0", "1", "$"]
1930
alphabet_index = {a:i for i,a in enumerate(alphabet)}
20-
max_pos = 10000
2131
size = 16
2232

2333
class Model(torch.nn.Module):
2434
def __init__(self, alphabet_size, size):
2535
super().__init__()
2636

2737
self.word_embedding = torch.nn.Embedding(num_embeddings=alphabet_size, embedding_dim=size)
28-
self.pos_embedding = torch.stack([
29-
torch.arange(0, max_pos, dtype=torch.float) == 0,
30-
torch.arange(0, max_pos, dtype=torch.float) == 1,
31-
torch.arange(0, max_pos, dtype=torch.float) >= 2,
32-
], dim=1).to(torch.float)
33-
self.pos_adapter = torch.nn.Linear(self.pos_embedding.size()[1], size)
38+
self.pos_encoding = PositionEncoding(size)
3439

35-
encoder_layer = encoder.PostnormTransformerEncoderLayer(d_model=size, nhead=1, dim_feedforward=size*4, dropout=0.)
40+
encoder_layer = encoder.TransformerEncoderLayer(d_model=size, nhead=1, dim_feedforward=size*4, dropout=0.)
3641
#encoder_layer = encoder.ScaledTransformerEncoderLayer(d_model=size, nhead=1, dim_feedforward=size*4, dropout=0.)
3742
#encoder_layer.norm1.eps = encoder_layer.norm2.eps = 0.
3843
self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=2)
3944

4045
self.output_layer = torch.nn.Linear(size, 1)
4146

4247
def forward(self, w):
43-
x = self.word_embedding(w) + self.pos_adapter(self.pos_embedding[:len(w)])
48+
x = self.word_embedding(w) + self.pos_encoding(len(w))
4449
y = self.encoder(x.unsqueeze(1)).squeeze(1)
4550
y = y[0]
4651
z = self.output_layer(y)

first_exact.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import sys
66
import argparse
77

8-
log_sigmoid = torch.nn.LogSigmoid()
9-
108
ap = argparse.ArgumentParser()
119
ap.add_argument('--length', type=int, default=100)
1210
ap.add_argument('--steps', type=int, default=100)
@@ -17,6 +15,21 @@
1715
alphabet_index = {a:i for i,a in enumerate(alphabet)}
1816
max_pos = 10000
1917

18+
log_sigmoid = torch.nn.LogSigmoid()
19+
20+
class PositionEncoding(torch.nn.Module):
21+
def __init__(self):
22+
super().__init__()
23+
24+
def forward(self, n):
25+
zero = torch.zeros(n)
26+
pos = torch.arange(0, n).to(torch.float)
27+
pe = torch.stack([zero]*3 +
28+
[pos == 1] +
29+
[zero]*2,
30+
dim=1)
31+
return pe
32+
2033
class FirstLayer(torch.nn.TransformerEncoderLayer):
2134
def __init__(self):
2235
super().__init__(6, 1, 1, dropout=0.)
@@ -96,20 +109,15 @@ def __init__(self):
96109
super().__init__()
97110

98111
self.word_embedding = torch.eye(3, 6)
99-
self.pos_embedding = torch.stack(
100-
[torch.zeros(max_pos)]*3 +
101-
[torch.arange(0, max_pos, dtype=torch.float) == 1] +
102-
[torch.zeros(max_pos)]*2,
103-
dim=1)
104-
112+
self.pos_encoding = PositionEncoding()
105113
self.transformer_encoder = MyTransformerEncoder()
106114
self.output_layer = torch.nn.Linear(6, 1)
107115
self.output_layer.weight = torch.nn.Parameter(torch.tensor(
108116
[[0,0,0,0,0,1]], dtype=torch.float))
109117
self.output_layer.bias = torch.nn.Parameter(torch.tensor([0.]))
110118

111119
def forward(self, w):
112-
x = self.word_embedding[w] + self.pos_embedding[:len(w)]
120+
x = self.word_embedding[w] + self.pos_encoding(len(w))
113121
y = self.transformer_encoder(x.unsqueeze(1)).squeeze(1)
114122
z = self.output_layer(y[0])
115123
return z

first_exact_layernorm.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
import sys
66
import argparse
77

8-
log_sigmoid = torch.nn.LogSigmoid()
9-
108
ap = argparse.ArgumentParser()
119
ap.add_argument('--length', type=int, default=100)
1210
ap.add_argument('--steps', type=int, default=100)
@@ -19,6 +17,21 @@
1917
alphabet_index = {a:i for i,a in enumerate(alphabet)}
2018
max_pos = 10000
2119

20+
log_sigmoid = torch.nn.LogSigmoid()
21+
22+
class PositionEncoding(torch.nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
26+
def forward(self, n):
27+
zero = torch.zeros(n)
28+
pos = torch.arange(0, n).to(torch.float)
29+
pe = torch.stack([zero]*3 +
30+
[pos == 1] +
31+
[zero]*2,
32+
dim=1)
33+
return pe
34+
2235
class FirstLayer(torch.nn.TransformerEncoderLayer):
2336
def __init__(self):
2437
super().__init__(12, 1, 1, dropout=0.)
@@ -110,20 +123,15 @@ def __init__(self):
110123
super().__init__()
111124

112125
self.word_embedding = torch.eye(3, 6)
113-
self.pos_embedding = torch.stack(
114-
[torch.zeros(max_pos)]*3 +
115-
[torch.arange(0, max_pos, dtype=torch.float) == 1] +
116-
[torch.zeros(max_pos)]*2,
117-
dim=1)
118-
126+
self.pos_encoding = PositionEncoding()
119127
self.transformer_encoder = MyTransformerEncoder()
120128
self.output_layer = torch.nn.Linear(12, 1)
121129
self.output_layer.weight = torch.nn.Parameter(torch.tensor(
122130
[[0,0,0,0,0,1,0,0,0,0,0,0]], dtype=torch.float))
123131
self.output_layer.bias = torch.nn.Parameter(torch.tensor([0.]))
124132

125133
def forward(self, w):
126-
x = self.word_embedding[w] + self.pos_embedding[:len(w)]
134+
x = self.word_embedding[w] + self.pos_encoding(len(w))
127135
x = torch.cat([x, -x], dim=-1)
128136
y = self.transformer_encoder(x.unsqueeze(1)).squeeze(1)
129137
z = self.output_layer(y[0])

parity_exact_layernorm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,8 @@ def __init__(self):
179179
self.output_layer.bias = torch.nn.Parameter(torch.tensor([0.]))
180180

181181
def forward(self, w):
182-
x = torch.cat([self.word_embedding[w] + self.pos_encoding(len(w)),
183-
-(self.word_embedding[w] + self.pos_encoding(len(w)))],
184-
dim=1)
182+
x = self.word_embedding[w] + self.pos_encoding(len(w))
183+
x = torch.cat([x, -x], dim=-1)
185184
y = self.transformer_encoder(x.unsqueeze(1)).squeeze(1)
186185
z = self.output_layer(y[-1])
187186
return z

0 commit comments

Comments
 (0)