-
Notifications
You must be signed in to change notification settings - Fork 928
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove cudf.Scalar from shift/fillna #17922
base: branch-25.04
Are you sure you want to change the base?
Changes from all commits
388824f
2a677b3
79d066f
f3cccc8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,7 @@ | |
|
||
from cudf._typing import ( | ||
ColumnBinaryOperand, | ||
ColumnLike, | ||
DatetimeLikeScalar, | ||
Dtype, | ||
ScalarLike, | ||
|
@@ -269,6 +270,21 @@ def __contains__(self, item: ScalarLike) -> bool: | |
"cudf.core.column.NumericalColumn", self.astype("int64") | ||
) | ||
|
||
def _validate_fillna_value( | ||
self, fill_value: ScalarLike | ColumnLike | ||
) -> plc.Scalar | ColumnBase: | ||
"""Align fill_value for .fillna based on column type.""" | ||
if ( | ||
isinstance(fill_value, np.datetime64) | ||
and self.time_unit != np.datetime_data(fill_value)[0] | ||
): | ||
# TODO: Disallow this cast | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like a lot of your PRs have had these kinds of comments. Do they all fall into similar buckets? Should we open some issues for tracking? |
||
fill_value = fill_value.astype(self.dtype) | ||
elif isinstance(fill_value, str) and fill_value.lower() == "nat": | ||
# TODO: Disallow this casting; user should be explicit | ||
fill_value = np.datetime64(fill_value, self.time_unit) | ||
return super()._validate_fillna_value(fill_value) | ||
|
||
@functools.cached_property | ||
def time_unit(self) -> str: | ||
return np.datetime_data(self.dtype)[0] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,8 @@ | |
DecimalDtype, | ||
) | ||
from cudf.core.mixins import BinaryOperand | ||
from cudf.core.scalar import pa_scalar_to_plc_scalar | ||
from cudf.utils.dtypes import cudf_dtype_to_pa_type | ||
from cudf.utils.utils import pa_mask_buffer_to_mask | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -168,16 +170,35 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str): | |
|
||
return result | ||
|
||
def _scalar_to_plc_scalar(self, scalar: ScalarLike) -> plc.Scalar: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that #17422 is merged I think we can stop special-casing this and see if anything breaks. WDYT? It does mean that decimal conversions in tests will fail if run with an older version of pyarrow, but I think that's an OK tradeoff. We might have to put some conditional xfails into our test suite for the "oldest" test runs. |
||
"""Return a pylibcudf.Scalar that matches the type of self.dtype""" | ||
if not isinstance(scalar, pa.Scalar): | ||
# e.g casting int to decimal type isn't allow, but OK in the constructor? | ||
pa_scalar = pa.scalar( | ||
scalar, type=cudf_dtype_to_pa_type(self.dtype) | ||
) | ||
else: | ||
pa_scalar = scalar.cast(cudf_dtype_to_pa_type(self.dtype)) | ||
plc_scalar = pa_scalar_to_plc_scalar(pa_scalar) | ||
if isinstance(self.dtype, (Decimal32Dtype, Decimal64Dtype)): | ||
# pyarrow.Scalar only supports Decimal128 so conversion | ||
# from pyarrow would only return a pylibcudf.Scalar with Decimal128 | ||
col = ColumnBase.from_pylibcudf( | ||
plc.Column.from_scalar(plc_scalar, 1) | ||
).astype(self.dtype) | ||
return plc.copying.get_element(col.to_pylibcudf(mode="read"), 0) | ||
return plc_scalar | ||
|
||
def _validate_fillna_value( | ||
self, fill_value: ScalarLike | ColumnLike | ||
) -> cudf.Scalar | ColumnBase: | ||
) -> plc.Scalar | ColumnBase: | ||
"""Align fill_value for .fillna based on column type.""" | ||
if isinstance(fill_value, (int, Decimal)): | ||
return cudf.Scalar(fill_value, dtype=self.dtype) | ||
return super()._validate_fillna_value(fill_value) | ||
elif isinstance(fill_value, ColumnBase) and ( | ||
isinstance(self.dtype, DecimalDtype) or self.dtype.kind in "iu" | ||
): | ||
return fill_value.astype(self.dtype) | ||
return super()._validate_fillna_value(fill_value) | ||
raise TypeError( | ||
"Decimal columns only support using fillna with decimal and " | ||
"integer values" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels a bit odd as a class method. I feel like a free function that accepts a dtype would be more appropriate, then we could call that with
col.dtype
. Scoping-wise this doesn't feel like a Column method. Plus then it would directly mirrorpa_scalar_to_plc_scalar
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess there is currently a small benefit because we can override this method for decimal columns to get the specialized behavior that we need, but I think that we don't need that any more (see my comment on that class).