-
Notifications
You must be signed in to change notification settings - Fork 3k
/
model.py
114 lines (93 loc) · 3.03 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__()
self.project = nn.Sequential(
nn.Linear(in_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False),
)
def forward(self, z):
w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1) # (N, D * K)
class HANLayer(nn.Module):
"""
HAN layer.
Arguments
---------
num_meta_paths : number of homogeneous graphs generated from the metapaths.
in_size : input feature dimension
out_size : output feature dimension
layer_num_heads : number of attention heads
dropout : Dropout probability
Inputs
------
g : list[DGLGraph]
List of graphs
h : tensor
Input features
Outputs
-------
tensor
The output feature
"""
def __init__(
self, num_meta_paths, in_size, out_size, layer_num_heads, dropout
):
super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList()
for i in range(num_meta_paths):
self.gat_layers.append(
GATConv(
in_size,
out_size,
layer_num_heads,
dropout,
dropout,
activation=F.elu,
)
)
self.semantic_attention = SemanticAttention(
in_size=out_size * layer_num_heads
)
self.num_meta_paths = num_meta_paths
def forward(self, gs, h):
semantic_embeddings = []
for i, g in enumerate(gs):
semantic_embeddings.append(self.gat_layers[i](g, h).flatten(1))
semantic_embeddings = torch.stack(
semantic_embeddings, dim=1
) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
class HAN(nn.Module):
def __init__(
self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout
):
super(HAN, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(
HANLayer(
num_meta_paths, in_size, hidden_size, num_heads[0], dropout
)
)
for l in range(1, len(num_heads)):
self.layers.append(
HANLayer(
num_meta_paths,
hidden_size * num_heads[l - 1],
hidden_size,
num_heads[l],
dropout,
)
)
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h):
for gnn in self.layers:
h = gnn(g, h)
return self.predict(h)