-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodel.py
62 lines (46 loc) · 2.12 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from torch import nn
from modules import (ResidualModuleWrapper, FeedForwardModule, GCNModule, SAGEModule, GATModule, GATSepModule,
TransformerAttentionModule, TransformerAttentionSepModule)
MODULES = {
'ResNet': [FeedForwardModule],
'GCN': [GCNModule],
'SAGE': [SAGEModule],
'GAT': [GATModule],
'GAT-sep': [GATSepModule],
'GT': [TransformerAttentionModule, FeedForwardModule],
'GT-sep': [TransformerAttentionSepModule, FeedForwardModule]
}
NORMALIZATION = {
'None': nn.Identity,
'LayerNorm': nn.LayerNorm,
'BatchNorm': nn.BatchNorm1d
}
class Model(nn.Module):
def __init__(self, model_name, num_layers, input_dim, hidden_dim, output_dim, hidden_dim_multiplier, num_heads,
normalization, dropout):
super().__init__()
normalization = NORMALIZATION[normalization]
self.input_linear = nn.Linear(in_features=input_dim, out_features=hidden_dim)
self.dropout = nn.Dropout(p=dropout)
self.act = nn.GELU()
self.residual_modules = nn.ModuleList()
for _ in range(num_layers):
for module in MODULES[model_name]:
residual_module = ResidualModuleWrapper(module=module,
normalization=normalization,
dim=hidden_dim,
hidden_dim_multiplier=hidden_dim_multiplier,
num_heads=num_heads,
dropout=dropout)
self.residual_modules.append(residual_module)
self.output_normalization = normalization(hidden_dim)
self.output_linear = nn.Linear(in_features=hidden_dim, out_features=output_dim)
def forward(self, graph, x):
x = self.input_linear(x)
x = self.dropout(x)
x = self.act(x)
for residual_module in self.residual_modules:
x = residual_module(graph, x)
x = self.output_normalization(x)
x = self.output_linear(x).squeeze(1)
return x