Skip to content

Commit

Permalink
add the patch
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 18, 2024
1 parent f7e065f commit 85ac333
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions python/dgl/graphbolt/feature_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def record_stream(tensor):
edges = original_edge_ids.get(type_name, None)
if edges is None:
continue
if edges.is_cuda:
edges.record_stream(torch.cuda.current_stream())
for feature_name in feature_names:
edge_features[i][
(type_name, feature_name)
Expand All @@ -143,6 +145,10 @@ def record_stream(tensor):
)
)
else:
if original_edge_ids.is_cuda:
original_edge_ids.record_stream(
torch.cuda.current_stream()
)
for feature_name in self.edge_feature_keys:
edge_features[i][feature_name] = record_stream(
self.feature_store.read(
Expand Down

0 comments on commit 85ac333

Please sign in to comment.