diff --git a/python/dgl/graphbolt/impl/neighbor_sampler.py b/python/dgl/graphbolt/impl/neighbor_sampler.py index ef10d49d7584..605da8ff5ce3 100644 --- a/python/dgl/graphbolt/impl/neighbor_sampler.py +++ b/python/dgl/graphbolt/impl/neighbor_sampler.py @@ -1,9 +1,12 @@ """Neighbor subgraph samplers for GraphBolt.""" +from functools import partial + import torch from torch.utils.data import functional_datapipe from ..internal import compact_csc_format, unique_and_compact_csc_formats +from ..minibatch_transformer import MiniBatchTransformer from ..subgraph_sampler import SubgraphSampler from .sampled_subgraph_impl import SampledSubgraphImpl @@ -12,8 +15,66 @@ __all__ = ["NeighborSampler", "LayerNeighborSampler"] +@functional_datapipe("sample_per_layer") +class SamplePerLayer(MiniBatchTransformer): + """Sample neighbor edges from a graph for a single layer.""" + + def __init__(self, datapipe, sampler, fanout, replace, prob_name): + super().__init__(datapipe, self._sample_per_layer) + self.sampler = sampler + self.fanout = fanout + self.replace = replace + self.prob_name = prob_name + + def _sample_per_layer(self, minibatch): + subgraph = self.sampler( + minibatch._seed_nodes, self.fanout, self.replace, self.prob_name + ) + minibatch.sampled_subgraphs.insert(0, subgraph) + return minibatch + + +@functional_datapipe("compact_per_layer") +class CompactPerLayer(MiniBatchTransformer): + """Compact the sampled edges for a single layer.""" + + def __init__(self, datapipe, deduplicate): + super().__init__(datapipe, self._compact_per_layer) + self.deduplicate = deduplicate + + def _compact_per_layer(self, minibatch): + subgraph = minibatch.sampled_subgraphs[0] + seeds = minibatch._seed_nodes + if self.deduplicate: + ( + original_row_node_ids, + compacted_csc_format, + ) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds) + subgraph = SampledSubgraphImpl( + sampled_csc=compacted_csc_format, + original_column_node_ids=seeds, + original_row_node_ids=original_row_node_ids, + original_edge_ids=subgraph.original_edge_ids, + ) + else: + ( + original_row_node_ids, + compacted_csc_format, + ) = compact_csc_format(subgraph.sampled_csc, seeds) + subgraph = SampledSubgraphImpl( + sampled_csc=compacted_csc_format, + original_column_node_ids=seeds, + original_row_node_ids=original_row_node_ids, + original_edge_ids=subgraph.original_edge_ids, + ) + minibatch._seed_nodes = original_row_node_ids + minibatch.sampled_subgraphs[0] = subgraph + return minibatch + + @functional_datapipe("sample_neighbor") class NeighborSampler(SubgraphSampler): + # pylint: disable=abstract-method """Sample neighbor edges from a graph and return a subgraph. Functional name: :obj:`sample_neighbor`. @@ -95,6 +156,7 @@ class NeighborSampler(SubgraphSampler): )] """ + # pylint: disable=useless-super-delegation def __init__( self, datapipe, @@ -103,26 +165,19 @@ def __init__( replace=False, prob_name=None, deduplicate=True, + sampler=None, ): - super().__init__(datapipe) - self.graph = graph - # Convert fanouts to a list of tensors. - self.fanouts = [] - for fanout in fanouts: - if not isinstance(fanout, torch.Tensor): - fanout = torch.LongTensor([int(fanout)]) - self.fanouts.insert(0, fanout) - self.replace = replace - self.prob_name = prob_name - self.deduplicate = deduplicate - self.sampler = graph.sample_neighbors + if sampler is None: + sampler = graph.sample_neighbors + super().__init__( + datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler + ) - def sample_subgraphs(self, seeds, seeds_timestamp): - subgraphs = [] - num_layers = len(self.fanouts) + def _prepare(self, node_type_to_id, minibatch): + seeds = minibatch._seed_nodes # Enrich seeds with all node types. if isinstance(seeds, dict): - ntypes = list(self.graph.node_type_to_id.keys()) + ntypes = list(node_type_to_id.keys()) # Loop over different seeds to extract the device they are on. device = None dtype = None @@ -134,42 +189,37 @@ def sample_subgraphs(self, seeds, seeds_timestamp): seeds = { ntype: seeds.get(ntype, default_tensor) for ntype in ntypes } - for hop in range(num_layers): - subgraph = self.sampler( - seeds, - self.fanouts[hop], - self.replace, - self.prob_name, + minibatch._seed_nodes = seeds + minibatch.sampled_subgraphs = [] + return minibatch + + @staticmethod + def _set_input_nodes(minibatch): + minibatch.input_nodes = minibatch._seed_nodes + return minibatch + + # pylint: disable=arguments-differ + def sampling_stages( + self, datapipe, graph, fanouts, replace, prob_name, deduplicate, sampler + ): + datapipe = datapipe.transform( + partial(self._prepare, graph.node_type_to_id) + ) + for fanout in reversed(fanouts): + # Convert fanout to tensor. + if not isinstance(fanout, torch.Tensor): + fanout = torch.LongTensor([int(fanout)]) + datapipe = datapipe.sample_per_layer( + sampler, fanout, replace, prob_name ) - if self.deduplicate: - ( - original_row_node_ids, - compacted_csc_format, - ) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds) - subgraph = SampledSubgraphImpl( - sampled_csc=compacted_csc_format, - original_column_node_ids=seeds, - original_row_node_ids=original_row_node_ids, - original_edge_ids=subgraph.original_edge_ids, - ) - else: - ( - original_row_node_ids, - compacted_csc_format, - ) = compact_csc_format(subgraph.sampled_csc, seeds) - subgraph = SampledSubgraphImpl( - sampled_csc=compacted_csc_format, - original_column_node_ids=seeds, - original_row_node_ids=original_row_node_ids, - original_edge_ids=subgraph.original_edge_ids, - ) - subgraphs.insert(0, subgraph) - seeds = original_row_node_ids - return seeds, subgraphs + datapipe = datapipe.compact_per_layer(deduplicate) + + return datapipe.transform(self._set_input_nodes) @functional_datapipe("sample_layer_neighbor") class LayerNeighborSampler(NeighborSampler): + # pylint: disable=abstract-method """Sample layer neighbor edges from a graph and return a subgraph. Functional name: :obj:`sample_layer_neighbor`. @@ -280,5 +330,5 @@ def __init__( replace, prob_name, deduplicate, + graph.sample_layer_neighbors, ) - self.sampler = graph.sample_layer_neighbors diff --git a/python/dgl/graphbolt/subgraph_sampler.py b/python/dgl/graphbolt/subgraph_sampler.py index 3e3c3d9b507c..b05b8ca30619 100644 --- a/python/dgl/graphbolt/subgraph_sampler.py +++ b/python/dgl/graphbolt/subgraph_sampler.py @@ -22,21 +22,44 @@ class SubgraphSampler(MiniBatchTransformer): Functional name: :obj:`sample_subgraph`. This class is the base class of all subgraph samplers. Any subclass of - SubgraphSampler should implement the :meth:`sample_subgraphs` method. + SubgraphSampler should implement either the :meth:`sample_subgraphs` method + or the :meth:`sampling_stages` method to define the fine-grained sampling + stages to take advantage of optimizations provided by the GraphBolt + DataLoader. Parameters ---------- datapipe : DataPipe The datapipe. + args : Non-Keyword Arguments + Arguments to be passed into sampling_stages. + kwargs : Keyword Arguments + Arguments to be passed into sampling_stages. """ def __init__( self, datapipe, + *args, + **kwargs, ): - super().__init__(datapipe, self._sample) + datapipe = datapipe.transform(self._preprocess) + datapipe = self.sampling_stages(datapipe, *args, **kwargs) + datapipe = datapipe.transform(self._postprocess) + super().__init__(datapipe, self._identity) - def _sample(self, minibatch): + @staticmethod + def _identity(minibatch): + return minibatch + + @staticmethod + def _postprocess(minibatch): + delattr(minibatch, "_seed_nodes") + delattr(minibatch, "_seeds_timestamp") + return minibatch + + @staticmethod + def _preprocess(minibatch): if minibatch.node_pairs is not None: ( seeds, @@ -44,7 +67,7 @@ def _sample(self, minibatch): minibatch.compacted_node_pairs, minibatch.compacted_negative_srcs, minibatch.compacted_negative_dsts, - ) = self._node_pairs_preprocess(minibatch) + ) = SubgraphSampler._node_pairs_preprocess(minibatch) elif minibatch.seed_nodes is not None: seeds = minibatch.seed_nodes seeds_timestamp = ( @@ -55,13 +78,12 @@ def _sample(self, minibatch): f"Invalid minibatch {minibatch}: Either `node_pairs` or " "`seed_nodes` should have a value." ) - ( - minibatch.input_nodes, - minibatch.sampled_subgraphs, - ) = self.sample_subgraphs(seeds, seeds_timestamp) + minibatch._seed_nodes = seeds + minibatch._seeds_timestamp = seeds_timestamp return minibatch - def _node_pairs_preprocess(self, minibatch): + @staticmethod + def _node_pairs_preprocess(minibatch): use_timestamp = hasattr(minibatch, "timestamp") node_pairs = minibatch.node_pairs neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts @@ -191,6 +213,23 @@ def _node_pairs_preprocess(self, minibatch): compacted_negative_dsts if has_neg_dst else None, ) + def _sample(self, minibatch): + ( + minibatch.input_nodes, + minibatch.sampled_subgraphs, + ) = self.sample_subgraphs( + minibatch._seed_nodes, minibatch._seeds_timestamp + ) + return minibatch + + def sampling_stages(self, datapipe): + """The sampling stages are defined here by chaining to the datapipe. The + default implementation expects :meth:`sample_subgraphs` to be + implemented. To define fine-grained stages, this method should be + overridden. + """ + return datapipe.transform(self._sample) + def sample_subgraphs(self, seeds, seeds_timestamp): """Sample subgraphs from the given seeds, possibly with temporal constraints.