Skip to content

Commit f5ed8e1

Browse files
author
cyente
authored
Add files via upload
1 parent 5e99db1 commit f5ed8e1

File tree

1 file changed

+190
-0
lines changed

1 file changed

+190
-0
lines changed

model/model_multimodal.py

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
from tensorflow.contrib.rnn import GRUCell
4+
from tensorflow.contrib import layers
5+
6+
7+
######################model##########################
8+
def weights(name, hidden_size, i):
9+
image_stdv = np.sqrt(1. / (2048))
10+
text_stdv = np.sqrt(1. / (2757))
11+
hidden_stdv = np.sqrt(1. / (hidden_size))
12+
if name == 'in_image':
13+
w = tf.get_variable(name='w/in_image_'+ str(i),
14+
shape=[2048, hidden_size],
15+
initializer=tf.random_normal_initializer(stddev=image_stdv))
16+
#w = tf.get_variable(name='gnn/w/in_image_', shape=[2048, hidden_size], initializer=tf.random_normal_initializer)
17+
if name == 'out_image':
18+
w = tf.get_variable(name='w/out_image_' + str(i),
19+
shape=[hidden_size, 2048],
20+
initializer=tf.random_normal_initializer(stddev=image_stdv))
21+
#w = tf.get_variable(name='w/out_image_', shape=[hidden_size, 2048], initializer=tf.random_normal_initializer)
22+
if name == 'in_text':
23+
w = tf.get_variable(name='w/in_text_'+ str(i),
24+
shape=[2757, hidden_size],
25+
initializer=tf.random_normal_initializer(stddev=image_stdv))
26+
#w = tf.get_variable(name='gnn/w/in_image_', shape=[2048, hidden_size], initializer=tf.random_normal_initializer)
27+
if name == 'out_text':
28+
w = tf.get_variable(name='w/out_text_' + str(i),
29+
shape=[hidden_size, 2757],
30+
initializer=tf.random_normal_initializer(stddev=image_stdv))
31+
#w = tf.get_variable(name='w/out_image_', shape=[hidden_size, 2048], initializer=tf.random_normal_initializer)
32+
if name == 'image_hidden_state_out':
33+
w = tf.get_variable(name='w/image_hidden_state_out' + str(i),
34+
shape=[hidden_size, hidden_size],
35+
initializer=tf.random_normal_initializer(stddev=hidden_stdv))
36+
#w = tf.get_variable(name='w/hidden_state_out_' + str(i), shape=[hidden_size, hidden_size], initializer=tf.random_normal_initializer)
37+
if name == 'image_hidden_state_in':
38+
#w = tf.get_variable(name='w/hidden_state_in_', shape=[hidden_size, hidden_size], initializer=tf.random_normal_initializer)
39+
w = tf.get_variable(name='w/image_hidden_state_in_' + str(i),
40+
shape=[hidden_size, hidden_size],
41+
initializer=tf.random_normal_initializer(stddev=hidden_stdv))
42+
if name == 'text_hidden_state_out':
43+
w = tf.get_variable(name='w/text_hidden_state_out' + str(i),
44+
shape=[hidden_size, hidden_size],
45+
initializer=tf.random_normal_initializer(stddev=hidden_stdv))
46+
#w = tf.get_variable(name='w/hidden_state_out_' + str(i), shape=[hidden_size, hidden_size], initializer=tf.random_normal_initializer)
47+
if name == 'text_hidden_state_in':
48+
#w = tf.get_variable(name='w/hidden_state_in_', shape=[hidden_size, hidden_size], initializer=tf.random_normal_initializer)
49+
w = tf.get_variable(name='w/text_hidden_state_in_' + str(i),
50+
shape=[hidden_size, hidden_size],
51+
initializer=tf.random_normal_initializer(stddev=hidden_stdv))
52+
53+
54+
return w
55+
56+
57+
def biases(name, hidden_size, i):
58+
image_stdv = np.sqrt(1. / (2048))
59+
hidden_stdv = np.sqrt(1. / (hidden_size))
60+
if name == 'image_hidden_state_out':
61+
b = tf.get_variable(name='b/image_hidden_state_out' + str(i), shape=[hidden_size],
62+
initializer=tf.random_normal_initializer(stddev=hidden_stdv))
63+
# b = tf.get_variable(name='b/hidden_state_out', shape=[hidden_size],
64+
# initializer=tf.random_normal_initializer)
65+
if name == 'image_hidden_state_in':
66+
b = tf.get_variable(name='b/image_hidden_state_in' + str(i), shape=[hidden_size],
67+
initializer=tf.random_normal_initializer(stddev=hidden_stdv))
68+
# b = tf.get_variable(name='b/hidden_state_in', shape=[hidden_size],
69+
# initializer=tf.random_normal_initializer)
70+
if name == 'text_hidden_state_out':
71+
b = tf.get_variable(name='b/text_hidden_state_out' + str(i), shape=[hidden_size],
72+
initializer=tf.random_normal_initializer(stddev=hidden_stdv))
73+
# b = tf.get_variable(name='b/hidden_state_out', shape=[hidden_size],
74+
# initializer=tf.random_normal_initializer)
75+
if name == 'text_hidden_state_in':
76+
b = tf.get_variable(name='b/text_hidden_state_in' + str(i), shape=[hidden_size],
77+
initializer=tf.random_normal_initializer(stddev=hidden_stdv))
78+
# b = tf.get_variable(name='b/hidden_state_in', shape=[hidden_size],
79+
# initializer=tf.random_normal_initializer)
80+
if name == 'out_image':
81+
# b = tf.get_variable(name='b/out_image_', shape=[2048],
82+
# initializer=tf.random_normal_initializer)
83+
b = tf.get_variable(name='b/out_image_' + str(i), shape=[2048],
84+
initializer=tf.random_normal_initializer(stddev=image_stdv))
85+
if name == 'out_text':
86+
# b = tf.get_variable(name='b/out_image_', shape=[2048],
87+
# initializer=tf.random_normal_initializer)
88+
b = tf.get_variable(name='b/out_text_' + str(i), shape=[2757],
89+
initializer=tf.random_normal_initializer(stddev=image_stdv))
90+
91+
return b
92+
93+
94+
def message_pass(label, x, hidden_size, batch_size, num_category, graph):
95+
96+
w_hidden_state = weights(label + '_hidden_state_out', hidden_size, 0)
97+
#b_hidden_state = biases('hidden_state_out', hidden_size, 0)
98+
x_all = tf.reshape(tf.matmul(
99+
tf.reshape(x[:,0,:], [batch_size, hidden_size]),
100+
w_hidden_state),
101+
[batch_size, hidden_size])
102+
for i in range(1, num_category):
103+
w_hidden_state = weights(label + '_hidden_state_out', hidden_size, i)
104+
#b_hidden_state = biases('hidden_state_out', hidden_size, i)
105+
x_all_ = tf.reshape(tf.matmul(
106+
tf.reshape(x[:, i, :], [batch_size, hidden_size]),
107+
w_hidden_state),
108+
[batch_size, hidden_size])
109+
x_all = tf.concat([x_all, x_all_], 1)
110+
x_all = tf.reshape(x_all, [batch_size, num_category, hidden_size])
111+
x_all = tf.transpose(x_all, (0, 2, 1)) # [batch_size, hidden_size, num_category]
112+
113+
x_ = x_all[0]
114+
graph_ = graph[0]
115+
x = tf.matmul(x_, graph_)
116+
for i in range(1, batch_size):
117+
x_ = x_all[i]
118+
graph_ = graph[i]
119+
x_ = tf.matmul(x_, graph_)
120+
x = tf.concat([x, x_], 0)
121+
x = tf.reshape(x, [batch_size, hidden_size, num_category])
122+
x = tf.transpose(x, (0, 2, 1))
123+
124+
x_ = tf.reshape(tf.matmul(x[:, 0, :], weights(label + '_hidden_state_in', hidden_size, 0)),
125+
[batch_size, hidden_size])
126+
for j in range(1, num_category):
127+
_x = tf.reshape(tf.matmul(x[:, j, :], weights(label + '_hidden_state_in', hidden_size, j)),
128+
[batch_size, hidden_size])
129+
x_ = tf.concat([x_, _x], 1)
130+
x = tf.reshape(x_, [batch_size, num_category, hidden_size])
131+
132+
return x
133+
134+
135+
136+
#def GNN(image, batch_size, hidden_size, keep_prob, n_steps, mask_x, num_category, graph):
137+
def GNN(label, data, batch_size, hidden_size, n_steps, num_category, graph):
138+
139+
gru_cell = GRUCell(hidden_size)
140+
w_in = weights('in_' + label, hidden_size, 0)
141+
h0 = tf.reshape(tf.matmul(data[:,0,:], w_in), [batch_size, hidden_size]) #initialize h0 [batchsize, hidden_state]
142+
for i in range(1, num_category):
143+
w_in = weights('in_' + label, hidden_size, i)
144+
h0 = tf.concat([h0, tf.reshape(
145+
tf.matmul(data[:,i,:], w_in), [batch_size, hidden_size])
146+
], 1)
147+
h0 = tf.reshape(h0, [batch_size, num_category, hidden_size]) # h0: [batchsize, num_category, hidden_state]
148+
ini = h0
149+
h0 = tf.nn.tanh(h0)
150+
151+
state = h0
152+
sum_graph = tf.reduce_sum(graph, reduction_indices=1)
153+
enable_node = tf.cast(tf.cast(sum_graph, dtype=bool), dtype=tf.float32)
154+
155+
with tf.variable_scope("gnn"):
156+
for step in range(n_steps):
157+
if step > 0: tf.get_variable_scope().reuse_variables()
158+
# state = state * mask_x
159+
x = message_pass(label, state, hidden_size, batch_size, num_category, graph)
160+
# x = tf.reshape(x, [batch_size*num_category, hidden_size])
161+
# state = tf.reshape(state, [batch_size*num_category, hidden_size])
162+
(x_new, state_new) = gru_cell(x[0], state[0])
163+
state_new = tf.transpose(state_new, (1, 0))
164+
state_new = tf.multiply(state_new, enable_node[0])
165+
state_new = tf.transpose(state_new, (1, 0))
166+
for i in range(1, batch_size):
167+
(x_, state_) = gru_cell(x[i], state[i]) # #input of GRUCell must be 2 rank, not 3 rank
168+
state_ = tf.transpose(state_, (1, 0))
169+
state_ = tf.multiply(state_, enable_node[i])
170+
state_ = tf.transpose(state_, (1, 0))
171+
state_new = tf.concat([state_new, state_], 0)
172+
# x = tf.reshape(x, [batch_size, num_category, hidden_size])
173+
state = tf.reshape(state_new, [batch_size, num_category, hidden_size]) # #restore: 2 rank to 3 rank
174+
# state = state * mask_x
175+
# state = tf.nn.dropout(state, keep_prob)
176+
177+
# w_out_image = weights('out_image', hidden_size, 0)
178+
# b_out_image = biases('out_image', hidden_size, 0)
179+
# output = tf.reshape(tf.matmul(state[:, 0, :], w_out_image) + b_out_image, [batch_size, 2048]) #initialize output : [batchsize, 2048]
180+
# for i in range(1, num_category):
181+
# w_out_image = weights('out_image', hidden_size, i)
182+
# b_out_image = biases('out_image', hidden_size, i)
183+
# output = tf.concat([output, tf.reshape(
184+
# tf.matmul(state[:, i, :], w_out_image) + b_out_image,
185+
# [batch_size, 2048])], 1)
186+
# output = tf.reshape(output, [batch_size, num_category, 2048])
187+
# output = tf.nn.tanh(output)
188+
189+
return state, ini
190+

0 commit comments

Comments
 (0)