Skip to content

Commit 9addf8e

Browse files
author
Tianyu Gao
committed
fix pcnn bug
1 parent 01ce420 commit 9addf8e

8 files changed

+163
-40
lines changed

.gitignore

+6
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ ckpt
116116
benchmark/tacred
117117
*.swp
118118

119+
<<<<<<< HEAD
119120
# data and pretrain
120121
pretrain
121122
benchmark
@@ -127,3 +128,8 @@ benchmark
127128

128129
# package
129130
opennre-egg.info
131+
=======
132+
# debug
133+
benchmark/nyt10-ori
134+
train_nyt10_pcnn_att_ori.py
135+
>>>>>>> ebf8370... fix pcnn bug

opennre/encoder/cnn_encoder.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def __init__(self,
1818
word2vec=None,
1919
kernel_size=3,
2020
padding_size=1,
21-
dropout=0.5):
21+
dropout=0,
22+
activation_function=F.relu):
2223
"""
2324
Args:
2425
token2id: dictionary of token->idx mapping
@@ -33,12 +34,13 @@ def __init__(self,
3334
"""
3435
# Hyperparameters
3536
super(CNNEncoder, self).__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec)
36-
self.dropout = dropout
37+
self.drop = nn.Dropout(dropout)
3738
self.kernel_size = kernel_size
3839
self.padding_size = padding_size
40+
self.act = activation_function
3941

40-
self.conv = CNN(self.input_size, self.hidden_size, self.dropout, self.kernel_size, self.padding_size, activation_function = F.relu)
41-
self.pool = MaxPool(self.max_length)
42+
self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding)
43+
self.pool = nn.MaxPool1d(self.max_length)
4244

