Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove cudf.Scalar from shift/fillna #17922

Open
wants to merge 4 commits into
base: branch-25.04
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cudf.core.scalar import pa_scalar_to_plc_scalar
from cudf.utils.dtypes import (
SIZE_TYPE_DTYPE,
cudf_dtype_to_pa_type,
find_common_type,
is_mixed_with_object_dtype,
min_signed_type,
Expand Down Expand Up @@ -1047,7 +1048,7 @@ def notnull(self) -> ColumnBase:

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> cudf.Scalar | ColumnBase:
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if cudf.api.types.is_scalar(fill_value):
if fill_value != _DEFAULT_CATEGORICAL_VALUE:
Expand All @@ -1057,7 +1058,11 @@ def _validate_fillna_value(
raise ValueError(
f"{fill_value=} must be in categories"
) from err
return cudf.Scalar(fill_value, dtype=self.codes.dtype)
return pa_scalar_to_plc_scalar(
pa.scalar(
fill_value, type=cudf_dtype_to_pa_type(self.codes.dtype)
)
)
else:
fill_value = column.as_column(fill_value, nan_as_null=False)
if isinstance(fill_value.dtype, CategoricalDtype):
Expand Down
23 changes: 15 additions & 8 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,12 +462,11 @@ def _fill(

@acquire_spill_lock()
def shift(self, offset: int, fill_value: ScalarLike) -> Self:
if not isinstance(fill_value, cudf.Scalar):
fill_value = cudf.Scalar(fill_value, dtype=self.dtype)
plc_fill_value = self._scalar_to_plc_scalar(fill_value)
plc_col = plc.copying.shift(
self.to_pylibcudf(mode="read"),
offset,
fill_value.device_value,
plc_fill_value,
)
return type(self).from_pylibcudf(plc_col) # type: ignore[return-value]

Expand Down Expand Up @@ -761,13 +760,21 @@ def _check_scatter_key_length(
f"{num_keys}"
)

def _scalar_to_plc_scalar(self, scalar: ScalarLike) -> plc.Scalar:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit odd as a class method. I feel like a free function that accepts a dtype would be more appropriate, then we could call that with col.dtype. Scoping-wise this doesn't feel like a Column method. Plus then it would directly mirror pa_scalar_to_plc_scalar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess there is currently a small benefit because we can override this method for decimal columns to get the specialized behavior that we need, but I think that we don't need that any more (see my comment on that class).

"""Return a pylibcudf.Scalar that matches the type of self.dtype"""
if not isinstance(scalar, pa.Scalar):
scalar = pa.scalar(scalar)
return pa_scalar_to_plc_scalar(
scalar.cast(cudf_dtype_to_pa_type(self.dtype))
)

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> cudf.Scalar | ColumnBase:
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if is_scalar(fill_value):
return cudf.Scalar(fill_value, dtype=self.dtype)
return as_column(fill_value)
return self._scalar_to_plc_scalar(fill_value)
return as_column(fill_value).astype(self.dtype)

@acquire_spill_lock()
def replace(
Expand Down Expand Up @@ -813,8 +820,8 @@ def fillna(
if method == "ffill"
else plc.replace.ReplacePolicy.FOLLOWING
)
elif is_scalar(fill_value):
plc_replace = cudf.Scalar(fill_value).device_value
elif isinstance(fill_value, plc.Scalar):
plc_replace = fill_value
else:
plc_replace = fill_value.to_pylibcudf(mode="read")
plc_column = plc.replace.replace_nulls(
Expand Down
16 changes: 16 additions & 0 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

from cudf._typing import (
ColumnBinaryOperand,
ColumnLike,
DatetimeLikeScalar,
Dtype,
ScalarLike,
Expand Down Expand Up @@ -269,6 +270,21 @@ def __contains__(self, item: ScalarLike) -> bool:
"cudf.core.column.NumericalColumn", self.astype("int64")
)

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if (
isinstance(fill_value, np.datetime64)
and self.time_unit != np.datetime_data(fill_value)[0]
):
# TODO: Disallow this cast
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like a lot of your PRs have had these kinds of comments. Do they all fall into similar buckets? Should we open some issues for tracking?

fill_value = fill_value.astype(self.dtype)
elif isinstance(fill_value, str) and fill_value.lower() == "nat":
# TODO: Disallow this casting; user should be explicit
fill_value = np.datetime64(fill_value, self.time_unit)
return super()._validate_fillna_value(fill_value)

@functools.cached_property
def time_unit(self) -> str:
return np.datetime_data(self.dtype)[0]
Expand Down
27 changes: 24 additions & 3 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
DecimalDtype,
)
from cudf.core.mixins import BinaryOperand
from cudf.core.scalar import pa_scalar_to_plc_scalar
from cudf.utils.dtypes import cudf_dtype_to_pa_type
from cudf.utils.utils import pa_mask_buffer_to_mask

if TYPE_CHECKING:
Expand Down Expand Up @@ -168,16 +170,35 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str):

return result

def _scalar_to_plc_scalar(self, scalar: ScalarLike) -> plc.Scalar:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that #17422 is merged I think we can stop special-casing this and see if anything breaks. WDYT? It does mean that decimal conversions in tests will fail if run with an older version of pyarrow, but I think that's an OK tradeoff. We might have to put some conditional xfails into our test suite for the "oldest" test runs.

"""Return a pylibcudf.Scalar that matches the type of self.dtype"""
if not isinstance(scalar, pa.Scalar):
# e.g casting int to decimal type isn't allow, but OK in the constructor?
pa_scalar = pa.scalar(
scalar, type=cudf_dtype_to_pa_type(self.dtype)
)
else:
pa_scalar = scalar.cast(cudf_dtype_to_pa_type(self.dtype))
plc_scalar = pa_scalar_to_plc_scalar(pa_scalar)
if isinstance(self.dtype, (Decimal32Dtype, Decimal64Dtype)):
# pyarrow.Scalar only supports Decimal128 so conversion
# from pyarrow would only return a pylibcudf.Scalar with Decimal128
col = ColumnBase.from_pylibcudf(
plc.Column.from_scalar(plc_scalar, 1)
).astype(self.dtype)
return plc.copying.get_element(col.to_pylibcudf(mode="read"), 0)
return plc_scalar

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> cudf.Scalar | ColumnBase:
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if isinstance(fill_value, (int, Decimal)):
return cudf.Scalar(fill_value, dtype=self.dtype)
return super()._validate_fillna_value(fill_value)
elif isinstance(fill_value, ColumnBase) and (
isinstance(self.dtype, DecimalDtype) or self.dtype.kind in "iu"
):
return fill_value.astype(self.dtype)
return super()._validate_fillna_value(fill_value)
raise TypeError(
"Decimal columns only support using fillna with decimal and "
"integer values"
Expand Down
13 changes: 9 additions & 4 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,20 @@ def find_and_replace(

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> cudf.Scalar | ColumnBase:
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if is_scalar(fill_value):
cudf_obj: cudf.Scalar | ColumnBase = cudf.Scalar(fill_value)
if not as_column(cudf_obj).can_cast_safely(self.dtype):
cudf_obj = ColumnBase.from_pylibcudf(
plc.Column.from_scalar(
pa_scalar_to_plc_scalar(pa.scalar(fill_value)), 1
)
)
if not cudf_obj.can_cast_safely(self.dtype):
raise TypeError(
f"Cannot safely cast non-equivalent "
f"{type(fill_value).__name__} to {self.dtype.name}"
)
return super()._validate_fillna_value(fill_value)
else:
cudf_obj = as_column(fill_value, nan_as_null=False)
if not cudf_obj.can_cast_safely(self.dtype): # type: ignore[attr-defined]
Expand All @@ -576,7 +581,7 @@ def _validate_fillna_value(
f"{cudf_obj.dtype.type.__name__} to "
f"{self.dtype.type.__name__}"
)
return cudf_obj.astype(self.dtype)
return cudf_obj.astype(self.dtype)

def can_cast_safely(self, to_dtype: DtypeObj) -> bool:
"""
Expand Down
23 changes: 22 additions & 1 deletion python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from cudf._typing import ColumnBinaryOperand, DatetimeLikeScalar, Dtype
from cudf._typing import (
ColumnBinaryOperand,
ColumnLike,
DatetimeLikeScalar,
Dtype,
ScalarLike,
)

_unit_to_nanoseconds_conversion = {
"ns": 1,
Expand Down Expand Up @@ -137,6 +143,21 @@ def __contains__(self, item: DatetimeLikeScalar) -> bool:
"cudf.core.column.NumericalColumn", self.astype("int64")
)

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if (
isinstance(fill_value, np.timedelta64)
and self.time_unit != np.datetime_data(fill_value)[0]
):
# TODO: Disallow this cast
fill_value = fill_value.astype(self.dtype)
elif isinstance(fill_value, str) and fill_value.lower() == "nat":
# TODO: Disallow this casting; user should be explicit
fill_value = np.timedelta64(fill_value, self.time_unit)
return super()._validate_fillna_value(fill_value)

@property
def values(self):
"""
Expand Down
Loading