@@ -18,7 +18,8 @@ def __init__(self,
18
18
word2vec = None ,
19
19
kernel_size = 3 ,
20
20
padding_size = 1 ,
21
- dropout = 0.5 ):
21
+ dropout = 0.0 ,
22
+ activation_function = F .relu ):
22
23
"""
23
24
Args:
24
25
token2id: dictionary of token->idx mapping
@@ -33,12 +34,19 @@ def __init__(self,
33
34
"""
34
35
# hyperparameters
35
36
super ().__init__ (token2id , max_length , hidden_size , word_size , position_size , blank_padding , word2vec )
36
- self .dropout = dropout
37
+ self .drop = nn . Dropout ( dropout )
37
38
self .kernel_size = kernel_size
38
39
self .padding_size = padding_size
40
+ self .act = activation_function
41
+
42
+ self .conv = nn .Conv1d (self .input_size , self .hidden_size , self .kernel_size , padding = self .padding_size )
43
+ self .pool = nn .MaxPool1d (self .max_length )
44
+ self .mask_embedding = nn .Embedding (4 , 3 )
45
+ self .mask_embedding .weight .data .copy_ (torch .FloatTensor ([[0 , 0 , 0 ], [1 , 0 , 0 ], [0 , 1 , 0 ], [0 , 0 , 1 ]]))
46
+ self .mask_embedding .weight .requires_grad = False
47
+ self ._minus = - 100
39
48
40
- self .conv = CNN (self .input_size , self .hidden_size , self .dropout , self .kernel_size , self .padding_size , activation_function = F .relu )
41
- self .pool = MaxPool (self .max_length , 3 )
49
+ self .hidden_size *= 3
42
50
43
51
def forward (self , token , pos1 , pos2 , mask ):
44
52
"""
@@ -55,8 +63,17 @@ def forward(self, token, pos1, pos2, mask):
55
63
x = torch .cat ([self .word_embedding (token ),
56
64
self .pos1_embedding (pos1 ),
57
65
self .pos2_embedding (pos2 )], 2 ) # (B, L, EMBED)
58
- x = self .conv (x ) # (B, L, EMBED)
59
- x = self .pool (x ) # (B, EMBED)
66
+ x = x .transpose (1 , 2 ) # (B, EMBED, L)
67
+ x = self .conv (x ) # (B, H, L)
68
+
69
+ mask = 1 - self .mask_embedding (mask ).transpose (1 , 2 ) # (B, L) -> (B, L, 3) -> (B, 3, L)
70
+ pool1 = self .pool (self .act (x + self ._minus * mask [:, 0 :1 , :])) # (B, H, 1)
71
+ pool2 = self .pool (self .act (x + self ._minus * mask [:, 1 :2 , :]))
72
+ pool3 = self .pool (self .act (x + self ._minus * mask [:, 2 :3 , :]))
73
+ x = torch .cat ([pool1 , pool2 , pool3 ], 1 ) # (B, 3H, 1)
74
+ x = x .squeeze (2 ) # (B, 3H)
75
+ x = self .drop (x )
76
+
60
77
return x
61
78
62
79
def tokenize (self , item ):
@@ -69,25 +86,72 @@ def tokenize(self, item):
69
86
Return:
70
87
Name of the relation of the sentence
71
88
"""
72
- # Sentence -> token
73
- indexed_tokens , pos1 , pos2 = super ().tokenize (item )
74
- sentence = item ['text' ]
89
+ if 'text' in item :
90
+ sentence = item ['text' ]
91
+ is_token = False
92
+ else :
93
+ sentence = item ['token' ]
94
+ is_token = True
75
95
pos_head = item ['h' ]['pos' ]
76
- pos_tail = item ['t' ]['pos' ]
96
+ pos_tail = item ['t' ]['pos' ]
77
97
78
- # Mask
98
+ # Sentence -> token
99
+ if not is_token :
100
+ if pos_head [0 ] > pos_tail [0 ]:
101
+ pos_min , pos_max = [pos_tail , pos_head ]
102
+ rev = True
103
+ else :
104
+ pos_min , pos_max = [pos_head , pos_tail ]
105
+ rev = False
106
+ sent_0 = self .tokenizer .tokenize (sentence [:pos_min [0 ]])
107
+ sent_1 = self .tokenizer .tokenize (sentence [pos_min [1 ]:pos_max [0 ]])
108
+ sent_2 = self .tokenizer .tokenize (sentence [pos_max [1 ]:])
109
+ ent_0 = self .tokenizer .tokenize (sentence [pos_min [0 ]:pos_min [1 ]])
110
+ ent_1 = self .tokenizer .tokenize (sentence [pos_max [0 ]:pos_max [1 ]])
111
+ tokens = sent_0 + ent_0 + sent_1 + ent_1 + sent_2
112
+ if rev :
113
+ pos_tail = [len (sent_0 ), len (sent_0 ) + len (ent_0 )]
114
+ pos_head = [len (sent_0 ) + len (ent_0 ) + len (sent_1 ), len (sent_0 ) + len (ent_0 ) + len (sent_1 ) + len (ent_1 )]
115
+ else :
116
+ pos_head = [len (sent_0 ), len (sent_0 ) + len (ent_0 )]
117
+ pos_tail = [len (sent_0 ) + len (ent_0 ) + len (sent_1 ), len (sent_0 ) + len (ent_0 ) + len (sent_1 ) + len (ent_1 )]
118
+ else :
119
+ tokens = sentence
120
+
121
+ # Token -> index
122
+ if self .blank_padding :
123
+ indexed_tokens = self .tokenizer .convert_tokens_to_ids (tokens , self .max_length , self .token2id ['[PAD]' ], self .token2id ['[UNK]' ])
124
+ else :
125
+ indexed_tokens = self .tokenizer .convert_tokens_to_ids (tokens , unk_id = self .token2id ['[UNK]' ])
126
+
127
+ # Position -> index
128
+ pos1 = []
129
+ pos2 = []
79
130
pos1_in_index = min (pos_head [0 ], self .max_length )
80
131
pos2_in_index = min (pos_tail [0 ], self .max_length )
132
+ for i in range (len (tokens )):
133
+ pos1 .append (min (i - pos1_in_index + self .max_length , 2 * self .max_length - 1 ))
134
+ pos2 .append (min (i - pos2_in_index + self .max_length , 2 * self .max_length - 1 ))
135
+
136
+ if self .blank_padding :
137
+ while len (pos1 ) < self .max_length :
138
+ pos1 .append (0 )
139
+ while len (pos2 ) < self .max_length :
140
+ pos2 .append (0 )
141
+ indexed_tokens = indexed_tokens [:self .max_length ]
142
+ pos1 = pos1 [:self .max_length ]
143
+ pos2 = pos2 [:self .max_length ]
81
144
145
+ indexed_tokens = torch .tensor (indexed_tokens ).long ().unsqueeze (0 ) # (1, L)
146
+ pos1 = torch .tensor (pos1 ).long ().unsqueeze (0 ) # (1, L)
147
+ pos2 = torch .tensor (pos2 ).long ().unsqueeze (0 ) # (1, L)
148
+
149
+ # Mask
82
150
mask = []
83
- pos_min = min (pos1_in_index , pos2_in_index )
84
- pos_max = max (pos1_in_index , pos2_in_index )
85
- for i in range (len (indexed_tokens )):
86
- if pos1 [0 ][i ] == 0 :
87
- break
88
- if i <= pos_min :
151
+ for i in range (len (tokens )):
152
+ if i <= pos_min [0 ]:
89
153
mask .append (1 )
90
- elif i <= pos_max :
154
+ elif i <= pos_max [ 0 ] :
91
155
mask .append (2 )
92
156
else :
93
157
mask .append (3 )
0 commit comments