Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Enable multi-partition Select operations containing basic aggregations #17941

Draft
wants to merge 12 commits into
base: branch-25.04
Choose a base branch
from
344 changes: 344 additions & 0 deletions python/cudf_polars/cudf_polars/experimental/expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
# TODO: remove need for this
# ruff: noqa: D101
"""Multi-partition Expr classes and utilities."""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import pylibcudf as plc

from cudf_polars.containers import Column
from cudf_polars.dsl.expressions.aggregation import Agg
from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
from cudf_polars.dsl.traversal import (
CachingVisitor,
reuse_if_unchanged,
traversal,
)
from cudf_polars.experimental.base import get_key_name

if TYPE_CHECKING:
from collections.abc import Mapping, MutableMapping, Sequence

from cudf_polars.containers import DataFrame
from cudf_polars.dsl.expressions.base import AggInfo
from cudf_polars.typing import ExprTransformer


_SUPPORTED_AGGS = ("count", "min", "max", "sum", "mean")


class FusedExpr(Expr):
"""
A single Expr node representing a fused Expr sub-graph.

Notes
-----
A FusedExpr may not contain more than one non-pointwise
Expr nodes. If the object does contain a non-pointwise
node, that node must be the root of sub_expr.
"""

__slots__ = ("sub_expr",)
_non_child = ("dtype", "sub_expr")

def __init__(
self,
dtype: plc.DataType,
sub_expr: Expr,
*children: Expr,
):
self.dtype = dtype
self.sub_expr = sub_expr
self.children = children
Comment on lines +55 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: what's the difference between a sub_expr and a child? Is the sub_expr just the thing this came from?

Copy link
Member Author

@rjzamora rjzamora Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A FusedExpr node corresponds to multiple fused expression nodes that are effectively "owned" by the FusedExpr object.

For example, let's assume we have an expression like A = B + Sum(C + D). Since A contains a non-pointwise node, we can decompose it into two FusedExpr nodes:

  • F0 = B + F1
  • F1 = Sum(C + D).

The sub_expr attributes of F0 and F1 correspond to the Expr nodes that they "own":

  • F0.sub_expr = B + F1
  • F1.sub_expr = Sum(C + D)

The children of a FusedExpr node can only contain other FusedExpr nodes:

  • F0.children = (F1,)
  • F1.children = ()

The distinction between sub_expr and children is important during evaluation. When we evaluate F0, we must already know the result of evaluating its children (F1), but we don't need to know the result of any other nodes in F0.sub_expr.

self.is_pointwise = sub_expr.is_pointwise

def do_evaluate(
self,
df: DataFrame,
*,
context: ExecutionContext = ExecutionContext.FRAME,
mapping: Mapping[Expr, Column] | None = None,
) -> Column:
"""Evaluate this expression given a dataframe for context."""
return self.sub_expr.evaluate(df, context=context, mapping=mapping)

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
return self.sub_expr.collect_agg(depth=depth)


def extract_partition_counts(
exprs: Sequence[Expr],
child_ir_count: int,
*,
update: MutableMapping[Expr, int] | None = None,
skip_fused_exprs: bool = False,
) -> MutableMapping[Expr, int]:
"""
Extract a partition-count mapping for Expr nodes.

Parameters
----------
exprs
Sequence of root expressions to traverse and
get partition counts.
child_ir_count
Partition count for the child-IR node.
update
Existing mapping to update.
skip_fused_exprs
Whether to skip over FusedExpr objects. This
can be used to stay within a local FusedExpr
sub-expression.

Returns
-------
Mapping between Expr nodes and partition counts.
"""
expr_partition_counts: MutableMapping[Expr, int] = update or {}
for expr in exprs:
for node in list(traversal([expr]))[::-1]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you doing this because you want a child-before-parent traversal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we are transforming the Expr graph to comprise only FusedExpr nodes, and FusedExpr.chilren may only comprise other FusedExpr nodes. Therefore, we must transform the expression using a child-before-parent traversal.

if isinstance(node, FusedExpr):
# Process the fused sub-expression graph first
if skip_fused_exprs:
continue # Stay within the current sub expression
expr_partition_counts = extract_partition_counts(
[node.sub_expr],
child_ir_count,
update=expr_partition_counts,
skip_fused_exprs=True,
)
expr_partition_counts[node] = expr_partition_counts[node.sub_expr]
elif isinstance(node, Agg):
# Assume all aggregations produce 1 partition
expr_partition_counts[node] = 1
Comment on lines +116 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: This is not right, I think, if we observe the Agg inside a groupby.

