Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implementation of udf and udaf decorator #1040

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def lit(value):


udf = ScalarUDF.udf
udf_decorator = ScalarUDF.udf_decorator

udaf = AggregateUDF.udaf
udaf_decorator = AggregateUDF.udaf_decorator

udwf = WindowUDF.udwf
47 changes: 47 additions & 0 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
import functools

import pyarrow

Expand Down Expand Up @@ -148,6 +149,27 @@ def udf(
volatility=volatility,
)

@staticmethod
def udf_decorator(
input_types: list[pyarrow.DataType],
return_type: _R,
volatility: Volatility | str,
name: Optional[str] = None
):
def decorator(func):
udf_caller = ScalarUDF.udf(
func,
input_types,
return_type,
volatility,
name
)

@functools.wraps(func)
def wrapper(*args, **kwargs):
return udf_caller(*args, **kwargs)
return wrapper
return decorator

class Accumulator(metaclass=ABCMeta):
"""Defines how an :py:class:`AggregateUDF` accumulates values."""
Expand Down Expand Up @@ -287,6 +309,31 @@ def sum_bias_10() -> Summarize:
state_type=state_type,
volatility=volatility,
)

@staticmethod
def udaf_decorator(
input_types: pyarrow.DataType | list[pyarrow.DataType],
return_type: pyarrow.DataType,
state_type: list[pyarrow.DataType],
volatility: Volatility | str,
name: Optional[str] = None
):
def decorator(accum: Callable[[], Accumulator]):
udaf_caller = AggregateUDF.udaf(
accum,
input_types,
return_type,
state_type,
volatility,
name
)

@functools.wraps(accum)
def wrapper(*args, **kwargs):
return udaf_caller(*args, **kwargs)
return wrapper
return decorator



class WindowEvaluator(metaclass=ABCMeta):
Expand Down
50 changes: 49 additions & 1 deletion python/tests/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pyarrow as pa
import pyarrow.compute as pc
import pytest
from datafusion import Accumulator, column, udaf
from datafusion import Accumulator, column, udaf, udaf_decorator


class Summarize(Accumulator):
Expand Down Expand Up @@ -116,6 +116,29 @@ def test_udaf_aggregate(df):

assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])

def test_udaf_decorator_aggregate(df):

@udaf_decorator(pa.float64(),
pa.float64(),
[pa.float64()],
"immutable")
def summarize():
return Summarize()

df1 = df.aggregate([], [summarize(column("a"))])

# execute and collect the first (and only) batch
result = df1.collect()[0]

assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])

df2 = df.aggregate([], [summarize(column("a"))])

# Run a second time to ensure the state is properly reset
result = df2.collect()[0]

assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])


def test_udaf_aggregate_with_arguments(df):
bias = 10.0
Expand Down Expand Up @@ -143,6 +166,31 @@ def test_udaf_aggregate_with_arguments(df):
assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])


def test_udaf_decorator_aggregate_with_arguments(df):
bias = 10.0

@udaf_decorator(pa.float64(),
pa.float64(),
[pa.float64()],
"immutable")
def summarize():
return Summarize(bias)

df1 = df.aggregate([], [summarize(column("a"))])

# execute and collect the first (and only) batch
result = df1.collect()[0]

assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])

df2 = df.aggregate([], [summarize(column("a"))])

# Run a second time to ensure the state is properly reset
result = df2.collect()[0]

assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0])


def test_group_by(df):
summarize = udaf(
Summarize,
Expand Down
42 changes: 36 additions & 6 deletions python/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import pyarrow as pa
import pytest
from datafusion import column, udf
from datafusion import column, udf, udf_decorator


@pytest.fixture
def df(ctx):
# create a RecordBatch and a new DataFrame from it
batch = pa.RecordBatch.from_arrays(
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
[pa.array([1, 2, 3]), pa.array([4, 4, None])],
names=["a", "b"],
)
return ctx.create_dataframe([[batch]], name="test_table")
Expand All @@ -39,10 +39,20 @@ def test_udf(df):
volatility="immutable",
)

df = df.select(is_null(column("a")))
df = df.select(is_null(column("b")))
result = df.collect()[0].column(0)

assert result == pa.array([False, False, False])
assert result == pa.array([False, False, True])


def test_udf_decorator(df):
@udf_decorator([pa.int64()], pa.bool_(), "immutable")
def is_null(x: pa.Array) -> pa.Array:
return x.is_null()

df = df.select(is_null(column("b")))
result = df.collect()[0].column(0)
assert result == pa.array([False, False, True])


def test_register_udf(ctx, df) -> None:
Expand All @@ -56,10 +66,10 @@ def test_register_udf(ctx, df) -> None:

ctx.register_udf(is_null)

df_result = ctx.sql("select is_null(a) from test_table")
df_result = ctx.sql("select is_null(b) from test_table")
result = df_result.collect()[0].column(0)

assert result == pa.array([False, False, False])
assert result == pa.array([False, False, True])


class OverThresholdUDF:
Expand Down Expand Up @@ -94,3 +104,23 @@ def test_udf_with_parameters(df) -> None:
result = df2.collect()[0].column(0)

assert result == pa.array([False, True, True])


def test_udf_with_parameters(df) -> None:
@udf_decorator([pa.int64()], pa.bool_(), "immutable")
def udf_no_param(values: pa.Array) -> pa.Array:
return OverThresholdUDF()(values)

df1 = df.select(udf_no_param(column("a")))
result = df1.collect()[0].column(0)

assert result == pa.array([True, True, True])

@udf_decorator([pa.int64()], pa.bool_(), "immutable")
def udf_with_param(values: pa.Array) -> pa.Array:
return OverThresholdUDF(2)(values)

df2 = df.select(udf_with_param(column("a")))
result = df2.collect()[0].column(0)

assert result == pa.array([False, True, True])