-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodel.py
67 lines (55 loc) · 2.47 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import torch.nn.functional as F
def normalize_adj(adj):
# Row-normalize matrix
last_dim = adj.size(-1)
rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
return torch.div(adj, rowsum)
def graph_pooling(inputs, num_vertices):
out = inputs.sum(1)
return torch.div(out, num_vertices.unsqueeze(-1).expand_as(out))
class DirectedGraphConvolution(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight1 = nn.Parameter(torch.zeros((in_features, out_features)))
self.weight2 = nn.Parameter(torch.zeros((in_features, out_features)))
self.dropout = nn.Dropout(0.1)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight1.data)
nn.init.xavier_uniform_(self.weight2.data)
def forward(self, inputs, adj):
norm_adj = normalize_adj(adj)
output1 = F.relu(torch.matmul(norm_adj, torch.matmul(inputs, self.weight1)))
inv_norm_adj = normalize_adj(adj.transpose(1, 2))
output2 = F.relu(torch.matmul(inv_norm_adj, torch.matmul(inputs, self.weight2)))
out = (output1 + output2) / 2
out = self.dropout(out)
return out
def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
class NeuralPredictor(nn.Module):
def __init__(self, initial_hidden=5, gcn_hidden=144, gcn_layers=3, linear_hidden=128):
super().__init__()
self.gcn = [DirectedGraphConvolution(initial_hidden if i == 0 else gcn_hidden, gcn_hidden)
for i in range(gcn_layers)]
self.gcn = nn.ModuleList(self.gcn)
self.dropout = nn.Dropout(0.1)
self.fc1 = nn.Linear(gcn_hidden, linear_hidden, bias=False)
self.fc2 = nn.Linear(linear_hidden, 1, bias=False)
def forward(self, inputs):
numv, adj, out = inputs["num_vertices"], inputs["adjacency"], inputs["operations"]
gs = adj.size(1) # graph node number
adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device)) # assuming diagonal is not 1
for layer in self.gcn:
out = layer(out, adj_with_diag)
out = graph_pooling(out, numv)
out = self.fc1(out)
out = self.dropout(out)
out = self.fc2(out).view(-1)
return out