-
Notifications
You must be signed in to change notification settings - Fork 3k
/
ggnn_ns.py
65 lines (51 loc) · 1.84 KB
/
ggnn_ns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Gated Graph Neural Network module for node selection tasks
"""
import dgl
import torch
from dgl.nn.pytorch import GatedGraphConv
from torch import nn
class NodeSelectionGGNN(nn.Module):
def __init__(self, annotation_size, out_feats, n_steps, n_etypes):
super(NodeSelectionGGNN, self).__init__()
self.annotation_size = annotation_size
self.out_feats = out_feats
self.ggnn = GatedGraphConv(
in_feats=out_feats,
out_feats=out_feats,
n_steps=n_steps,
n_etypes=n_etypes,
)
self.output_layer = nn.Linear(annotation_size + out_feats, 1)
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, graph, labels=None):
etypes = graph.edata.pop("type")
annotation = graph.ndata.pop("annotation").float()
assert annotation.size()[-1] == self.annotation_size
node_num = graph.num_nodes()
zero_pad = torch.zeros(
[node_num, self.out_feats - self.annotation_size],
dtype=torch.float,
device=annotation.device,
)
h1 = torch.cat([annotation, zero_pad], -1)
out = self.ggnn(graph, h1, etypes)
all_logits = self.output_layer(
torch.cat([out, annotation], -1)
).squeeze(-1)
graph.ndata["logits"] = all_logits
batch_g = dgl.unbatch(graph)
preds = []
if labels is not None:
loss = 0.0
for i, g in enumerate(batch_g):
logits = g.ndata["logits"]
preds.append(torch.argmax(logits))
if labels is not None:
logits = logits.unsqueeze(0)
y = labels[i].unsqueeze(0)
loss += self.loss_fn(logits, y)
if labels is not None:
loss /= float(len(batch_g))
return loss, preds
return preds