-
Notifications
You must be signed in to change notification settings - Fork 148
/
Copy pathRSHN_sampler.py
95 lines (76 loc) · 3.06 KB
/
RSHN_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import dgl
import torch as th
from dgl.data.utils import load_graphs, save_graphs
class coarsened_line_graph():
def __init__(self, rw_len, batch_size, n_dataset, symmetric=True):
self.rw_len = rw_len
self.batch_size = batch_size
self.n_dataset = n_dataset
self.symmetric = symmetric # which means the original graph had inverse edges
return
def get_cl_graph(self, hg):
fname = './openhgnn/output/RSHN/{}_cl_graoh_{}_{}.bin'.format(
self.n_dataset, self.rw_len, self.batch_size)
if os.path.exists(fname):
g, _ = load_graphs(fname)
return g[0]
else:
g = self.build_cl_graph(hg)
save_graphs(fname, g)
return g
def init_cl_graph(self, cl_graph):
cl_graph = give_one_hot_feats(cl_graph, 'h')
cl_graph = dgl.remove_self_loop(cl_graph)
edge_attr = cl_graph.edata['w'].type(th.FloatTensor).to(cl_graph.device)
row, col = cl_graph.edges()
for i in range(cl_graph.num_nodes()):
mask = th.eq(row, i)
edge_attr[mask] = th.nn.functional.normalize(edge_attr[mask], p=2, dim=0)
# add_self_loop, set 1 as edge feature
cl_graph = dgl.add_self_loop(cl_graph)
edge_attr = th.cat([edge_attr, th.ones(cl_graph.num_nodes(), device=edge_attr.device)], dim=0)
cl_graph.edata['w'] = edge_attr
return cl_graph
def build_cl_graph(self, hg):
if not hg.is_homogeneous:
self.num_edge_type = len(hg.etypes)
g = dgl.to_homogeneous(hg).to('cpu')
traces = self.random_walks(g)
edge_batch = self.rw_map_edge_type(g, traces)
cl_graph = self.edge2graph(edge_batch)
return cl_graph
def random_walks(self, g):
source_nodes = th.randint(0, g.number_of_nodes(), (self.batch_size,))
traces, _ = dgl.sampling.random_walk(g, source_nodes, length=self.rw_len-1)
return traces
def rw_map_edge_type(self, g, traces):
edge_type = g.edata[dgl.ETYPE].long()
edge_batch = []
first_flag = True
for t in traces:
u = t[:-1]
v = t[1:]
edge_path = edge_type[g.edge_ids(u, v)].unsqueeze(0)
if first_flag == True:
edge_batch = edge_path
first_flag = False
else:
edge_batch = th.cat((edge_batch, edge_path), dim=0)
return edge_batch
def edge2graph(self, edge_batch):
u = edge_batch[:, :-1].reshape(-1)
v = edge_batch[:, 1:].reshape(-1)
if self.symmetric:
tmp = u
u = th.cat((u, v), dim=0)
v = th.cat((v,tmp), dim=0)
g = dgl.graph((u, v))
sg = dgl.to_simple(g, return_counts='w')
return sg
def give_one_hot_feats(g, ntype='h'):
# if the nodes are featureless, the input feature is then the node id.
num_nodes = g.num_nodes()
#g.ndata[ntype] = th.arange(num_nodes, dtype=th.float32, device=g.device)
g.ndata[ntype] = th.eye(num_nodes).to(g.device)
return g