From 9f15737e6a3996a2d6d8463a5cc0b9e45405ca03 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Tue, 16 Jan 2024 22:12:32 -0500 Subject: [PATCH 01/12] initial code --- .../sampling/graphbolt/node_classification.py | 4 ++- .../impl/fused_csc_sampling_graph.py | 31 ++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/examples/sampling/graphbolt/node_classification.py b/examples/sampling/graphbolt/node_classification.py index f589e667b455..d0854b844968 100644 --- a/examples/sampling/graphbolt/node_classification.py +++ b/examples/sampling/graphbolt/node_classification.py @@ -384,8 +384,10 @@ def main(args): dataset = gb.BuiltinDataset("ogbn-products").load() # Move the dataset to the selected storage. - graph = dataset.graph.to(args.storage_device) + # graph = dataset.graph.to(args.storage_device) features = dataset.feature.to(args.storage_device) + dataset.graph.pin_memory_() + graph = dataset.graph train_set = dataset.tasks[0].train_set valid_set = dataset.tasks[0].validation_set diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 8f026b3c5095..462ac8b7e557 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -975,8 +975,37 @@ def _pin(x): def pin_memory_(self): """Copy `FusedCSCSamplingGraph` to the pinned memory in-place.""" + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + cudart = torch.cuda.cudart() + def _pin(x): - return x.pin_memory() if hasattr(x, "pin_memory") else x + if hasattr(x, "pin_memory_"): + x.pin_memory_() + elif ( + isinstance(x, torch.Tensor) + and not x.is_pinned() + and not x.is_cuda + and x.is_contiguous() + ): + # x.share_memory_() + ptr = x.data_ptr() + assert ( + cudart.cudaHostRegister( + ptr, x.numel() * x.element_size(), 0 + ) + == 0 + ) + # assert x.is_shared() + assert x.is_pinned() + print(x.is_pinned(), x.is_shared()) + + def new_del(self): + assert cudart.cudaHostUnregister(ptr) == 0 + torch.Tensor.__del__(self) + + x.__del__ = new_del + + return x self._apply_to_members(_pin) From e33ca341993fa134c6f600b55319b9a30c66e945 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Wed, 17 Jan 2024 04:15:11 +0000 Subject: [PATCH 02/12] add implementation --- .../impl/fused_csc_sampling_graph.py | 23 +++++++++---------- .../impl/torch_based_feature_store.py | 21 +++++++++++++++-- 2 files changed, 30 insertions(+), 14 deletions(-) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 462ac8b7e557..a33bad484f4e 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -35,6 +35,13 @@ def __init__( super().__init__() self._c_csc_graph = c_csc_graph + def __del__(self): + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + if hasattr(self, "_is_inplace_pinned"): + cudart = torch.cuda.cudart() + for tensor in self._is_inplace_pinned: + assert cudart.cudaHostUnregister(tensor.data_ptr()) == 0 + @property def total_num_nodes(self) -> int: """Returns the number of nodes in the graph. @@ -974,6 +981,7 @@ def _pin(x): def pin_memory_(self): """Copy `FusedCSCSamplingGraph` to the pinned memory in-place.""" + self._is_inplace_pinned = set() # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 cudart = torch.cuda.cudart() @@ -984,26 +992,17 @@ def _pin(x): elif ( isinstance(x, torch.Tensor) and not x.is_pinned() - and not x.is_cuda + and x.device.type == "cpu" and x.is_contiguous() ): - # x.share_memory_() - ptr = x.data_ptr() assert ( cudart.cudaHostRegister( - ptr, x.numel() * x.element_size(), 0 + x.data_ptr(), x.numel() * x.element_size(), 0 ) == 0 ) - # assert x.is_shared() - assert x.is_pinned() - print(x.is_pinned(), x.is_shared()) - - def new_del(self): - assert cudart.cudaHostUnregister(ptr) == 0 - torch.Tensor.__del__(self) - x.__del__ = new_del + self._is_inplace_pinned.add(x) return x diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index 3952eb0a84b4..3649fc32811b 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -84,6 +84,13 @@ def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None): self._tensor = torch_feature.contiguous() self._metadata = metadata + def __del__(self): + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + if hasattr(self, "_is_inplace_pinned"): + cudart = torch.cuda.cudart() + for tensor in self._is_inplace_pinned: + assert cudart.cudaHostUnregister(tensor.data_ptr()) == 0 + def read(self, ids: torch.Tensor = None): """Read the feature by index. @@ -169,14 +176,24 @@ def metadata(self): def pin_memory_(self): """In-place operation to copy the feature to pinned memory.""" - self._tensor = self._tensor.pin_memory() + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + x = self._tensor + if not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous(): + assert ( + torch.cuda.cudart().cudaHostRegister( + x.data_ptr(), x.numel() * x.element_size(), 0 + ) + == 0 + ) + + self._is_inplace_pinned.add(x) def to(self, device): # pylint: disable=invalid-name """Copy `TorchBasedFeature` to the specified device.""" # copy.copy is a shallow copy so it does not copy tensor memory. self2 = copy.copy(self) if device == "pinned": - self2.pin_memory_() + self2._tensor = self2.pin_memory() else: self2._tensor = self2._tensor.to(device) return self2 From a5b3e9ddfb6ff413eb88a16de9e802757ef08c88 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Wed, 17 Jan 2024 04:19:47 +0000 Subject: [PATCH 03/12] make pin_memory_ return self for convenience --- python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 2 +- python/dgl/graphbolt/impl/torch_based_feature_store.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index a33bad484f4e..52db8d8ffc46 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -1006,7 +1006,7 @@ def _pin(x): return x - self._apply_to_members(_pin) + return self._apply_to_members(_pin) def fused_csc_sampling_graph( diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index 3649fc32811b..3e1fb2f4091b 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -187,6 +187,7 @@ def pin_memory_(self): ) self._is_inplace_pinned.add(x) + return self def to(self, device): # pylint: disable=invalid-name """Copy `TorchBasedFeature` to the specified device.""" @@ -282,6 +283,7 @@ def pin_memory_(self): """In-place operation to copy the feature store to pinned memory.""" for feature in self._features.values(): feature.pin_memory_() + return self def to(self, device): # pylint: disable=invalid-name """Copy `TorchBasedFeatureStore` to the specified device.""" From 3fef383ce83f43f363aedaae362b9a05a3d99d05 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Wed, 17 Jan 2024 04:23:49 +0000 Subject: [PATCH 04/12] fix bug --- examples/sampling/graphbolt/node_classification.py | 4 +--- python/dgl/graphbolt/impl/torch_based_feature_store.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/sampling/graphbolt/node_classification.py b/examples/sampling/graphbolt/node_classification.py index d0854b844968..f589e667b455 100644 --- a/examples/sampling/graphbolt/node_classification.py +++ b/examples/sampling/graphbolt/node_classification.py @@ -384,10 +384,8 @@ def main(args): dataset = gb.BuiltinDataset("ogbn-products").load() # Move the dataset to the selected storage. - # graph = dataset.graph.to(args.storage_device) + graph = dataset.graph.to(args.storage_device) features = dataset.feature.to(args.storage_device) - dataset.graph.pin_memory_() - graph = dataset.graph train_set = dataset.tasks[0].train_set valid_set = dataset.tasks[0].validation_set diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index 3e1fb2f4091b..e5fbc7b68dea 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -176,6 +176,7 @@ def metadata(self): def pin_memory_(self): """In-place operation to copy the feature to pinned memory.""" + self._is_inplace_pinned = set() # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 x = self._tensor if not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous(): From abf466c5c6b8fd75e98bdaa9e50931e450e0acfd Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Wed, 17 Jan 2024 04:25:24 +0000 Subject: [PATCH 05/12] remove unnecessary check --- python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 52db8d8ffc46..4351c1a3321b 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -990,8 +990,7 @@ def _pin(x): if hasattr(x, "pin_memory_"): x.pin_memory_() elif ( - isinstance(x, torch.Tensor) - and not x.is_pinned() + not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous() ): From 93c8a1851379446ac1e2ab21e17831996dafda04 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Wed, 17 Jan 2024 04:32:39 +0000 Subject: [PATCH 06/12] fix the bug --- python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 4351c1a3321b..52db8d8ffc46 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -990,7 +990,8 @@ def _pin(x): if hasattr(x, "pin_memory_"): x.pin_memory_() elif ( - not x.is_pinned() + isinstance(x, torch.Tensor) + and not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous() ): From cdfb60cdc4671a13e3759c6e43f10fe96eb1c735 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Wed, 17 Jan 2024 05:03:32 +0000 Subject: [PATCH 07/12] fix bug --- python/dgl/graphbolt/impl/torch_based_feature_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index e5fbc7b68dea..f0bdf45d4aab 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -195,7 +195,7 @@ def to(self, device): # pylint: disable=invalid-name # copy.copy is a shallow copy so it does not copy tensor memory. self2 = copy.copy(self) if device == "pinned": - self2._tensor = self2.pin_memory() + self2._tensor = self2._tensor.pin_memory() else: self2._tensor = self2._tensor.to(device) return self2 From b7ca85203926fee0d1101b03ec490efe3570815c Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Wed, 17 Jan 2024 05:08:06 +0000 Subject: [PATCH 08/12] in place operation should not return self --- python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 2 +- python/dgl/graphbolt/impl/torch_based_feature_store.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 52db8d8ffc46..a33bad484f4e 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -1006,7 +1006,7 @@ def _pin(x): return x - return self._apply_to_members(_pin) + self._apply_to_members(_pin) def fused_csc_sampling_graph( diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index f0bdf45d4aab..4757a4e2ed5b 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -188,7 +188,6 @@ def pin_memory_(self): ) self._is_inplace_pinned.add(x) - return self def to(self, device): # pylint: disable=invalid-name """Copy `TorchBasedFeature` to the specified device.""" @@ -284,7 +283,6 @@ def pin_memory_(self): """In-place operation to copy the feature store to pinned memory.""" for feature in self._features.values(): feature.pin_memory_() - return self def to(self, device): # pylint: disable=invalid-name """Copy `TorchBasedFeatureStore` to the specified device.""" From a28a3c8dba134a908d06cfb3d3d74ad8419f5bcd Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Wed, 17 Jan 2024 05:21:22 +0000 Subject: [PATCH 09/12] add context to the comment --- python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 6 ++++++ python/dgl/graphbolt/impl/torch_based_feature_store.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index a33bad484f4e..7ac3b4d5989a 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -36,6 +36,9 @@ def __init__( self._c_csc_graph = c_csc_graph def __del__(self): + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 if hasattr(self, "_is_inplace_pinned"): cudart = torch.cuda.cudart() @@ -983,6 +986,9 @@ def pin_memory_(self): """Copy `FusedCSCSamplingGraph` to the pinned memory in-place.""" self._is_inplace_pinned = set() + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 cudart = torch.cuda.cudart() diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index 4757a4e2ed5b..eb4d565ce356 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -85,6 +85,9 @@ def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None): self._metadata = metadata def __del__(self): + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 if hasattr(self, "_is_inplace_pinned"): cudart = torch.cuda.cudart() @@ -177,6 +180,9 @@ def metadata(self): def pin_memory_(self): """In-place operation to copy the feature to pinned memory.""" self._is_inplace_pinned = set() + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 x = self._tensor if not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous(): From 1c88ff738dbad7d8058f7c7e132909d03479e045 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Thu, 18 Jan 2024 00:20:30 +0000 Subject: [PATCH 10/12] refine the implementation --- python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 11 +++++------ .../dgl/graphbolt/impl/torch_based_feature_store.py | 10 +++++----- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 7ac3b4d5989a..c2d9390c473d 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -34,16 +34,17 @@ def __init__( ): super().__init__() self._c_csc_graph = c_csc_graph + self._is_inplace_pinned = set() def __del__(self): # torch.Tensor.pin_memory() is not an inplace operation. To make it # truly in-place, we need to use cudaHostRegister. Then, we need to use # cudaHostUnregister to unpin the tensor in the destructor. # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 - if hasattr(self, "_is_inplace_pinned"): - cudart = torch.cuda.cudart() - for tensor in self._is_inplace_pinned: - assert cudart.cudaHostUnregister(tensor.data_ptr()) == 0 + for tensor in self._is_inplace_pinned: + assert ( + torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0 + ) @property def total_num_nodes(self) -> int: @@ -984,8 +985,6 @@ def _pin(x): def pin_memory_(self): """Copy `FusedCSCSamplingGraph` to the pinned memory in-place.""" - self._is_inplace_pinned = set() - # torch.Tensor.pin_memory() is not an inplace operation. To make it # truly in-place, we need to use cudaHostRegister. Then, we need to use # cudaHostUnregister to unpin the tensor in the destructor. diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index eb4d565ce356..4df1be182882 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -83,16 +83,17 @@ def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None): # Make sure the tensor is contiguous. self._tensor = torch_feature.contiguous() self._metadata = metadata + self._is_inplace_pinned = set() def __del__(self): # torch.Tensor.pin_memory() is not an inplace operation. To make it # truly in-place, we need to use cudaHostRegister. Then, we need to use # cudaHostUnregister to unpin the tensor in the destructor. # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 - if hasattr(self, "_is_inplace_pinned"): - cudart = torch.cuda.cudart() - for tensor in self._is_inplace_pinned: - assert cudart.cudaHostUnregister(tensor.data_ptr()) == 0 + for tensor in self._is_inplace_pinned: + assert ( + torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0 + ) def read(self, ids: torch.Tensor = None): """Read the feature by index. @@ -179,7 +180,6 @@ def metadata(self): def pin_memory_(self): """In-place operation to copy the feature to pinned memory.""" - self._is_inplace_pinned = set() # torch.Tensor.pin_memory() is not an inplace operation. To make it # truly in-place, we need to use cudaHostRegister. Then, we need to use # cudaHostUnregister to unpin the tensor in the destructor. From cf6a022eb3b88b0c96ec8a463f70af4316038c37 Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Thu, 18 Jan 2024 02:27:08 +0000 Subject: [PATCH 11/12] add pointer check to ensure pinning is truly in-place. --- .../pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py | 4 ++++ .../pytorch/graphbolt/impl/test_torch_based_feature_store.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py index b2f240e6279b..cb4035c62ba7 100644 --- a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py +++ b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py @@ -1601,10 +1601,14 @@ def test_csc_sampling_graph_to_device(device): def test_csc_sampling_graph_to_pinned_memory(): # Construct FusedCSCSamplingGraph. graph = create_fused_csc_sampling_graph() + ptr = graph.csc_indptr.data_ptr() # Copy to pinned_memory in-place. graph.pin_memory_() + # Check if pinning is truly in-place. + assert graph.csc_indptr.data_ptr() == ptr + is_graph_on_device_type(graph, "cpu") is_graph_pinned(graph) diff --git a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py index be4b43b79461..ff7aa8f912e6 100644 --- a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py +++ b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py @@ -221,6 +221,9 @@ def test_torch_based_pinned_feature(dtype, idtype, shape): feature = gb.TorchBasedFeature(tensor) feature.pin_memory_() + # Check if pinning is truly in-place. + assert feature._tensor.data_ptr() == tensor.data_ptr() + # Test read entire pinned feature, the result should be on cuda. assert torch.equal(feature.read(), test_tensor_cuda) assert feature.read().is_cuda From 04372f7242a84c3ee4b03202f6393afead976fad Mon Sep 17 00:00:00 2001 From: Muhammed Fatih Balin <m.f.balin@gmail.com> Date: Thu, 18 Jan 2024 02:30:13 +0000 Subject: [PATCH 12/12] address reviews. --- python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 4 +++- python/dgl/graphbolt/impl/torch_based_feature_store.py | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index c2d9390c473d..09df1e9e5799 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -998,8 +998,10 @@ def _pin(x): isinstance(x, torch.Tensor) and not x.is_pinned() and x.device.type == "cpu" - and x.is_contiguous() ): + assert ( + x.is_contiguous() + ), "Tensor pinning is only supported for contiguous tensors." assert ( cudart.cudaHostRegister( x.data_ptr(), x.numel() * x.element_size(), 0 diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index 4df1be182882..af77912ec9d5 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -185,7 +185,10 @@ def pin_memory_(self): # cudaHostUnregister to unpin the tensor in the destructor. # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 x = self._tensor - if not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous(): + if not x.is_pinned() and x.device.type == "cpu": + assert ( + x.is_contiguous() + ), "Tensor pinning is only supported for contiguous tensors." assert ( torch.cuda.cudart().cudaHostRegister( x.data_ptr(), x.numel() * x.element_size(), 0