-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
102 lines (74 loc) · 3.34 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
from torchvision import models
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
class RNN_LSTM(nn.Module):
def __init__(self, batch_size, embedding, cell_size, num_layers, bidirectional, batch_first = True,
dropout_probability = 0.5):
super(RNN_LSTM, self).__init__()
self.batch_size = batch_size
self.embeddings = nn.Embedding(embedding.shape[0], embedding.shape[1])
self.embeddings.load_state_dict({'weight': torch.from_numpy(embedding)})
self.embeddings.weight.requires_grad = False
self.lstm = nn.LSTM(embedding.shape[1], cell_size,
num_layers=num_layers, bidirectional=bidirectional, batch_first=batch_first,
dropout=dropout_probability)
def update_batch_size(self, batch_size):
self.batch_size = batch_size
def forward(self, hidden_state, cell_state, captions):
embeds = self.embeddings(captions)
outputs, (hidden_state, cell_state) = self.lstm(embeds, (hidden_state, cell_state))
return outputs, (hidden_state, cell_state)
class Vanilla_Text_Encoder(nn.Module):
def __init__(self, batch_size, embedding, cell_size, num_layers, bidirectional, GPU, gpu_nummber, batch_first =
True,
dropout_probability = 0.5):
super(Vanilla_Text_Encoder, self).__init__()
self.c_dim = cell_size
self.cnn = models.resnet18()
self.cnn.fc = nn.Sequential(
nn.Linear(512, cell_size * 2, bias=False),
nn.ReLU())
self.rnn = RNN_LSTM(batch_size, embedding, cell_size, num_layers, bidirectional, batch_first,
dropout_probability)
self.fc = nn.Sequential(
nn.Linear(cell_size, embedding.shape[0], bias=True),
nn.BatchNorm1d(embedding.shape[0]),
nn.ReLU()
)
self.GPU = GPU
self.gpu_number = gpu_nummber
self.batch_size = batch_size
self.cell_size = cell_size
def cond_aug_network(self, img_encoding):
mu = img_encoding[:, :self.c_dim]
logvar = img_encoding[:, self.c_dim:]
std = logvar.mul(0.5).exp_()
eps = torch.FloatTensor(std.size()).normal_()
if self.GPU:
eps = eps.cuda(self.gpu_number)
hidden = eps.mul(std).add_(mu)
return hidden
def forward(self, images, captions):
cnn_output = self.cnn(images)
hidden_state = self.cond_aug_network(cnn_output)
cell_state = torch.FloatTensor(hidden_state.shape[0], hidden_state.shape[1])
cell_state.data.normal_(0, 1)
if self.GPU:
cell_state = cell_state.cuda(self.gpu_number)
hidden_state = hidden_state.unsqueeze(0)
cell_state = cell_state.unsqueeze(0)
outputs, _ = self.rnn(hidden_state, cell_state, captions)
predictions = None
batch_size, sent_len, num_features = outputs.shape
outputs = outputs.contiguous().view(sent_len, batch_size, num_features)
for output in outputs:
if predictions is None:
predictions = self.fc(output).unsqueeze(1)
else:
predictions = torch.cat((predictions, self.fc(output).unsqueeze(1)), dim = 1)
return predictions