From 85ac333c5fd24282b5fb91dce66ae0d3b48383a6 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin Date: Thu, 18 Jan 2024 21:03:11 +0000 Subject: [PATCH] add the patch --- python/dgl/graphbolt/feature_fetcher.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/dgl/graphbolt/feature_fetcher.py b/python/dgl/graphbolt/feature_fetcher.py index 7b9ed6817afb..b830f0ec1c8a 100644 --- a/python/dgl/graphbolt/feature_fetcher.py +++ b/python/dgl/graphbolt/feature_fetcher.py @@ -134,6 +134,8 @@ def record_stream(tensor): edges = original_edge_ids.get(type_name, None) if edges is None: continue + if edges.is_cuda: + edges.record_stream(torch.cuda.current_stream()) for feature_name in feature_names: edge_features[i][ (type_name, feature_name) @@ -143,6 +145,10 @@ def record_stream(tensor): ) ) else: + if original_edge_ids.is_cuda: + original_edge_ids.record_stream( + torch.cuda.current_stream() + ) for feature_name in self.edge_feature_keys: edge_features[i][feature_name] = record_stream( self.feature_store.read(