From 9f15737e6a3996a2d6d8463a5cc0b9e45405ca03 Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Tue, 16 Jan 2024 22:12:32 -0500
Subject: [PATCH 01/12] initial code

---
 .../sampling/graphbolt/node_classification.py |  4 ++-
 .../impl/fused_csc_sampling_graph.py          | 31 ++++++++++++++++++-
 2 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/examples/sampling/graphbolt/node_classification.py b/examples/sampling/graphbolt/node_classification.py
index f589e667b455..d0854b844968 100644
--- a/examples/sampling/graphbolt/node_classification.py
+++ b/examples/sampling/graphbolt/node_classification.py
@@ -384,8 +384,10 @@ def main(args):
     dataset = gb.BuiltinDataset("ogbn-products").load()
 
     # Move the dataset to the selected storage.
-    graph = dataset.graph.to(args.storage_device)
+    # graph = dataset.graph.to(args.storage_device)
     features = dataset.feature.to(args.storage_device)
+    dataset.graph.pin_memory_()
+    graph = dataset.graph
 
     train_set = dataset.tasks[0].train_set
     valid_set = dataset.tasks[0].validation_set
diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
index 8f026b3c5095..462ac8b7e557 100644
--- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
@@ -975,8 +975,37 @@ def _pin(x):
     def pin_memory_(self):
         """Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
 
+        # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
+        cudart = torch.cuda.cudart()
+
         def _pin(x):
-            return x.pin_memory() if hasattr(x, "pin_memory") else x
+            if hasattr(x, "pin_memory_"):
+                x.pin_memory_()
+            elif (
+                isinstance(x, torch.Tensor)
+                and not x.is_pinned()
+                and not x.is_cuda
+                and x.is_contiguous()
+            ):
+                # x.share_memory_()
+                ptr = x.data_ptr()
+                assert (
+                    cudart.cudaHostRegister(
+                        ptr, x.numel() * x.element_size(), 0
+                    )
+                    == 0
+                )
+                # assert x.is_shared()
+                assert x.is_pinned()
+                print(x.is_pinned(), x.is_shared())
+
+                def new_del(self):
+                    assert cudart.cudaHostUnregister(ptr) == 0
+                    torch.Tensor.__del__(self)
+
+                x.__del__ = new_del
+
+            return x
 
         self._apply_to_members(_pin)
 

From e33ca341993fa134c6f600b55319b9a30c66e945 Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Wed, 17 Jan 2024 04:15:11 +0000
Subject: [PATCH 02/12] add implementation

---
 .../impl/fused_csc_sampling_graph.py          | 23 +++++++++----------
 .../impl/torch_based_feature_store.py         | 21 +++++++++++++++--
 2 files changed, 30 insertions(+), 14 deletions(-)

diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
index 462ac8b7e557..a33bad484f4e 100644
--- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
@@ -35,6 +35,13 @@ def __init__(
         super().__init__()
         self._c_csc_graph = c_csc_graph
 
+    def __del__(self):
+        # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
+        if hasattr(self, "_is_inplace_pinned"):
+            cudart = torch.cuda.cudart()
+            for tensor in self._is_inplace_pinned:
+                assert cudart.cudaHostUnregister(tensor.data_ptr()) == 0
+
     @property
     def total_num_nodes(self) -> int:
         """Returns the number of nodes in the graph.
@@ -974,6 +981,7 @@ def _pin(x):
 
     def pin_memory_(self):
         """Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
+        self._is_inplace_pinned = set()
 
         # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
         cudart = torch.cuda.cudart()
@@ -984,26 +992,17 @@ def _pin(x):
             elif (
                 isinstance(x, torch.Tensor)
                 and not x.is_pinned()
-                and not x.is_cuda
+                and x.device.type == "cpu"
                 and x.is_contiguous()
             ):
-                # x.share_memory_()
-                ptr = x.data_ptr()
                 assert (
                     cudart.cudaHostRegister(
-                        ptr, x.numel() * x.element_size(), 0
+                        x.data_ptr(), x.numel() * x.element_size(), 0
                     )
                     == 0
                 )
-                # assert x.is_shared()
-                assert x.is_pinned()
-                print(x.is_pinned(), x.is_shared())
-
-                def new_del(self):
-                    assert cudart.cudaHostUnregister(ptr) == 0
-                    torch.Tensor.__del__(self)
 
-                x.__del__ = new_del
+                self._is_inplace_pinned.add(x)
 
             return x
 
diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py
index 3952eb0a84b4..3649fc32811b 100644
--- a/python/dgl/graphbolt/impl/torch_based_feature_store.py
+++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py
@@ -84,6 +84,13 @@ def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None):
         self._tensor = torch_feature.contiguous()
         self._metadata = metadata
 
+    def __del__(self):
+        # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
+        if hasattr(self, "_is_inplace_pinned"):
+            cudart = torch.cuda.cudart()
+            for tensor in self._is_inplace_pinned:
+                assert cudart.cudaHostUnregister(tensor.data_ptr()) == 0
+
     def read(self, ids: torch.Tensor = None):
         """Read the feature by index.
 