4345
def forward(self, token, pos1, pos2):
4446
"""
@@ -55,9 +57,11 @@ def forward(self, token, pos1, pos2):
5557
x = torch.cat([self.word_embedding(token),
5658
self.pos1_embedding(pos1),
5759
self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
58-
x = self.conv(x) # (B, L, EMBED)
59-
x = self.pool(x) # (B, EMBED)
60+
x = x.transpose(1, 2) # (B, EMBED, L)
61+
x = self.act(self.conv(x)) # (B, H, L)
62+
x = self.pool(x).squeeze(-1)
63+
x = self.drop(x)
6064
return x
6165

6266
def tokenize(self, item):
63-
return super().tokenize(item)
67+
return super().tokenize(item)

opennre/encoder/pcnn_encoder.py

+82-18
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ def __init__(self,
1818
word2vec=None,
1919
kernel_size=3,
2020
padding_size=1,
21-
dropout=0.5):
21+
dropout=0.0,
22+
activation_function=F.relu):
2223
"""
2324
Args:
2425
token2id: dictionary of token->idx mapping
@@ -33,12 +34,19 @@ def __init__(self,
3334
"""
3435
# hyperparameters
3536
super().__init__(token2id, max_length, hidden_size, word_size, position_size, blank_padding, word2vec)
36-
self.dropout = dropout
37+
self.drop = nn.Dropout(dropout)
3738
self.kernel_size = kernel_size
3839
self.padding_size = padding_size
40+
self.act = activation_function
41+
42+
self.conv = nn.Conv1d(self.input_size, self.hidden_size, self.kernel_size, padding=self.padding_size)
43+
self.pool = nn.MaxPool1d(self.max_length)
44+
self.mask_embedding = nn.Embedding(4, 3)
45+
self.mask_embedding.weight.data.copy_(torch.FloatTensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]))
46+
self.mask_embedding.weight.requires_grad = False
47+
self._minus = -100
3948

40-
self.conv = CNN(self.input_size, self.hidden_size, self.dropout, self.kernel_size, self.padding_size, activation_function=F.relu)
41-
self.pool = MaxPool(self.max_length, 3)
49+
self.hidden_size *= 3
4250

4351
def forward(self, token, pos1, pos2, mask):
4452
"""
@@ -55,8 +63,17 @@ def forward(self, token, pos1, pos2, mask):
5563
x = torch.cat([self.word_embedding(token),
5664
self.pos1_embedding(pos1),
5765
self.pos2_embedding(pos2)], 2) # (B, L, EMBED)
58-
x = self.conv(x) # (B, L, EMBED)
59-
x = self.pool(x) # (B, EMBED)
66+
x = x.transpose(1, 2) # (B, EMBED, L)
67+
x = self.conv(x) # (B, H, L)
68+
69+
mask = 1 - self.mask_embedding(mask).transpose(1, 2) # (B, L) -> (B, L, 3) -> (B, 3, L)
70+
pool1 = self.pool(self.act(x + self._minus * mask[:, 0:1, :])) # (B, H, 1)
71+
pool2 = self.pool(self.act(x + self._minus * mask[:, 1:2, :]))
72+
pool3 = self.pool(self.act(x + self._minus * mask[:, 2:3, :]))
73+
x = torch.cat([pool1, pool2, pool3], 1) # (B, 3H, 1)
74+
x = x.squeeze(2) # (B, 3H)
75+
x = self.drop(x)
76+
6077
return x
6178

6279
def tokenize(self, item):
@@ -69,25 +86,72 @@ def tokenize(self, item):
6986
Return:
7087
Name of the relation of the sentence
7188
"""
72-
# Sentence -> token
73-
indexed_tokens, pos1, pos2 = super().tokenize(item)
74-
sentence = item['text']
89+
if 'text' in item:
90+
sentence = item['text']
91+
is_token = False
92+
else:
93+
sentence = item['token']
94+
is_token = True
7595
pos_head = item['h']['pos']
76-
pos_tail = item['t']['pos']
96+
pos_tail = item['t']['pos']
7797

78-
# Mask
98+
# Sentence -> token
99+
if not is_token:
100+
if pos_head[0] > pos_tail[0]:
101+
pos_min, pos_max = [pos_tail, pos_head]
102+
rev = True
103+
else:
104+
pos_min, pos_max = [pos_head, pos_tail]
105+
rev = False
106+
sent_0 = self.tokenizer.tokenize(sentence[:pos_min[0]])
107+
sent_1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]])
108+
sent_2 = self.tokenizer.tokenize(sentence[pos_max[1]:])
109+
ent_0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]])
110+
ent_1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]])
111+
tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2
112+
if rev:
113+
pos_tail = [len(sent_0), len(sent_0) + len(ent_0)]
114+
pos_head = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
115+
else:
116+
pos_head = [len(sent_0), len(sent_0) + len(ent_0)]
117+
pos_tail = [len(sent_0) + len(ent_0) + len(sent_1), len(sent_0) + len(ent_0) + len(sent_1) + len(ent_1)]
118+
else:
119+
tokens = sentence
120+
121+
# Token -> index
122+
if self.blank_padding:
123+
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, self.max_length, self.token2id['[PAD]'], self.token2id['[UNK]'])
124+
else:
125+
indexed_tokens = self.tokenizer.convert_tokens_to_ids(tokens, unk_id = self.token2id['[UNK]'])
126+
127+
# Position -> index
128+
pos1 = []
129+
pos2 = []
79130
pos1_in_index = min(pos_head[0], self.max_length)
80131
pos2_in_index = min(pos_tail[0], self.max_length)
132+
for i in range(len(tokens)):
133+
pos1.append(min(i - pos1_in_index + self.max_length, 2 * self.max_length - 1))
134+
pos2.append(min(i - pos2_in_index + self.max_length, 2 * self.max_length - 1))
135+
136+
if self.blank_padding:
137+
while len(pos1) < self.max_length:
138+
pos1.append(0)
139+
while len(pos2) < self.max_length:
140+
pos2.append(0)
141+
indexed_tokens = indexed_tokens[:self.max_length]
142+
pos1 = pos1[:self.max_length]
143+
pos2 = pos2[:self.max_length]
81144

145+
indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L)
146+
pos1 = torch.tensor(pos1).long().unsqueeze(0) # (1, L)
147+
pos2 = torch.tensor(pos2).long().unsqueeze(0) # (1, L)
148+
149+
# Mask
82150
mask = []
83-
pos_min = min(pos1_in_index, pos2_in_index)
84-
pos_max = max(pos1_in_index, pos2_in_index)
85-
for i in range(len(indexed_tokens)):
86-
if pos1[0][i] == 0:
87-
break
88-
if i <= pos_min:
151+
for i in range(len(tokens)):
152+
if i <= pos_min[0]:
89153
mask.append(1)
90-
elif i <= pos_max:
154+
elif i <= pos_max[0]:
91155
mask.append(2)
92156
else:
93157
mask.append(3)

opennre/module/pool/max_pool.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,24 @@ def forward(self, x, mask=None):
3030
output features: (B, H_EMBED)
3131
"""
3232
# Check size of tensors
33-
if mask == None or self.segment_num == None or self.segment_num == 1:
33+
if mask is None or self.segment_num is None or self.segment_num == 1:
3434
x = x.transpose(1, 2) # (B, L, I_EMBED) -> (B, I_EMBED, L)
3535
x = self.pool(x).squeeze(-1) # (B, I_EMBED, 1) -> (B, I_EMBED)
3636
return x
3737
else:
38-
B, L, I_EMBED = x.size()[:2]
39-
mask = 1 - self.mask_embedding(mask).transpose(1, 2).unsqueeze(2) # (B, L) -> (B, L, S) -> (B, S, L) -> (B, S, 1, L)
40-
x = x.transpose(1, 2).unsqueeze(1) # (B, L, I_EMBED) -> (B, I_EMBED, L) -> (B, 1, I_EMBED, L)
41-
x = (x + self._minus * mask).view([-1, I_EMBED, L]) # (B, S, I_EMBED, L) -> (B * S, I_EMBED, L)
42-
x = self.pool(x).squeeze(-1) # (B * S, I_EMBED, 1) -> (B * S, I_EMBED)
43-
x = x.view([B, -1]) # (B, S * I_EMBED)
44-
return x
38+
B, L, I_EMBED = x.size()[:3]
39+
# mask = 1 - self.mask_embedding(mask).transpose(1, 2).unsqueeze(2) # (B, L) -> (B, L, S) -> (B, S, L) -> (B, S, 1, L)
40+
# x = x.transpose(1, 2).unsqueeze(1) # (B, L, I_EMBED) -> (B, I_EMBED, L) -> (B, 1, I_EMBED, L)
41+
# x = (x + self._minus * mask).contiguous().view([-1, I_EMBED, L]) # (B, S, I_EMBED, L) -> (B * S, I_EMBED, L)
42+
# x = self.pool(x).squeeze(-1) # (B * S, I_EMBED, 1) -> (B * S, I_EMBED)
43+
# x = x.view([B, -1]) # (B, S * I_EMBED)
44+
# return x
45+
mask = 1 - self.mask_embedding(mask).transpose(1, 2)
46+
x = x.transpose(1, 2)
47+
pool1 = self.pool(x + self._minus * mask[:, 0:1, :])
48+
pool2 = self.pool(x + self._minus * mask[:, 1:2, :])
49+
pool3 = self.pool(x + self._minus * mask[:, 2:3, :])
50+
51+
x = torch.cat([pool1, pool2, pool3], 1)
52+
# x = x.squeeze(-1)
53+
return x

train_nyt10_cnn_att.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
kernel_size=3,
1919
padding_size=1,
2020
word2vec=word2vec,
21-
dropout=0.5)
21+
dropout=0.0)
2222
model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id)
2323
framework = opennre.framework.BagRE(
2424
train_path='benchmark/nyt10/nyt10_train.txt',
@@ -32,7 +32,7 @@
3232
weight_decay=0,
3333
opt='sgd')
3434
# Train
35-
framework.train_model()
35+
# framework.train_model()
3636
# Test
3737
framework.load_state_dict(torch.load(ckpt)['state_dict'])
3838
result = framework.eval_model(framework.test_loader)

