-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcontext_query_attention.py
163 lines (142 loc) · 5.93 KB
/
context_query_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class StructuredAttention(nn.Module):
"""Use each word in context to attend to words in query.
In my case, context is question-answer, query is the object-level
features in an image.
Note the values in S are cosine similarity scores, and are in [-1, 1]
They are scaled before softmax to make sure the maximum value could
get very high probability.
S_ = F.softmax(S * self.scale, dim=-1)
Consider softmax function f(m) = exp(m) / [24 * exp(-m) + exp(m)]
If not scaled, S * scale \in [-100, 100], the weight the maximum value could only get is
exp(1) / [24 * exp(-1) + exp(1)] = 0.04 .
When set the scale = 100, S * scale \in [-100, 100]
exp(100) / [24 * exp(-100) + exp(100)] = 0.9976
"""
def __init__(self, dropout=0.1, scale=100, add_void=False):
"""
Args:
dropout:
scale:
add_void:
"""
super(StructuredAttention, self).__init__()
self.dropout = dropout
self.scale = scale
self.add_void = add_void
def forward(self, C, Q, c_mask, q_mask, noun_mask=None, void_vector=None):
"""
match the dim of '*', singlton is allowed
Args:
C: (N, 5, Li, Lqa, D)
Q: (N, 1, Li, Lr, D)
c_mask: (N, 5, Li, Lqa)
q_mask: (N, 1, Li, Lr)
noun_mask: (N, 5, Lqa) , where 1 indicate the current position is a noun
or (N, 5, Li, Lqa), where each entry is the probability of the current
image being a positive bag for the word
void_vector: (D, )
Returns:
(N, *, Lc, D)
"""
bsz, _, num_img, num_region, hsz = Q.shape
if void_vector is not None:
num_void = len(void_vector)
Q_void = void_vector.view(1, 1, 1, num_void, hsz).repeat(bsz, 1, num_img, 1, 1)
Q = torch.cat([Q, Q_void], dim=-2) # (N, 1, Li, Lr+num_void, D)
q_mask_void = q_mask.new_ones(bsz, 1, num_img, num_void) # ones
q_mask = torch.cat([q_mask, q_mask_void], dim=-1) # (N, 1, Li, Lr+num_void)
S, S_mask = self.similarity(C, Q, c_mask, q_mask) # (N, 5, Li, Lqa, Lr+num_void)
S_ = F.softmax(S * self.scale, dim=-1)
# (N, 5, Li, Lqa, Lr+1) # the weight of each query word to a given context word
S_ = S_ * S_mask # for columns that are all padded elements
if noun_mask is not None:
if len(noun_mask.shape) == 3:
bsz, num_qa, lqa = noun_mask.shape
S_ = S_ * noun_mask.view(bsz, num_qa, 1, lqa, 1)
elif len(noun_mask.shape) == 4:
S_ = S_ * noun_mask.unsqueeze(-1)
else:
raise NotImplementedError
if void_vector is not None:
if self.add_void:
A = torch.matmul(S_, Q) # (N, 5, Li, Lqa, D)
S, S_mask, S_ = S[:, :, :, :, :-num_void], S_mask[:, :, :, :, :-num_void], S_[:, :, :, :, :-num_void]
else:
S, S_mask, S_ = S[:, :, :, :, :-num_void], S_mask[:, :, :, :, :-num_void], S_[:, :, :, :, :-num_void]
Q = Q[:, :, :, :-num_void, :] # (N, 1, Li, Lr, D)
A = torch.matmul(S_, Q) # (N, 5, Li, Lqa, D)
else:
A = torch.matmul(S_, Q) # (N, 5, Li, Lqa, D)
return A, S, S_mask, S_
def similarity(self, C, Q, c_mask, q_mask):
"""
word2word dot-product similarity
Args:
C: (N, 5, Li, Lqa, D)
Q: (N, 1, Li, Lr, D)
c_mask: (N, 5, Li, Lqa)
q_mask: (N, 1, Li, Lr)
Returns:
(N, *, Lc, Lq)
"""
C = F.dropout(F.normalize(C, p=2, dim=-1), p=self.dropout, training=self.training)
Q = F.dropout(F.normalize(Q, p=2, dim=-1), p=self.dropout, training=self.training)
S_mask = torch.matmul(c_mask.unsqueeze(-1), q_mask.unsqueeze(-2)) # (N, 5, Li, Lqa, Lr)
S = torch.matmul(C, Q.transpose(-2, -1)) # (N, 5, Li, Lqa, Lr)
masked_S = S - 1e10*(1 - S_mask) # (N, 5, Li, Lqa, Lr)
return masked_S, S_mask
class ContextQueryAttention(nn.Module):
"""
sub-a attention
"""
def __init__(self):
super(ContextQueryAttention, self).__init__()
def forward(self, C, Q, c_mask, q_mask):
"""
match the dim of '*', singlton is allowed
:param C: (N, *, Lc, D)
:param Q: (N, *, Lq, D)
:param c_mask: (N, *, Lc)
:param q_mask: (N, *, Lq)
:return: (N, Lc, D) and (N, Lq, D)
"""
S = self.similarity(C, Q, c_mask, q_mask) # (N, *, Lc, Lq)
S_ = F.softmax(S, dim=-1) # (N, *, Lc, Lq)
A = torch.matmul(S_, Q) # (N, *, Lc, D)
return A
def similarity(self, C, Q, c_mask, q_mask):
"""
word2word dot-product similarity
:param C: (N, *, Lc, D)
:param Q: (N, *, Lq, D)
:param c_mask: (N, *, Lc)
:param q_mask: (N, *, Lq)
:return: (N, *, Lc, Lq)
"""
C = F.dropout(C, p=0.1, training=self.training)
Q = F.dropout(Q, p=0.1, training=self.training)
hsz_root = math.sqrt(C.shape[-1])
S_mask = torch.matmul(c_mask.unsqueeze(-1), q_mask.unsqueeze(-2)) # (N, *, Lc, Lq)
S = torch.matmul(C, Q.transpose(-2, -1)) / hsz_root # (N, *, Lc, Lq)
masked_S = S - 1e10*(1 - S_mask) # (N, *, Lc, Lq)
return masked_S
def test():
# (N, *, D, Lc)
c2q = ContextQueryAttention()
hsz = 128
bsz = 10
lc = 20
lq = 10
context = torch.randn(bsz, hsz, lc).float()
context_mask = torch.ones(bsz, lc).float()
query = torch.randn(bsz, hsz, lq).float()
query_mask = torch.ones(bsz, lq).float()
a, b = c2q(context, query, context_mask, query_mask)
print("input size", context.shape, context_mask.shape, query.shape, query_mask.shape)
print("output size", a.shape, b.shape)
if __name__ == '__main__':
test()