diff --git a/python/dgl/distributed/dist_dataloader.py b/python/dgl/distributed/dist_dataloader.py index 117379b2b4a3..833fb937a43c 100644 --- a/python/dgl/distributed/dist_dataloader.py +++ b/python/dgl/distributed/dist_dataloader.py @@ -311,7 +311,7 @@ def collate(self, items): raise NotImplementedError @staticmethod - def add_edge_attribute_to_graph(g, data_name): + def add_edge_attribute_to_graph(g, data_name, gb_padding): """Add data into the graph as an edge attribute. For some cases such as prob/mask-based sampling on GraphBolt partitions, @@ -327,9 +327,11 @@ def add_edge_attribute_to_graph(g, data_name): The graph. data_name : str The name of data that's stored in DistGraph.ndata/edata. + gb_padding : int, optional + The padding value for GraphBolt partitions' new edge_attributes. """ if g._use_graphbolt and data_name: - g.add_edge_attribute(data_name) + g.add_edge_attribute(data_name, gb_padding) class NodeCollator(Collator): @@ -344,6 +346,11 @@ class NodeCollator(Collator): The node set to compute outputs. graph_sampler : dgl.dataloading.BlockSampler The neighborhood sampler. + gb_padding : int, optional + The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None. + e.g. prob/mask-based sampling. + Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors. + The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt. Examples -------- @@ -366,7 +373,7 @@ class NodeCollator(Collator): :doc:`Minibatch Training Tutorials `. """ - def __init__(self, g, nids, graph_sampler): + def __init__(self, g, nids, graph_sampler, gb_padding=1): self.g = g if not isinstance(nids, Mapping): assert ( @@ -380,7 +387,7 @@ def __init__(self, g, nids, graph_sampler): # Add prob/mask into graphbolt partition's edge attributes if needed. if hasattr(self.graph_sampler, "prob"): Collator.add_edge_attribute_to_graph( - self.g, self.graph_sampler.prob + self.g, self.graph_sampler.prob, gb_padding ) @property @@ -508,8 +515,11 @@ class EdgeCollator(Collator): A set of builtin negative samplers are provided in :ref:`the negative sampling module `. - - Examples + gb_padding : int, optional + The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None. + e.g. prob/mask-based sampling. + Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors. + The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt. -------- The following example shows how to train a 3-layer GNN for edge classification on a set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes @@ -612,6 +622,7 @@ def __init__( reverse_eids=None, reverse_etypes=None, negative_sampler=None, + gb_padding=1, ): self.g = g if not isinstance(eids, Mapping): @@ -642,7 +653,7 @@ def __init__( # Add prob/mask into graphbolt partition's edge attributes if needed. if hasattr(self.graph_sampler, "prob"): Collator.add_edge_attribute_to_graph( - self.g, self.graph_sampler.prob + self.g, self.graph_sampler.prob, gb_padding ) @property diff --git a/python/dgl/distributed/dist_graph.py b/python/dgl/distributed/dist_graph.py index d4e31ce02770..149d5d3fd878 100644 --- a/python/dgl/distributed/dist_graph.py +++ b/python/dgl/distributed/dist_graph.py @@ -143,15 +143,16 @@ def _copy_data_from_shared_mem(name, shape): class AddEdgeAttributeFromKVRequest(rpc.Request): """Add edge attribute from kvstore to local GraphBolt partition.""" - def __init__(self, name, kv_names): + def __init__(self, name, kv_names, padding): self._name = name self._kv_names = kv_names + self._padding = padding def __getstate__(self): - return self._name, self._kv_names + return self._name, self._kv_names, self._padding def __setstate__(self, state): - self._name, self._kv_names = state + self._name, self._kv_names, self._padding = state def process_request(self, server_state): # For now, this is only used to add prob/mask data to the graph. @@ -169,7 +170,13 @@ def process_request(self, server_state): gpb = server_state.partition_book # Initialize the edge attribute. num_edges = g.total_num_edges - attr_data = torch.zeros(num_edges, dtype=data_type) + + # Padding is used to fill missing edge attributes (e.g., 'prob' or 'mask') for certain edge types. + # In DGLGraph, some edges may lack these attributes or have them set to None, but DGL will still sample these edges. + # In contrast, GraphBolt samples edges based on specific attributes (e.g., 'mask' == 1) and will skip edges with missing attributes. + # To ensure consistent sampling behavior in GraphBolt, we pad missing attributes with default values (e.g., 'mask' = 1), + # allowing all edges to be sampled, even if their attributes were missing or None in DGLGraph. + attr_data = torch.full((num_edges,), self._padding, dtype=data_type) # Map data from kvstore to the local partition for inner edges only. num_inner_edges = gpb.metadata()[gpb.partid]["num_edges"] homo_eids = g.edge_attributes[EID][:num_inner_edges] @@ -1620,13 +1627,15 @@ def _get_edata_names(self, etype=None): edata_names.append(name) return edata_names - def add_edge_attribute(self, name): + def add_edge_attribute(self, name, padding): """Add an edge attribute into GraphBolt partition from edge data. Parameters ---------- name : str The name of the edge attribute. + padding : int, optional + The padding value for the new edge attribute. """ # Sanity checks. if not self._use_graphbolt: @@ -1643,7 +1652,7 @@ def add_edge_attribute(self, name): ] rpc.send_request( self._client._main_server_id, - AddEdgeAttributeFromKVRequest(name, kv_names), + AddEdgeAttributeFromKVRequest(name, kv_names, padding), ) # Wait for the response. assert rpc.recv_response()._name == name diff --git a/tests/distributed/test_distributed_sampling.py b/tests/distributed/test_distributed_sampling.py index 4ca8f7b130ac..556e091fbbe3 100644 --- a/tests/distributed/test_distributed_sampling.py +++ b/tests/distributed/test_distributed_sampling.py @@ -7,8 +7,9 @@ import unittest from pathlib import Path -import backend as F import dgl + +import dgl.backend as F import numpy as np import pytest import torch @@ -1858,6 +1859,81 @@ def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask): ) +def check_hetero_dist_edge_dataloader_gb( + tmpdir, num_server, use_graphbolt=True +): + generate_ip_config("rpc_ip_config.txt", num_server, num_server) + + g = create_random_hetero() + eids = torch.randperm(g.num_edges("r23"))[:10] + mask = torch.zeros(g.num_edges("r23"), dtype=torch.bool) + mask[eids] = True + + num_parts = num_server + + orig_nid_map, orig_eid_map = partition_graph( + g, + "test_sampling", + num_parts, + tmpdir, + num_hops=1, + part_method="metis", + return_mapping=True, + use_graphbolt=use_graphbolt, + store_eids=True, + ) + + part_config = tmpdir / "test_sampling.json" + + pserver_list = [] + ctx = mp.get_context("spawn") + for i in range(num_server): + p = ctx.Process( + target=start_server, + args=( + i, + tmpdir, + num_server > 1, + "test_sampling", + ["csc", "coo"], + True, + ), + ) + p.start() + time.sleep(1) + pserver_list.append(p) + + dgl.distributed.initialize("rpc_ip_config.txt", use_graphbolt=True) + dist_graph = DistGraph("test_sampling", part_config=part_config) + + os.environ["DGL_DIST_DEBUG"] = "1" + + edges = {("n2", "r23", "n3"): eids} + sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask") + loader = dgl.dataloading.DistEdgeDataLoader( + dist_graph, edges, sampler, batch_size=64 + ) + dgl.distributed.exit_client() + for p in pserver_list: + p.join() + assert p.exitcode == 0 + + block = next(iter(loader))[2][0] + assert block.num_src_nodes("n1") > 0 + assert block.num_edges("r12") > 0 + assert block.num_edges("r13") > 0 + assert block.num_edges("r23") > 0 + + +def test_hetero_dist_edge_dataloader_gb( + num_server=1, +): + reset_envs() + os.environ["DGL_DIST_MODE"] = "distributed" + with tempfile.TemporaryDirectory() as tmpdirname: + check_hetero_dist_edge_dataloader_gb(Path(tmpdirname), num_server) + + if __name__ == "__main__": import tempfile