Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt] Refactor NeighborSampler and expose fine-grained datapipes. #6983

Merged
merged 40 commits into from
Feb 1, 2024
Merged
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
b4c045f
prototyping
mfbalin Jan 19, 2024
a26e354
fix bug
mfbalin Jan 20, 2024
58a1190
remove print expressions, works now
mfbalin Jan 20, 2024
eedb8d1
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 23, 2024
3d2ed98
add tests
mfbalin Jan 24, 2024
105924a
use seeds_timestamp in preprocess
mfbalin Jan 24, 2024
bfb28ec
add docstring for linting
mfbalin Jan 24, 2024
e4becc9
fix linting
mfbalin Jan 24, 2024
428ff24
fix argument bug
mfbalin Jan 24, 2024
85b0601
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 24, 2024
866316e
fix the bug
mfbalin Jan 24, 2024
e2793fd
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 24, 2024
2473722
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 27, 2024
2d1dda9
address reviews
mfbalin Jan 29, 2024
fad7c50
add docstring to the new `MinibatchTransformer`.
mfbalin Jan 29, 2024
a8fdfc6
address review properly.
mfbalin Jan 29, 2024
933246f
remove unused `Mapper` import for linting.
mfbalin Jan 29, 2024
cd68728
NeighborSampler2 now derives from `MinibatchTransformer`.
mfbalin Jan 30, 2024
c3a903d
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 30, 2024
dcbfb4e
FInal refactoring of NeighborSampler.
mfbalin Jan 30, 2024
21fe633
Fix not only preprocess but also postprocess issue.
mfbalin Jan 30, 2024
29861f1
take back test changes.
mfbalin Jan 30, 2024
232f2f3
fix in_subgraph_sampler
mfbalin Jan 30, 2024
03bea25
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 30, 2024
86d9c43
add docstring for `append_sampling_step`.
mfbalin Jan 30, 2024
f995d20
Address reviews, minimize changes, keep API exactly the same.
mfbalin Jan 31, 2024
a64d34e
remove leftover changes.
mfbalin Jan 31, 2024
e46b8c7
minor change.
mfbalin Jan 31, 2024
8cc858c
Make the function into a proper one so that it can be pickled.
mfbalin Jan 31, 2024
02ca357
make the lambda into a proper function so that it can be pickled.
mfbalin Jan 31, 2024
19b4367
linting.
mfbalin Jan 31, 2024
67d6f71
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 31, 2024
144134c
final linting.
mfbalin Jan 31, 2024
96bac52
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Jan 31, 2024
718cab8
Cleanup NeighborSampler as it does not need to store anything itself.
mfbalin Jan 31, 2024
ee3a7d7
linting
mfbalin Jan 31, 2024
1d906e7
address reviews by not passing sampler as string argument.
mfbalin Feb 1, 2024
6ab2e75
Merge branch 'master' into gb_refactor_neighbor_sampler
mfbalin Feb 1, 2024
6f880c0
Talk about `sampling_stages` in the SubgraphSampler API.
mfbalin Feb 1, 2024
5d907ee
add more documentation for `sampling_stages`.
mfbalin Feb 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
prototyping
  • Loading branch information
