diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py index 8f026b3c5095..09df1e9e5799 100644 --- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py +++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py @@ -34,6 +34,17 @@ def __init__( ): super().__init__() self._c_csc_graph = c_csc_graph + self._is_inplace_pinned = set() + + def __del__(self): + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + for tensor in self._is_inplace_pinned: + assert ( + torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0 + ) @property def total_num_nodes(self) -> int: @@ -974,9 +985,33 @@ def _pin(x): def pin_memory_(self): """Copy `FusedCSCSamplingGraph` to the pinned memory in-place.""" + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + cudart = torch.cuda.cudart() def _pin(x): - return x.pin_memory() if hasattr(x, "pin_memory") else x + if hasattr(x, "pin_memory_"): + x.pin_memory_() + elif ( + isinstance(x, torch.Tensor) + and not x.is_pinned() + and x.device.type == "cpu" + ): + assert ( + x.is_contiguous() + ), "Tensor pinning is only supported for contiguous tensors." + assert ( + cudart.cudaHostRegister( + x.data_ptr(), x.numel() * x.element_size(), 0 + ) + == 0 + ) + + self._is_inplace_pinned.add(x) + + return x self._apply_to_members(_pin) diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index 3952eb0a84b4..af77912ec9d5 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -83,6 +83,17 @@ def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None): # Make sure the tensor is contiguous. self._tensor = torch_feature.contiguous() self._metadata = metadata + self._is_inplace_pinned = set() + + def __del__(self): + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + for tensor in self._is_inplace_pinned: + assert ( + torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0 + ) def read(self, ids: torch.Tensor = None): """Read the feature by index. @@ -169,14 +180,30 @@ def metadata(self): def pin_memory_(self): """In-place operation to copy the feature to pinned memory.""" - self._tensor = self._tensor.pin_memory() + # torch.Tensor.pin_memory() is not an inplace operation. To make it + # truly in-place, we need to use cudaHostRegister. Then, we need to use + # cudaHostUnregister to unpin the tensor in the destructor. + # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842 + x = self._tensor + if not x.is_pinned() and x.device.type == "cpu": + assert ( + x.is_contiguous() + ), "Tensor pinning is only supported for contiguous tensors." + assert ( + torch.cuda.cudart().cudaHostRegister( + x.data_ptr(), x.numel() * x.element_size(), 0 + ) + == 0 + ) + + self._is_inplace_pinned.add(x) def to(self, device): # pylint: disable=invalid-name """Copy `TorchBasedFeature` to the specified device.""" # copy.copy is a shallow copy so it does not copy tensor memory. self2 = copy.copy(self) if device == "pinned": - self2.pin_memory_() + self2._tensor = self2._tensor.pin_memory() else: self2._tensor = self2._tensor.to(device) return self2 diff --git a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py index b2f240e6279b..cb4035c62ba7 100644 --- a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py +++ b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py @@ -1601,10 +1601,14 @@ def test_csc_sampling_graph_to_device(device): def test_csc_sampling_graph_to_pinned_memory(): # Construct FusedCSCSamplingGraph. graph = create_fused_csc_sampling_graph() + ptr = graph.csc_indptr.data_ptr() # Copy to pinned_memory in-place. graph.pin_memory_() + # Check if pinning is truly in-place. + assert graph.csc_indptr.data_ptr() == ptr + is_graph_on_device_type(graph, "cpu") is_graph_pinned(graph) diff --git a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py index be4b43b79461..ff7aa8f912e6 100644 --- a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py +++ b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py @@ -221,6 +221,9 @@ def test_torch_based_pinned_feature(dtype, idtype, shape): feature = gb.TorchBasedFeature(tensor) feature.pin_memory_() + # Check if pinning is truly in-place. + assert feature._tensor.data_ptr() == tensor.data_ptr() + # Test read entire pinned feature, the result should be on cuda. assert torch.equal(feature.read(), test_tensor_cuda) assert feature.read().is_cuda