-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model] Add HGAT #213
base: main
Are you sure you want to change the base?
[Model] Add HGAT #213
Conversation
gammagl/models/hgat.py
Outdated
# out_dict={} | ||
# for node_type, _ in x_dict.items(): | ||
# out_dict[node_type]=[] | ||
# for edge_type, edge_index in edge_index_dict.items(): | ||
# src_type, _, dst_type = edge_type | ||
# src = edge_index[0,:] | ||
# dst = edge_index[1,:] | ||
# message = unsorted_segment_sum(tlx.gather(x_dict[src_type],src),dst,num_nodes_dict[dst_type]) | ||
# out_dict[dst_type].append(message) | ||
# for node_type, outs in out_dict.items(): | ||
# aggr_out = tlx.reduce_sum(outs,axis=0) | ||
# out_dict[node_type]=tlx.relu(aggr_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
多余注释也删除
data = HeteroGraph() | ||
|
||
node_types = ['documents', 'topics', 'words'] | ||
for i, node_type in enumerate(node_types): | ||
x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz')) | ||
data[node_type].x = tlx.convert_to_tensor(x.todense(), dtype=tlx.float32) | ||
|
||
y = np.load(osp.join(self.raw_dir, 'labels.npy')) | ||
y = np.argmax(y,axis=1) | ||
data['documents'].y = tlx.convert_to_tensor(y, dtype=tlx.int64) | ||
|
||
split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz')) | ||
for name in ['train', 'val', 'test']: | ||
idx = split[f'{name}_idx'] | ||
mask = np.zeros(data['documents'].num_nodes, dtype=np.bool_) | ||
mask[idx] = True | ||
data['documents'][f'{name}_mask'] = tlx.convert_to_tensor(mask, dtype=tlx.bool) | ||
|
||
|
||
s = {} | ||
N_m = data['documents'].num_nodes | ||
N_d = data['topics'].num_nodes | ||
N_a = data['words'].num_nodes | ||
s['documents'] = (0, N_m) | ||
s['topics'] = (N_m, N_m + N_d) | ||
s['words'] = (N_m + N_d, N_m + N_d + N_a) | ||
|
||
A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz')).tocsr() | ||
for src, dst in product(node_types, node_types): | ||
A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo() | ||
if A_sub.nnz > 0: | ||
row = tlx.convert_to_tensor(A_sub.row, dtype=tlx.int64) | ||
col = tlx.convert_to_tensor(A_sub.col, dtype=tlx.int64) | ||
data[src, dst].edge_index = tlx.stack([row, col], axis=0) | ||
print(src+"____"+dst) | ||
|
||
if self.pre_transform is not None: | ||
data = self.pre_transform(data) | ||
|
||
self.save_data(self.collate([data]), self.processed_paths[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改成从作者提供的url下载数据集,然后处理成heterograph
from gammagl.data import (HeteroGraph, InMemoryDataset, download_url, | ||
extract_zip) | ||
|
||
class OHSUMED(InMemoryDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
写一个test文件,参考gammagl/datasets/路径下的其他测试文件
Add a new model: HGAT
corresponding with two conv layer, HINConv and HGATConv
and a new method to convert the Heterogeneous graph dict to homogeneous graph matrix: heter_homo_mutual_convert