mfbalin committed Jan 19, 2024
commit b4c045f5683af428cafa67ef530ac4bc635be35e
2 changes: 1 addition & 1 deletion examples/sampling/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
@@ -117,7 +117,7 @@ def create_dataloader(
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(
datapipe = datapipe.sample_neighbor2(
graph, fanout if job != "infer" else [-1]
)

182 changes: 181 additions & 1 deletion python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,196 @@
"""Neighbor subgraph samplers for GraphBolt."""

import torch
from torch.utils.data import functional_datapipe

Check warning on line 4 in python/dgl/graphbolt/impl/neighbor_sampler.py

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.

from ..internal import compact_csc_format, unique_and_compact_csc_formats

from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import SampledSubgraphImpl
from torchdata.datapipes.iter import IterDataPipe, Mapper


__all__ = ["NeighborSampler", "LayerNeighborSampler"]
__all__ = ["NeighborSampler", "LayerNeighborSampler", "NeighborSampler2"]

@functional_datapipe("sample_per_layer")
class SamplePerLayer(Mapper):

def __init__(self, datapipe, sampler, fanout, replace, prob_name):
super().__init__(datapipe, self._sample_per_layer)
self.sampler = sampler
self.fanout = fanout
self.replace = replace
self.prob_name = prob_name

def _sample_per_layer(self, minibatch):
print("sample_per_layer", minibatch)
return self.sampler(minibatch.input_nodes, self.fanout, self.replace, self.prob_name), minibatch

@functional_datapipe("compact_per_layer")
class CompactPerLayer(Mapper):

def __init__(self, datapipe, deduplicate):
super().__init__(datapipe, self._compact_per_layer)
self.deduplicate = deduplicate

def _compact_per_layer(self, subgraph_minibatch):
subgraph, minibatch = subgraph_minibatch
seeds = minibatch.input_nodes
if self.deduplicate:
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
minibatch.input_nodes = original_row_node_ids
minibatch.sampled_subgraphs.insert(0, subgraph)
print("compact_per_layer", minibatch)
return minibatch


@functional_datapipe("sample_neighbor2")
class NeighborSampler2(IterDataPipe):
"""Sample neighbor edges from a graph and return a subgraph.

Functional name: :obj:`sample_neighbor`.

Neighbor sampler is responsible for sampling a subgraph from given data. It
returns an induced subgraph along with compacted information. In the
context of a node classification task, the neighbor sampler directly
utilizes the nodes provided as seed nodes. However, in scenarios involving
link prediction, the process needs another pre-peocess operation. That is,
gathering unique nodes from the given node pairs, encompassing both
positive and negative node pairs, and employs these nodes as the seed nodes
for subsequent steps.

Parameters
----------
datapipe : DataPipe
The datapipe.
graph : FusedCSCSamplingGraph
The graph on which to perform subgraph sampling.
fanouts: list[torch.Tensor] or list[int]
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
Note: The fanout order is from the outermost layer to innermost layer.
For example, the fanout '[15, 10, 5]' means that 15 to the outermost
layer, 10 to the intermediate layer and 5 corresponds to the innermost
layer.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
times. Otherwise, each value can be selected only once.
prob_name: str, optional
The name of an edge attribute used as the weights of sampling for
each node. This attribute tensor should contain (unnormalized)
probabilities corresponding to each neighboring edge of a node.
It must be a 1D floating-point or boolean tensor, with the number
of elements equalling the total number of edges.
deduplicate: bool
Boolean indicating whether seeds between hops will be deduplicated.
If True, the same elements in seeds will be deleted to only one.
Otherwise, the same elements will be remained.

Examples
-------
>>> import torch
>>> import dgl.graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
>>> indices = torch.LongTensor([1, 2, 0, 3, 5, 4, 3, 5])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> node_pairs = torch.LongTensor([[0, 1], [1, 2]])
>>> item_set = gb.ItemSet(node_pairs, names="node_pairs")
>>> datapipe = gb.ItemSampler(item_set, batch_size=1)
>>> datapipe = datapipe.sample_uniform_negative(graph, 2)
>>> datapipe = datapipe.sample_neighbor(graph, [5, 10, 15])
>>> next(iter(datapipe)).sampled_subgraphs
[SampledSubgraphImpl(sampled_csc=CSCFormatBase(
indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),
),
original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(
indptr=tensor([0, 2, 4, 5, 6, 7, 8]),
indices=tensor([1, 4, 0, 5, 5, 3, 3, 2]),
),
original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 4, 5, 2, 3]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(
indptr=tensor([0, 2, 4, 5, 6]),
indices=tensor([1, 4, 0, 5, 5, 3]),
),
original_row_node_ids=tensor([0, 1, 4, 5, 2, 3]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 4, 5]),
)]
"""

def __init__(
self,
datapipe,
graph,
fanouts,
replace=False,
prob_name=None,
deduplicate=True,
sampler="sample_neighbors",
):
self.graph = graph
datapipe = datapipe.sample_subgraph_preprocess()
def helper(minibatch):
seeds = minibatch.input_nodes
# Enrich seeds with all node types.
if isinstance(seeds, dict):
ntypes = list(self.graph.node_type_to_id.keys())
# Loop over different seeds to extract the device they are on.
device = None
dtype = None
for _, seed in seeds.items():
device = seed.device
dtype = seed.dtype
break
default_tensor = torch.tensor([], dtype=dtype, device=device)
seeds = {
ntype: seeds.get(ntype, default_tensor) for ntype in ntypes
}
minibatch.input_nodes = seeds
minibatch.sampled_subgraphs = []
print("helper_end", minibatch)
return minibatch
datapipe = datapipe.map(helper)
sampler = getattr(graph, sampler)
for fanout in reversed(fanouts):
# Convert fanout to tensor.
if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)])
datapipe = datapipe.sample_per_layer(sampler, fanout, replace, prob_name)
datapipe = datapipe.compact_per_layer(deduplicate)
self.datapipe = datapipe

def __iter__(self):
yield from self.datapipe

@functional_datapipe("sample_neighbor")
class NeighborSampler(SubgraphSampler):
177 changes: 176 additions & 1 deletion python/dgl/graphbolt/subgraph_sampler.py
Original file line number Diff line number Diff line change
@@ -4,15 +4,190 @@
from typing import Dict

from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper

from .base import etype_str_to_tuple
from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch_transformer import MiniBatchTransformer

__all__ = [

Check warning on line 13 in python/dgl/graphbolt/subgraph_sampler.py

GitHub Actions / lintrunner

UFMT format

Run `lintrunner -a` to apply this patch.
"SubgraphSampler",
"SubgraphSampler", "SubgraphSamplerPreprocess",
]