We should probably attach the execution context of an expression when constructing it. (and remove the ExecutionContext argument from do_evaluate).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you may need to explain this one to me offline :)

elif node.is_pointwise:
# Pointwise expressions should preserve child partition count
if node.children:
# Assume maximum child partition count
expr_partition_counts[node] = max(
[expr_partition_counts[c] for c in node.children]
)
else:
# If no children, we are preserving the child-IR partition count
expr_partition_counts[node] = child_ir_count
else:
raise NotImplementedError(
f"{type(node)} not supported for multiple partitions."
)

return expr_partition_counts


def replace_sub_expr(e: Expr, rec: ExprTransformer):
"""Replace a target expression node."""
mapping = rec.state["mapping"]
if e in mapping:
return mapping[e]
return reuse_if_unchanged(e, rec)
Comment on lines +137 to +142
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could just call this replace, and have replace(expr, mapping) as the public function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm following the CachingVisitor pattern used elsewhere. Don't I need to include the rec: ExprTransformer argument?



def rename_agg(agg: Agg, new_name: str):
"""Modify the name of an aggregation expression."""
return CachingVisitor(
replace_sub_expr,
state={"mapping": {agg: Agg(agg.dtype, new_name, agg.options, *agg.children)}},
)(agg)
Comment on lines +145 to +150
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Renaming" feels like the wrong thing, because the options for one agg might not apply to the options of another.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the options for one agg might not apply to the options of another

Yeah, that's totally true. We definitely need to figure out how the options "plumbing" should work here. I was hoping the options would normally translate in a trivial way, but I don't have a great sense for the range of possibilities.



def decompose_expr_graph(expr: Expr) -> FusedExpr:
"""Transform an Expr into a graph of FusedExpr nodes."""
root = expr
while True:
exprs = [
e
for e in list(traversal([root]))[::-1]
if not (isinstance(e, FusedExpr) or e.is_pointwise)
]
if not exprs:
if isinstance(root, FusedExpr):
break # We are done rewriting root
exprs = [root]
old = exprs[0]

# Check that we can handle old
if not old.is_pointwise and not (
isinstance(old, Agg) and old.name in _SUPPORTED_AGGS
):
raise NotImplementedError(
f"Selection does not support {expr} for multiple partitions."
)

# Rewrite root to replace old with FusedExpr(old)
children = [child for child in traversal([old]) if isinstance(child, FusedExpr)]
new = FusedExpr(old.dtype, old, *children)
mapper = CachingVisitor(replace_sub_expr, state={"mapping": {old: new}})
root = mapper(root)

return root


def evaluate_chunk(
df: DataFrame,
expr: Expr,
children: tuple[Expr, ...],
*references: Column,
) -> Column:
"""Evaluate a single aggregation."""
return expr.evaluate(df, mapping=dict(zip(children, references, strict=False)))


def evaluate_chunk_multi(
df: DataFrame,
exprs: Sequence[Expr],
children: tuple[Expr, ...],
*references: Column,
) -> tuple[Column, ...]:
"""Evaluate multiple aggregations."""
return tuple(
expr.evaluate(df, mapping=dict(zip(children, references, strict=False)))
for expr in exprs
)


def combine_chunks_multi(
column_chunks: Sequence[tuple[Column]],
combine_aggs: Sequence[Agg],
finalize: tuple[plc.DataType, str] | None,
) -> Column:
"""Aggregate Column chunks."""
column_chunk_lists = zip(*column_chunks, strict=False)

combined = [
agg.op(
Column(
plc.concatenate.concatenate([col.obj for col in column_chunk_list]),
name=column_chunk_list[0].name,
)
)
for agg, column_chunk_list in zip(
combine_aggs, column_chunk_lists, strict=False
)
]

if finalize:
# Perform optional BinOp on combined columns
dt, op_name = finalize
op = getattr(plc.binaryop.BinaryOperator, op_name)
cols = [c.obj for c in combined]
return Column(plc.binaryop.binary_operation(*cols, op, dt))

assert len(combined) == 1
return combined[0]


