diff --git a/examples/sampling/graphbolt/lightning/node_classification.py b/examples/sampling/graphbolt/lightning/node_classification.py index 69bc9954dfcf..fdfc633eda18 100644 --- a/examples/sampling/graphbolt/lightning/node_classification.py +++ b/examples/sampling/graphbolt/lightning/node_classification.py @@ -58,8 +58,12 @@ def __init__(self, in_feats, n_hidden, n_classes): self.dropout = nn.Dropout(0.5) self.n_hidden = n_hidden self.n_classes = n_classes - self.train_acc = Accuracy(task="multiclass", num_classes=n_classes) - self.val_acc = Accuracy(task="multiclass", num_classes=n_classes) + self.train_acc = Accuracy( + task="multiclass", num_classes=n_classes, top_k=1 + ) + self.val_acc = Accuracy( + task="multiclass", num_classes=n_classes, top_k=1 + ) def forward(self, blocks, x): h = x @@ -133,13 +137,14 @@ def configure_optimizers(self): class DataModule(LightningDataModule): - def __init__(self, dataset, fanouts, batch_size, num_workers): + def __init__(self, dataset, fanouts, batch_size, num_workers, device): super().__init__() self.fanouts = fanouts self.batch_size = batch_size self.num_workers = num_workers - self.feature_store = dataset.feature - self.graph = dataset.graph + self.feature_store = dataset.feature.to(device) + self.graph = dataset.graph.to(device) + self.device = "cuda" if device != "cpu" else "cpu" self.train_set = dataset.tasks[0].train_set self.valid_set = dataset.tasks[0].validation_set self.num_classes = dataset.tasks[0].metadata["num_classes"] @@ -148,9 +153,10 @@ def create_dataloader(self, node_set, is_train): datapipe = gb.ItemSampler( node_set, batch_size=self.batch_size, - shuffle=True, - drop_last=True, + shuffle=is_train, + drop_last=is_train, ) + datapipe = datapipe.copy_to(self.device, ["seed_nodes"]) sampler = ( datapipe.sample_layer_neighbor if is_train @@ -203,14 +209,25 @@ def val_dataloader(self): default=0, help="number of workers (default: 0)", ) + parser.add_argument( + "--storage_device", + default="pinned", + choices=["cpu", "pinned", "cuda"], + help="Moves the dataset into the selected storage", + ) args = parser.parse_args() + if not torch.cuda.is_available(): + args.num_gpus = 0 + args.storage_device = "cpu" + dataset = gb.BuiltinDataset("ogbn-products").load() datamodule = DataModule( dataset, [10, 10, 10], args.batch_size, args.num_workers, + args.storage_device, ) in_size = dataset.feature.size("node", None, "feat")[0] model = SAGE(in_size, 256, datamodule.num_classes) @@ -225,7 +242,7 @@ def val_dataloader(self): # https://lightning.ai/docs/pytorch/stable/common/trainer.html. ######################################################################## trainer = Trainer( - accelerator="gpu", + accelerator="gpu" if args.num_gpus > 0 else "cpu", devices=args.num_gpus, max_epochs=args.epochs, callbacks=[checkpoint_callback, early_stopping_callback], diff --git a/examples/sampling/pyg/node_classification.py b/examples/sampling/pyg/node_classification.py index a34fbf4abecc..411e1f93e971 100644 --- a/examples/sampling/pyg/node_classification.py +++ b/examples/sampling/pyg/node_classification.py @@ -92,9 +92,9 @@ def create_dataloader(dataset_set, graph, feature, device, is_train): # (HIGHLIGHT) Create a data loader for efficiently loading graph data. # # - 'ItemSampler' samples mini-batches of node IDs from the dataset. + # - 'CopyTo' copies the fetched data to the specified device. # - 'sample_neighbor' performs neighbor sampling on the graph. # - 'FeatureFetcher' fetches node features based on the sampled subgraph. - # - 'CopyTo' copies the fetched data to the specified device. ##################################################################### # Create a datapipe for mini-batch sampling with a specific neighbor fanout. @@ -108,12 +108,12 @@ def create_dataloader(dataset_set, graph, feature, device, is_train): datapipe = gb.ItemSampler( dataset_set, batch_size=1024, shuffle=is_train, drop_last=is_train ) + # Copy the data to the specified device. + datapipe = datapipe.copy_to(device=device, extra_attrs=["seed_nodes"]) # Sample neighbors for each node in the mini-batch. datapipe = datapipe.sample_neighbor(graph, [10, 10, 10]) # Fetch node features for the sampled subgraph. datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"]) - # Copy the data to the specified device. - datapipe = datapipe.copy_to(device=device) # Create and return a DataLoader to handle data loading. dataloader = gb.DataLoader(datapipe, num_workers=0) @@ -195,13 +195,13 @@ def main(): args = parser.parse_args() dataset_name = args.dataset dataset = gb.BuiltinDataset(dataset_name).load() - graph = dataset.graph - feature = dataset.feature + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + graph = dataset.graph.to(device) + feature = dataset.feature.to(device) train_set = dataset.tasks[0].train_set valid_set = dataset.tasks[0].validation_set test_set = dataset.tasks[0].test_set num_classes = dataset.tasks[0].metadata["num_classes"] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_dataloader = create_dataloader( train_set, graph, feature, device, is_train=True