@@ -169,14 +176,24 @@ def metadata(self):
 
     def pin_memory_(self):
         """In-place operation to copy the feature to pinned memory."""
-        self._tensor = self._tensor.pin_memory()
+        # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
+        x = self._tensor
+        if not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous():
+            assert (
+                torch.cuda.cudart().cudaHostRegister(
+                    x.data_ptr(), x.numel() * x.element_size(), 0
+                )
+                == 0
+            )
+
+            self._is_inplace_pinned.add(x)
 
     def to(self, device):  # pylint: disable=invalid-name
         """Copy `TorchBasedFeature` to the specified device."""
         # copy.copy is a shallow copy so it does not copy tensor memory.
         self2 = copy.copy(self)
         if device == "pinned":
-            self2.pin_memory_()
+            self2._tensor = self2.pin_memory()
         else:
             self2._tensor = self2._tensor.to(device)
         return self2

From a5b3e9ddfb6ff413eb88a16de9e802757ef08c88 Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Wed, 17 Jan 2024 04:19:47 +0000
Subject: [PATCH 03/12] make pin_memory_ return self for convenience

---
 python/dgl/graphbolt/impl/fused_csc_sampling_graph.py  | 2 +-
 python/dgl/graphbolt/impl/torch_based_feature_store.py | 2 ++
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
index a33bad484f4e..52db8d8ffc46 100644
--- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
@@ -1006,7 +1006,7 @@ def _pin(x):
 
             return x
 
-        self._apply_to_members(_pin)
+        return self._apply_to_members(_pin)
 
 
 def fused_csc_sampling_graph(
diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py
index 3649fc32811b..3e1fb2f4091b 100644
--- a/python/dgl/graphbolt/impl/torch_based_feature_store.py
+++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py
@@ -187,6 +187,7 @@ def pin_memory_(self):
             )
 
             self._is_inplace_pinned.add(x)
+        return self
 
     def to(self, device):  # pylint: disable=invalid-name
         """Copy `TorchBasedFeature` to the specified device."""
@@ -282,6 +283,7 @@ def pin_memory_(self):
         """In-place operation to copy the feature store to pinned memory."""
         for feature in self._features.values():
             feature.pin_memory_()
+        return self
 
     def to(self, device):  # pylint: disable=invalid-name
         """Copy `TorchBasedFeatureStore` to the specified device."""

From 3fef383ce83f43f363aedaae362b9a05a3d99d05 Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Wed, 17 Jan 2024 04:23:49 +0000
Subject: [PATCH 04/12] fix bug

---
 examples/sampling/graphbolt/node_classification.py     | 4 +---
 python/dgl/graphbolt/impl/torch_based_feature_store.py | 1 +
 2 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/examples/sampling/graphbolt/node_classification.py b/examples/sampling/graphbolt/node_classification.py
index d0854b844968..f589e667b455 100644
--- a/examples/sampling/graphbolt/node_classification.py
+++ b/examples/sampling/graphbolt/node_classification.py
@@ -384,10 +384,8 @@ def main(args):
     dataset = gb.BuiltinDataset("ogbn-products").load()
 
     # Move the dataset to the selected storage.
-    # graph = dataset.graph.to(args.storage_device)
+    graph = dataset.graph.to(args.storage_device)
     features = dataset.feature.to(args.storage_device)
-    dataset.graph.pin_memory_()
-    graph = dataset.graph
 
     train_set = dataset.tasks[0].train_set
     valid_set = dataset.tasks[0].validation_set
diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py
index 3e1fb2f4091b..e5fbc7b68dea 100644
--- a/python/dgl/graphbolt/impl/torch_based_feature_store.py
+++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py
@@ -176,6 +176,7 @@ def metadata(self):
 
     def pin_memory_(self):
         """In-place operation to copy the feature to pinned memory."""
+        self._is_inplace_pinned = set()
         # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
         x = self._tensor
         if not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous():

From abf466c5c6b8fd75e98bdaa9e50931e450e0acfd Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Wed, 17 Jan 2024 04:25:24 +0000
Subject: [PATCH 05/12] remove unnecessary check

---
 python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
index 52db8d8ffc46..4351c1a3321b 100644
--- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
@@ -990,8 +990,7 @@ def _pin(x):
             if hasattr(x, "pin_memory_"):
                 x.pin_memory_()
             elif (
-                isinstance(x, torch.Tensor)
-                and not x.is_pinned()
+                not x.is_pinned()
                 and x.device.type == "cpu"
                 and x.is_contiguous()
             ):

From 93c8a1851379446ac1e2ab21e17831996dafda04 Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Wed, 17 Jan 2024 04:32:39 +0000
Subject: [PATCH 06/12] fix the bug

---
 python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
index 4351c1a3321b..52db8d8ffc46 100644
--- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
@@ -990,7 +990,8 @@ def _pin(x):
             if hasattr(x, "pin_memory_"):
                 x.pin_memory_()
             elif (
-                not x.is_pinned()
+                isinstance(x, torch.Tensor)
+                and not x.is_pinned()
                 and x.device.type == "cpu"
                 and x.is_contiguous()
             ):

From cdfb60cdc4671a13e3759c6e43f10fe96eb1c735 Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Wed, 17 Jan 2024 05:03:32 +0000
Subject: [PATCH 07/12] fix bug

---
 python/dgl/graphbolt/impl/torch_based_feature_store.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py
index e5fbc7b68dea..f0bdf45d4aab 100644
--- a/python/dgl/graphbolt/impl/torch_based_feature_store.py
+++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py
@@ -195,7 +195,7 @@ def to(self, device):  # pylint: disable=invalid-name
         # copy.copy is a shallow copy so it does not copy tensor memory.
         self2 = copy.copy(self)
         if device == "pinned":
-            self2._tensor = self2.pin_memory()
+            self2._tensor = self2._tensor.pin_memory()
         else:
             self2._tensor = self2._tensor.to(device)
         return self2

From b7ca85203926fee0d1101b03ec490efe3570815c Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Wed, 17 Jan 2024 05:08:06 +0000
Subject: [PATCH 08/12] in place operation should not return self

---
 python/dgl/graphbolt/impl/fused_csc_sampling_graph.py  | 2 +-
 python/dgl/graphbolt/impl/torch_based_feature_store.py | 2 --
 2 files changed, 1 insertion(+), 3 deletions(-)

diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
index 52db8d8ffc46..a33bad484f4e 100644
--- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
@@ -1006,7 +1006,7 @@ def _pin(x):
 
             return x
 
-        return self._apply_to_members(_pin)
+        self._apply_to_members(_pin)
 
 
 def fused_csc_sampling_graph(
diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py
index f0bdf45d4aab..4757a4e2ed5b 100644
--- a/python/dgl/graphbolt/impl/torch_based_feature_store.py
+++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py
@@ -188,7 +188,6 @@ def pin_memory_(self):
             )
 
             self._is_inplace_pinned.add(x)
-        return self
 
     def to(self, device):  # pylint: disable=invalid-name
         """Copy `TorchBasedFeature` to the specified device."""
@@ -284,7 +283,6 @@ def pin_memory_(self):
         """In-place operation to copy the feature store to pinned memory."""
         for feature in self._features.values():
             feature.pin_memory_()
-        return self
 
     def to(self, device):  # pylint: disable=invalid-name
         """Copy `TorchBasedFeatureStore` to the specified device."""

From a28a3c8dba134a908d06cfb3d3d74ad8419f5bcd Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Wed, 17 Jan 2024 05:21:22 +0000
Subject: [PATCH 09/12] add context to the comment

---
 python/dgl/graphbolt/impl/fused_csc_sampling_graph.py  | 6 ++++++
 python/dgl/graphbolt/impl/torch_based_feature_store.py | 6 ++++++
 2 files changed, 12 insertions(+)

diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
index a33bad484f4e..7ac3b4d5989a 100644
--- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
@@ -36,6 +36,9 @@ def __init__(
         self._c_csc_graph = c_csc_graph
 
     def __del__(self):
+        # torch.Tensor.pin_memory() is not an inplace operation. To make it
+        # truly in-place, we need to use cudaHostRegister. Then, we need to use
+        # cudaHostUnregister to unpin the tensor in the destructor.
         # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
         if hasattr(self, "_is_inplace_pinned"):
             cudart = torch.cuda.cudart()
@@ -983,6 +986,9 @@ def pin_memory_(self):
         """Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
         self._is_inplace_pinned = set()
 
+        # torch.Tensor.pin_memory() is not an inplace operation. To make it
+        # truly in-place, we need to use cudaHostRegister. Then, we need to use
+        # cudaHostUnregister to unpin the tensor in the destructor.
         # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
         cudart = torch.cuda.cudart()
 
diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py
index 4757a4e2ed5b..eb4d565ce356 100644
--- a/python/dgl/graphbolt/impl/torch_based_feature_store.py
+++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py
@@ -85,6 +85,9 @@ def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None):
         self._metadata = metadata
 
     def __del__(self):
+        # torch.Tensor.pin_memory() is not an inplace operation. To make it
+        # truly in-place, we need to use cudaHostRegister. Then, we need to use
+        # cudaHostUnregister to unpin the tensor in the destructor.
         # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
         if hasattr(self, "_is_inplace_pinned"):
             cudart = torch.cuda.cudart()
@@ -177,6 +180,9 @@ def metadata(self):
     def pin_memory_(self):
         """In-place operation to copy the feature to pinned memory."""
         self._is_inplace_pinned = set()
+        # torch.Tensor.pin_memory() is not an inplace operation. To make it
+        # truly in-place, we need to use cudaHostRegister. Then, we need to use
+        # cudaHostUnregister to unpin the tensor in the destructor.
         # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
         x = self._tensor
         if not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous():

From 1c88ff738dbad7d8058f7c7e132909d03479e045 Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Thu, 18 Jan 2024 00:20:30 +0000
Subject: [PATCH 10/12] refine the implementation

---
 python/dgl/graphbolt/impl/fused_csc_sampling_graph.py | 11 +++++------
 .../dgl/graphbolt/impl/torch_based_feature_store.py   | 10 +++++-----
 2 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
index 7ac3b4d5989a..c2d9390c473d 100644
--- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
@@ -34,16 +34,17 @@ def __init__(
     ):
         super().__init__()
         self._c_csc_graph = c_csc_graph
+        self._is_inplace_pinned = set()
 
     def __del__(self):
         # torch.Tensor.pin_memory() is not an inplace operation. To make it
         # truly in-place, we need to use cudaHostRegister. Then, we need to use
         # cudaHostUnregister to unpin the tensor in the destructor.
         # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
-        if hasattr(self, "_is_inplace_pinned"):
-            cudart = torch.cuda.cudart()
-            for tensor in self._is_inplace_pinned:
-                assert cudart.cudaHostUnregister(tensor.data_ptr()) == 0
+        for tensor in self._is_inplace_pinned:
+            assert (
+                torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0
+            )
 
     @property
     def total_num_nodes(self) -> int:
@@ -984,8 +985,6 @@ def _pin(x):
 
     def pin_memory_(self):
         """Copy `FusedCSCSamplingGraph` to the pinned memory in-place."""
-        self._is_inplace_pinned = set()
-
         # torch.Tensor.pin_memory() is not an inplace operation. To make it
         # truly in-place, we need to use cudaHostRegister. Then, we need to use
         # cudaHostUnregister to unpin the tensor in the destructor.
diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py
index eb4d565ce356..4df1be182882 100644
--- a/python/dgl/graphbolt/impl/torch_based_feature_store.py
+++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py
@@ -83,16 +83,17 @@ def __init__(self, torch_feature: torch.Tensor, metadata: Dict = None):
         # Make sure the tensor is contiguous.
         self._tensor = torch_feature.contiguous()
         self._metadata = metadata
+        self._is_inplace_pinned = set()
 
     def __del__(self):
         # torch.Tensor.pin_memory() is not an inplace operation. To make it
         # truly in-place, we need to use cudaHostRegister. Then, we need to use
         # cudaHostUnregister to unpin the tensor in the destructor.
         # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
-        if hasattr(self, "_is_inplace_pinned"):
-            cudart = torch.cuda.cudart()
-            for tensor in self._is_inplace_pinned:
-                assert cudart.cudaHostUnregister(tensor.data_ptr()) == 0
+        for tensor in self._is_inplace_pinned:
+            assert (
+                torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0
+            )
 
     def read(self, ids: torch.Tensor = None):
         """Read the feature by index.
@@ -179,7 +180,6 @@ def metadata(self):
 
     def pin_memory_(self):
         """In-place operation to copy the feature to pinned memory."""
-        self._is_inplace_pinned = set()
         # torch.Tensor.pin_memory() is not an inplace operation. To make it
         # truly in-place, we need to use cudaHostRegister. Then, we need to use
         # cudaHostUnregister to unpin the tensor in the destructor.

From cf6a022eb3b88b0c96ec8a463f70af4316038c37 Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Thu, 18 Jan 2024 02:27:08 +0000
Subject: [PATCH 11/12] add pointer check to ensure pinning is truly in-place.

---
 .../pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py   | 4 ++++
 .../pytorch/graphbolt/impl/test_torch_based_feature_store.py  | 3 +++
 2 files changed, 7 insertions(+)

diff --git a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
index b2f240e6279b..cb4035c62ba7 100644
--- a/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+++ b/tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
@@ -1601,10 +1601,14 @@ def test_csc_sampling_graph_to_device(device):
 def test_csc_sampling_graph_to_pinned_memory():
     # Construct FusedCSCSamplingGraph.
     graph = create_fused_csc_sampling_graph()
+    ptr = graph.csc_indptr.data_ptr()
 
     # Copy to pinned_memory in-place.
     graph.pin_memory_()
 
+    # Check if pinning is truly in-place.
+    assert graph.csc_indptr.data_ptr() == ptr
+
     is_graph_on_device_type(graph, "cpu")
     is_graph_pinned(graph)
 
diff --git a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py
index be4b43b79461..ff7aa8f912e6 100644
--- a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py
+++ b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py
@@ -221,6 +221,9 @@ def test_torch_based_pinned_feature(dtype, idtype, shape):
     feature = gb.TorchBasedFeature(tensor)
     feature.pin_memory_()
 
+    # Check if pinning is truly in-place.
+    assert feature._tensor.data_ptr() == tensor.data_ptr()
+
     # Test read entire pinned feature, the result should be on cuda.
     assert torch.equal(feature.read(), test_tensor_cuda)
     assert feature.read().is_cuda

From 04372f7242a84c3ee4b03202f6393afead976fad Mon Sep 17 00:00:00 2001
From: Muhammed Fatih Balin <m.f.balin@gmail.com>
Date: Thu, 18 Jan 2024 02:30:13 +0000
Subject: [PATCH 12/12] address reviews.

---
 python/dgl/graphbolt/impl/fused_csc_sampling_graph.py  | 4 +++-
 python/dgl/graphbolt/impl/torch_based_feature_store.py | 5 ++++-
 2 files changed, 7 insertions(+), 2 deletions(-)

diff --git a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
index c2d9390c473d..09df1e9e5799 100644
--- a/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+++ b/python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
@@ -998,8 +998,10 @@ def _pin(x):
                 isinstance(x, torch.Tensor)
                 and not x.is_pinned()
                 and x.device.type == "cpu"
-                and x.is_contiguous()
             ):
+                assert (
+                    x.is_contiguous()
+                ), "Tensor pinning is only supported for contiguous tensors."
                 assert (
                     cudart.cudaHostRegister(
                         x.data_ptr(), x.numel() * x.element_size(), 0
diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py
index 4df1be182882..af77912ec9d5 100644
--- a/python/dgl/graphbolt/impl/torch_based_feature_store.py
+++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py
@@ -185,7 +185,10 @@ def pin_memory_(self):
         # cudaHostUnregister to unpin the tensor in the destructor.
         # https://github.com/pytorch/pytorch/issues/32167#issuecomment-753551842
         x = self._tensor
-        if not x.is_pinned() and x.device.type == "cpu" and x.is_contiguous():
+        if not x.is_pinned() and x.device.type == "cpu":
+            assert (
+                x.is_contiguous()
+            ), "Tensor pinning is only supported for contiguous tensors."
             assert (
                 torch.cuda.cudart().cudaHostRegister(
                     x.data_ptr(), x.numel() * x.element_size(), 0