@functional_datapipe("sample_subgraph_preprocess")
class SubgraphSamplerPreprocess(Mapper):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph.

Functional name: :obj:`sample_subgraph`.

This class is the base class of all subgraph samplers. Any subclass of
SubgraphSampler should implement the :meth:`sample_subgraphs` method.

Parameters
----------
datapipe : DataPipe
The datapipe.
"""

def __init__(
self,
datapipe,
):
super().__init__(datapipe, self._preprocess)

def _preprocess(self, minibatch):
for minibatch in self.datapipe:
if minibatch.node_pairs is not None:
(
seeds,
seeds_timestamp,
minibatch.compacted_node_pairs,
minibatch.compacted_negative_srcs,
minibatch.compacted_negative_dsts,
) = self._node_pairs_preprocess(minibatch)
elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_nodes
seeds_timestamp = (
minibatch.timestamp if hasattr(minibatch, "timestamp") else None
)
else:
raise ValueError(
f"Invalid minibatch {minibatch}: Either `node_pairs` or "
"`seed_nodes` should have a value."
)
minibatch.input_nodes = seeds
return minibatch

def _node_pairs_preprocess(self, minibatch):
use_timestamp = hasattr(minibatch, "timestamp")
node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
has_neg_src = neg_src is not None
has_neg_dst = neg_dst is not None
is_heterogeneous = isinstance(node_pairs, Dict)
if is_heterogeneous:
has_neg_src = has_neg_src and all(
item is not None for item in neg_src.values()
)
has_neg_dst = has_neg_dst and all(
item is not None for item in neg_dst.values()
)
# Collect nodes from all types of input.
nodes = defaultdict(list)
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, (src, dst) in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
nodes[src_type].append(src)
nodes[dst_type].append(dst)
if use_timestamp:
nodes_timestamp[src_type].append(minibatch.timestamp[etype])
nodes_timestamp[dst_type].append(minibatch.timestamp[etype])
if has_neg_src:
for etype, src in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
nodes[src_type].append(src.view(-1))
if use_timestamp:
nodes_timestamp[src_type].append(
minibatch.timestamp[etype].repeat_interleave(
src.shape[-1]
)
)
if has_neg_dst:
for etype, dst in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
nodes[dst_type].append(dst.view(-1))
if use_timestamp:
nodes_timestamp[dst_type].append(
minibatch.timestamp[etype].repeat_interleave(
dst.shape[-1]
)
)
# Unique and compact the collected nodes.
if use_timestamp:
seeds, nodes_timestamp, compacted = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
(
compacted_node_pairs,
compacted_negative_srcs,
compacted_negative_dsts,
) = ({}, {}, {})
# Map back in same order as collect.
for etype, _ in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_node_pairs[etype] = (src, dst)
if has_neg_src:
for etype, _ in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_srcs[etype] = compacted[src_type].pop(0)
compacted_negative_srcs[etype] = compacted_negative_srcs[
etype
].view(neg_src[etype].shape)
if has_neg_dst:
for etype, _ in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_dsts[etype] = compacted[dst_type].pop(0)
compacted_negative_dsts[etype] = compacted_negative_dsts[
etype
].view(neg_dst[etype].shape)
else:
# Collect nodes from all types of input.
nodes = list(node_pairs)
nodes_timestamp = None
if use_timestamp:
# Timestamp for source and destination nodes are the same.
nodes_timestamp = [minibatch.timestamp, minibatch.timestamp]
if has_neg_src:
nodes.append(neg_src.view(-1))
if use_timestamp:
nodes_timestamp.append(
minibatch.timestamp.repeat_interleave(neg_src.shape[-1])
)
if has_neg_dst:
nodes.append(neg_dst.view(-1))
if use_timestamp:
nodes_timestamp.append(
minibatch.timestamp.repeat_interleave(neg_dst.shape[-1])
)
# Unique and compact the collected nodes.
if use_timestamp:
seeds, nodes_timestamp, compacted = compact_temporal_nodes(
nodes, nodes_timestamp
)
else:
seeds, compacted = unique_and_compact(nodes)
nodes_timestamp = None
# Map back in same order as collect.
compacted_node_pairs = tuple(compacted[:2])
compacted = compacted[2:]
if has_neg_src:
compacted_negative_srcs = compacted.pop(0)
# Since we need to calculate the neg_ratio according to the
# compacted_negatvie_srcs shape, we need to reshape it back.
compacted_negative_srcs = compacted_negative_srcs.view(
neg_src.shape
)
if has_neg_dst:
compacted_negative_dsts = compacted.pop(0)
# Same as above.
compacted_negative_dsts = compacted_negative_dsts.view(
neg_dst.shape
)
return (
seeds,
nodes_timestamp,
compacted_node_pairs,
compacted_negative_srcs if has_neg_src else None,
compacted_negative_dsts if has_neg_dst else None,
)

@functional_datapipe("sample_subgraph")
class SubgraphSampler(MiniBatchTransformer):
Loading