def make_agg_graph(
expr: FusedExpr,
child_name: str,
expr_partition_counts: MutableMapping[Expr, int],
) -> MutableMapping[Any, Any]:
"""Build a FusedExpr aggregation graph."""
agg = expr.sub_expr
assert isinstance(agg, Agg), f"Expected Agg, got {agg}"
assert agg.name in _SUPPORTED_AGGS, f"Agg {agg} not supported"

# NOTE: This algorithm assumes we are doing nested
# aggregations, or we are only aggregating a single
# column. If we are performing aligned aggregations
# across multiple columns at once, we should perform
# our reduction at the DataFrame level instead.

key_name = get_key_name(expr)
expr_child_names = [get_key_name(c) for c in expr.children]
expr_bcast = [expr_partition_counts[c] == 1 for c in expr.children]
input_count = max(expr_partition_counts[c] for c in agg.children)

graph: MutableMapping[Any, Any] = {}

# Define operations for each aggregation stage
agg_name = agg.name
if agg_name == "count":
chunk_aggs = [agg]
combine_aggs = [rename_agg(agg, "sum")]
finalize = None
elif agg_name == "mean":
sum_agg = rename_agg(agg, "sum")
count_agg = rename_agg(agg, "count")
chunk_aggs = [sum_agg, count_agg]
combine_aggs = [sum_agg, sum_agg]
finalize = (agg.dtype, "DIV")
else:
chunk_aggs = [agg]
combine_aggs = [agg]
finalize = None

# Pointwise stage
chunk_name = f"chunk-{key_name}"
for i in range(input_count):
graph[(chunk_name, i)] = (
evaluate_chunk_multi,
(child_name, i),
chunk_aggs,
expr.children,
*[
(name, 0) if bcast else (name, i)
for name, bcast in zip(expr_child_names, expr_bcast, strict=False)
],
)

# Combine and finalize
graph[(key_name, 0)] = (
combine_chunks_multi,
list(graph.keys()),
combine_aggs,
finalize,
)

return graph


def make_pointwise_graph(
expr: FusedExpr,
child_name: str,
expr_partition_counts: MutableMapping[Expr, int],
) -> MutableMapping[Any, Any]:
"""Build simple pointwise FusedExpr graph."""
key_name = get_key_name(expr)
expr_child_names = [get_key_name(c) for c in expr.children]
expr_bcast = [expr_partition_counts[c] == 1 for c in expr.children]
count = expr_partition_counts[expr]
return {
(key_name, i): (
evaluate_chunk,
(child_name, i),
expr.sub_expr,
expr.children,
*[
(name, 0) if bcast else (name, i)
for name, bcast in zip(expr_child_names, expr_bcast, strict=False)
],
)
for i in range(count)
}


def make_fusedexpr_graph(
expr: FusedExpr,
child_name: str,
expr_partition_counts: MutableMapping[Expr, int],
) -> MutableMapping[Any, Any]:
"""Build task graph for a FusedExpr node."""
sub_expr = expr.sub_expr
if isinstance(sub_expr, Agg) and sub_expr.name in _SUPPORTED_AGGS:
return make_agg_graph(expr, child_name, expr_partition_counts)
elif expr.is_pointwise:
return make_pointwise_graph(expr, child_name, expr_partition_counts)
else:
# TODO: Implement "complex" aggs (e.g. mean, std, etc)
raise NotImplementedError(
f"{type(sub_expr)} not supported for multiple partitions."
)
3 changes: 1 addition & 2 deletions python/cudf_polars/cudf_polars/experimental/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import cudf_polars.experimental.io
import cudf_polars.experimental.select
import cudf_polars.experimental.shuffle # noqa: F401
from cudf_polars.dsl.ir import IR, Cache, Filter, HStack, Projection, Select, Union
from cudf_polars.dsl.ir import IR, Cache, Filter, HStack, Projection, Union
from cudf_polars.dsl.traversal import CachingVisitor, traversal
from cudf_polars.experimental.base import PartitionInfo, _concat, get_key_name
from cudf_polars.experimental.dispatch import (
Expand Down Expand Up @@ -251,4 +251,3 @@ def _generate_ir_tasks_pwise(
generate_ir_tasks.register(Cache, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Filter, _generate_ir_tasks_pwise)
generate_ir_tasks.register(HStack, _generate_ir_tasks_pwise)
generate_ir_tasks.register(Select, _generate_ir_tasks_pwise)
Loading
Loading