Skip to content

Commit

Permalink
Merge branch 'master' into gb_cuda_multigpu_example_refinement
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Jan 19, 2024
2 parents ef9d29f + 2e6ded0 commit dcbc768
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions python/dgl/graphbolt/feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def record_stream(tensor):
nodes = input_nodes[type_name]
if nodes is None:
continue
if nodes.is_cuda:
nodes.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
node_features[
(type_name, feature_name)
Expand All @@ -104,6 +106,8 @@ def record_stream(tensor):
)
)
else:
if input_nodes.is_cuda:
input_nodes.record_stream(torch.cuda.current_stream())
for feature_name in self.node_feature_keys:
node_features[feature_name] = record_stream(
self.feature_store.read(
Expand Down Expand Up @@ -134,6 +138,8 @@ def record_stream(tensor):
edges = original_edge_ids.get(type_name, None)
if edges is None:
continue
if edges.is_cuda:
edges.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
edge_features[i][
(type_name, feature_name)
Expand All @@ -143,6 +149,10 @@ def record_stream(tensor):
)
)
else:
if original_edge_ids.is_cuda:
original_edge_ids.record_stream(
torch.cuda.current_stream()
)
for feature_name in self.edge_feature_keys:
edge_features[i][feature_name] = record_stream(
self.feature_store.read(
Expand Down

0 comments on commit dcbc768

Please sign in to comment.