train_nyt10_cnn_att_fixbag.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
kernel_size=3,
1919
padding_size=1,
2020
word2vec=word2vec,
21-
dropout=0.5)
21+
dropout=0.0)
2222
model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id)
2323
framework = opennre.framework.BagRE(
2424
train_path='benchmark/nyt10/nyt10_train.txt',

train_nyt10_pcnn_att.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,16 @@
1818
kernel_size=3,
1919
padding_size=1,
2020
word2vec=word2vec,
21-
dropout=0.5)
22-
model = nrekit.model.BagAttention(sentence_encoder, len(rel2id), rel2id)
23-
framework = nrekit.framework.BagRE(
21+
dropout=0.0)
22+
model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id)
23+
framework = opennre.framework.BagRE(
2424
train_path='benchmark/nyt10/nyt10_train.txt',
2525
val_path='benchmark/nyt10/nyt10_val.txt',
2626
test_path='benchmark/nyt10/nyt10_test.txt',
2727
model=model,
2828
ckpt=ckpt,
2929
batch_size=160,
30-
max_epoch=60,
30+
max_epoch=25,
3131
lr=0.5,
3232
weight_decay=0,
3333
opt='sgd')

train_nyt10_pcnn_att_fixbag.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# coding:utf-8
2+
import torch
3+
import numpy as np
4+
import json
5+
import opennre
6+
from opennre import encoder, model, framework
7+
8+
ckpt = 'ckpt/nyt10_pcnn_att_fixbag.pth.tar'
9+
word2id = json.load(open('pretrain/glove/glove.6B.50d_word2id.json'))
10+
word2vec = np.load('pretrain/glove/glove.6B.50d_mat.npy')
11+
rel2id = json.load(open('benchmark/nyt10/nyt10_rel2id.json'))
12+
sentence_encoder = opennre.encoder.PCNNEncoder(token2id=word2id,
13+
max_length=120,
14+
word_size=50,
15+
position_size=5,
16+
hidden_size=230,
17+
blank_padding=True,
18+
kernel_size=3,
19+
padding_size=1,
20+
word2vec=word2vec,
21+
dropout=0.0)
22+
model = opennre.model.BagAttention(sentence_encoder, len(rel2id), rel2id)
23+
framework = opennre.framework.BagRE(
24+
train_path='benchmark/nyt10/nyt10_train.txt',
25+
val_path='benchmark/nyt10/nyt10_val.txt',
26+
test_path='benchmark/nyt10/nyt10_test.txt',
27+
model=model,
28+
ckpt=ckpt,
29+
batch_size=160,
30+
max_epoch=60,
31+
lr=0.5,
32+
weight_decay=0,
33+
opt='sgd',
34+
bag_size=3)
35+
# Train
36+
framework.train_model()
37+
# Test
38+
framework.load_state_dict(torch.load(ckpt)['state_dict'])
39+
result = framework.eval_model(framework.test_loader)
40+
print('AUC on test set: {}'.format(result['auc']))

0 commit comments

Comments
 (0)