Skip to content

Commit

Permalink
[Misc] Black auto fix. (dmlc#4642)
Browse files Browse the repository at this point in the history
* [Misc] Black auto fix.

* sort

Co-authored-by: Steve <[email protected]>
  • Loading branch information
frozenbugs and Steve authored Sep 26, 2022
1 parent a9f2acf commit 23d0905
Show file tree
Hide file tree
Showing 99 changed files with 6,243 additions and 3,552 deletions.
84 changes: 61 additions & 23 deletions examples/pytorch/GATNE-T/src/main.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
from collections import defaultdict
import math
import os
import sys
import time
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from numpy import random
from torch.nn.parameter import Parameter
from tqdm.auto import tqdm
from utils import *

import dgl
import dgl.function as fn

from utils import *


def get_graph(network_data, vocab):
""" Build graph, treat all nodes as the same type
"""Build graph, treat all nodes as the same type
Parameters
----------
Expand Down Expand Up @@ -57,7 +57,9 @@ def __init__(self, g, num_fanouts):

def sample(self, pairs):
heads, tails, types = zip(*pairs)
seeds, head_invmap = torch.unique(torch.LongTensor(heads), return_inverse=True)
seeds, head_invmap = torch.unique(
torch.LongTensor(heads), return_inverse=True
)
blocks = []
for fanout in reversed(self.num_fanouts):
sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
Expand Down Expand Up @@ -90,7 +92,9 @@ def __init__(
self.edge_type_count = edge_type_count
self.dim_a = dim_a

self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
self.node_embeddings = Parameter(
torch.FloatTensor(num_nodes, embedding_size)
)
self.node_type_embeddings = Parameter(
torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
)
Expand All @@ -100,16 +104,24 @@ def __init__(
self.trans_weights_s1 = Parameter(
torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
)
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))
self.trans_weights_s2 = Parameter(
torch.FloatTensor(edge_type_count, dim_a, 1)
)

self.reset_parameters()

def reset_parameters(self):
self.node_embeddings.data.uniform_(-1.0, 1.0)
self.node_type_embeddings.data.uniform_(-1.0, 1.0)
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
self.trans_weights.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
self.trans_weights_s1.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)
self.trans_weights_s2.data.normal_(
std=1.0 / math.sqrt(self.embedding_size)
)

# embs: [batch_size, embedding_size]
def forward(self, block):
Expand All @@ -122,10 +134,16 @@ def forward(self, block):
with block.local_scope():
for i in range(self.edge_type_count):
edge_type = self.edge_types[i]
block.srcdata[edge_type] = self.node_type_embeddings[input_nodes, i]
block.dstdata[edge_type] = self.node_type_embeddings[output_nodes, i]
block.srcdata[edge_type] = self.node_type_embeddings[
input_nodes, i
]
block.dstdata[edge_type] = self.node_type_embeddings[
output_nodes, i
]
block.update_all(
fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type
fn.copy_u(edge_type, "m"),
fn.sum("m", edge_type),
etype=edge_type,
)
node_type_embed.append(block.dstdata[edge_type])

Expand All @@ -152,7 +170,9 @@ def forward(self, block):
attention = (
F.softmax(
torch.matmul(
torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)),
torch.tanh(
torch.matmul(tmp_node_type_embed, trans_w_s1)
),
trans_w_s2,
)
.squeeze(2)
Expand All @@ -173,7 +193,9 @@ def forward(self, block):
)
last_node_embed = F.normalize(node_embed, dim=2)

return last_node_embed # [batch_size, edge_type_count, embedding_size]
return (
last_node_embed # [batch_size, edge_type_count, embedding_size]
)


class NSLoss(nn.Module):
Expand All @@ -187,7 +209,8 @@ def __init__(self, num_nodes, num_sampled, embedding_size):
self.sample_weights = F.normalize(
torch.Tensor(
[
(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
(math.log(k + 2) - math.log(k + 1))
/ math.log(num_nodes + 1)
for k in range(num_nodes)
]
),
Expand Down Expand Up @@ -257,14 +280,20 @@ def train_model(network_data):
pin_memory=True,
)
model = DGLGATNE(
num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a
num_nodes,
embedding_size,
embedding_u_size,
edge_types,
edge_type_count,
dim_a,
)
nsloss = NSLoss(num_nodes, num_sampled, embedding_size)
model.to(device)
nsloss.to(device)

optimizer = torch.optim.Adam(
[{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-3
[{"params": model.parameters()}, {"params": nsloss.parameters()}],
lr=1e-3,
)

best_score = 0
Expand All @@ -286,7 +315,10 @@ def train_model(network_data):
block_types = block_types.to(device)
embs = model(block[0].to(device))[head_invmap]
embs = embs.gather(
1, block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2])
1,
block_types.view(-1, 1, 1).expand(
embs.shape[0], 1, embs.shape[2]
),
)[:, 0]
loss = nsloss(
block[0].dstdata[dgl.NID][head_invmap].to(device),
Expand All @@ -307,15 +339,19 @@ def train_model(network_data):

model.eval()
# {'1': {}, '2': {}}
final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)]))
final_model = dict(
zip(edge_types, [dict() for _ in range(edge_type_count)])
)
for i in range(num_nodes):
train_inputs = (
torch.tensor([i for _ in range(edge_type_count)])
.unsqueeze(1)
.to(device)
) # [i, i]
train_types = (
torch.tensor(list(range(edge_type_count))).unsqueeze(1).to(device)
torch.tensor(list(range(edge_type_count)))
.unsqueeze(1)
.to(device)
) # [0, 1]
pairs = torch.cat(
(train_inputs, train_inputs, train_types), dim=1
Expand Down Expand Up @@ -343,7 +379,9 @@ def train_model(network_data):
valid_aucs, valid_f1s, valid_prs = [], [], []
test_aucs, test_f1s, test_prs = [], [], []
for i in range(edge_type_count):
if args.eval_type == "all" or edge_types[i] in args.eval_type.split(","):
if args.eval_type == "all" or edge_types[i] in args.eval_type.split(
","
):
tmp_auc, tmp_f1, tmp_pr = evaluate(
final_model[edge_types[i]],
valid_true_data_by_edge[edge_types[i]],
Expand Down
Loading

0 comments on commit 23d0905

Please sign in to comment.