Skip to content

Commit

Permalink
Avoid cudf.dtype internally in favor of pre-defined, supported types (
Browse files Browse the repository at this point in the history
#17839)

xref #12494 and #12495

`cudf.dtype` is useful when cudf is passed a `dtype` argument from a user to perform inference on the input to make it cudf-compatable. Internally, we don't need this inference because we know the exact types to be passed & that are supported by cudf (columns), so this PR avoids calling `cudf.dtype` internally.

Generally:

* Define `CUDF_STRING_DTYPE` as a definitive cudf Python string type instead of `cudf/np.dtype("O"/"object", "str")`
* Prefer using `np.<type>` instead of `"<type>"` (using `np.` like an enum namespace)

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #17839
  • Loading branch information
mroeschke authored Feb 4, 2025
1 parent 36b0f3a commit a7e0257
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 97 deletions.
3 changes: 2 additions & 1 deletion python/cudf/cudf/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_list_dtype,
is_struct_dtype,
)
from cudf.utils.dtypes import CUDF_STRING_DTYPE


def is_numeric_dtype(obj):
Expand Down Expand Up @@ -113,7 +114,7 @@ def is_string_dtype(obj):
return (
(
isinstance(obj, (cudf.Index, cudf.Series))
and obj.dtype == cudf.dtype("O")
and obj.dtype == CUDF_STRING_DTYPE
)
or (isinstance(obj, cudf.core.column.StringColumn))
or (
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/_internals/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

from typing import TYPE_CHECKING, Literal

import numpy as np
from numba.np import numpy_support

import pylibcudf as plc

import cudf
from cudf.api.types import is_scalar
from cudf.utils import cudautils
from cudf.utils.dtypes import SUPPORTED_NUMPY_TO_PYLIBCUDF_TYPES
Expand Down Expand Up @@ -234,7 +234,7 @@ def from_udf(cls, op, *args, **kwargs) -> Self:
nb_type = numpy_support.from_dtype(kwargs["dtype"])
type_signature = (nb_type[:],)
ptx_code, output_dtype = cudautils.compile_udf(op, type_signature)
output_np_dtype = cudf.dtype(output_dtype)
output_np_dtype = np.dtype(output_dtype)
if output_np_dtype not in SUPPORTED_NUMPY_TO_PYLIBCUDF_TYPES:
raise TypeError(
f"Result of window function has unsupported dtype {op[1]}"
Expand Down
16 changes: 7 additions & 9 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def slice(self, start: int, stop: int, stride: int | None = None) -> Self:
# Need to create a gather map for given slice with stride
gather_map = as_column(
range(start, stop, stride),
dtype=cudf.dtype(np.int32),
dtype=np.dtype(np.int32),
)
return self.take(gather_map)

Expand Down Expand Up @@ -692,7 +692,7 @@ def _scatter_by_slice(
cudf.core.column.NumericalColumn,
as_column(
rng,
dtype=cudf.dtype(np.int32),
dtype=np.dtype(np.int32),
),
)

Expand Down Expand Up @@ -1796,9 +1796,7 @@ def column_empty(
children = (
cudf.core.column.NumericalColumn(
data=as_buffer(
rmm.DeviceBuffer(
size=row_count * cudf.dtype(SIZE_TYPE_DTYPE).itemsize
)
rmm.DeviceBuffer(size=row_count * SIZE_TYPE_DTYPE.itemsize)
),
size=None,
dtype=SIZE_TYPE_DTYPE,
Expand Down Expand Up @@ -2046,7 +2044,7 @@ def as_column(
)
)
if cudf.get_option("default_integer_bitwidth") and dtype is None:
dtype = cudf.dtype(
dtype = np.dtype(
f"i{cudf.get_option('default_integer_bitwidth') // 8}"
)
if dtype is not None:
Expand Down Expand Up @@ -2263,7 +2261,7 @@ def as_column(
and np.isnan(arbitrary)
):
if dtype is None:
dtype = getattr(arbitrary, "dtype", cudf.dtype("float64"))
dtype = getattr(arbitrary, "dtype", np.dtype(np.float64))
arbitrary = None
if isinstance(arbitrary, pa.Scalar):
col = ColumnBase.from_pylibcudf(
Expand Down Expand Up @@ -2480,7 +2478,7 @@ def as_column(
and pa.types.is_floating(arbitrary.type)
):
dtype = _maybe_convert_to_default_type(
cudf.dtype(arbitrary.type.to_pandas_dtype())
np.dtype(arbitrary.type.to_pandas_dtype())
)
except (pa.ArrowInvalid, pa.ArrowTypeError, TypeError):
arbitrary = pd.Series(arbitrary)
Expand Down Expand Up @@ -2562,7 +2560,7 @@ def deserialize_columns(headers: list[dict], frames: list) -> list[ColumnBase]:
def concat_columns(objs: "MutableSequence[ColumnBase]") -> ColumnBase:
"""Concatenate a sequence of columns."""
if len(objs) == 0:
dtype = cudf.dtype(None)
dtype = np.dtype(np.float64)
return column_empty(0, dtype=dtype)

# If all columns are `NumericalColumn` with different dtypes,
Expand Down
21 changes: 13 additions & 8 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
from cudf.core.column.column import ColumnBase, as_column
from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion
from cudf.core.scalar import pa_scalar_to_plc_scalar
from cudf.utils.dtypes import _get_base_dtype, cudf_dtype_to_pa_type
from cudf.utils.dtypes import (
CUDF_STRING_DTYPE,
_get_base_dtype,
cudf_dtype_to_pa_type,
)
from cudf.utils.utils import (
_all_bools_with_nulls,
_datetime_timedelta_find_and_replace,
Expand Down Expand Up @@ -182,7 +186,7 @@ def _resolve_mixed_dtypes(
lhs_unit = units.index(lhs_time_unit)
rhs_time_unit = cudf.utils.dtypes.get_time_unit(rhs)
rhs_unit = units.index(rhs_time_unit)
return cudf.dtype(f"{base_type}[{units[max(lhs_unit, rhs_unit)]}]")
return np.dtype(f"{base_type}[{units[max(lhs_unit, rhs_unit)]}]")


class DatetimeColumn(column.ColumnBase):
Expand Down Expand Up @@ -757,7 +761,7 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
}
and other_is_datetime64
):
out_dtype = cudf.dtype(np.bool_)
out_dtype = np.dtype(np.bool_)
elif op == "__add__" and other_is_timedelta:
# The only thing we can add to a datetime is a timedelta. This
# operation is symmetric, i.e. we allow `datetime + timedelta` or
Expand All @@ -778,7 +782,7 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
"NULL_EQUALS",
"NULL_NOT_EQUALS",
}:
out_dtype = cudf.dtype(np.bool_)
out_dtype = np.dtype(np.bool_)
if isinstance(other, ColumnBase) and not isinstance(
other, DatetimeColumn
):
Expand Down Expand Up @@ -823,13 +827,14 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool:
to_res, _ = np.datetime_data(to_dtype)
self_res, _ = np.datetime_data(self.dtype)

max_int = np.iinfo(cudf.dtype("int64")).max
int64 = np.dtype(np.int64)
max_int = np.iinfo(int64).max

max_dist = np.timedelta64(
self.max().astype(cudf.dtype("int64"), copy=False), self_res
self.max().astype(int64, copy=False), self_res
)
min_dist = np.timedelta64(
self.min().astype(cudf.dtype("int64"), copy=False), self_res
self.min().astype(int64, copy=False), self_res
)

self_delta_dtype = np.timedelta64(0, self_res).dtype
Expand All @@ -842,7 +847,7 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool:
return True
else:
return False
elif to_dtype == cudf.dtype("int64") or to_dtype == cudf.dtype("O"):
elif to_dtype == np.dtype(np.int64) or to_dtype == CUDF_STRING_DTYPE:
# can safely cast to representation, or string
return True
else:
Expand Down
14 changes: 7 additions & 7 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from cudf.core.scalar import pa_scalar_to_plc_scalar
from cudf.errors import MixedTypeError
from cudf.utils.dtypes import (
CUDF_STRING_DTYPE,
find_common_type,
min_column_type,
min_signed_type,
Expand Down Expand Up @@ -265,13 +266,13 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
and tmp.dtype.kind != "b"
):
if isinstance(tmp, NumericalColumn) and 0 in tmp:
out_dtype = cudf.dtype("float64")
out_dtype = np.dtype(np.float64)
elif isinstance(tmp, cudf.Scalar):
if tmp.is_valid() and tmp == 0:
# tmp == 0 can return NA
out_dtype = cudf.dtype("float64")
out_dtype = np.dtype(np.float64)
elif is_scalar(tmp) and tmp == 0:
out_dtype = cudf.dtype("float64")
out_dtype = np.dtype(np.float64)

if op in {"__and__", "__or__", "__xor__"}:
if self.dtype.kind == "f" or other.dtype.kind == "f":
Expand Down Expand Up @@ -362,7 +363,7 @@ def as_string_column(self) -> cudf.core.column.StringColumn:
if len(self) == 0:
return cast(
cudf.core.column.StringColumn,
column.column_empty(0, dtype="object"),
column.column_empty(0, dtype=CUDF_STRING_DTYPE),
)
elif self.dtype.kind == "b":
conv_func = functools.partial(
Expand All @@ -386,7 +387,7 @@ def as_datetime_column(
self, dtype: Dtype
) -> cudf.core.column.DatetimeColumn:
return cudf.core.column.DatetimeColumn(
data=self.astype("int64").base_data, # type: ignore[arg-type]
data=self.astype(np.dtype(np.int64)).base_data, # type: ignore[arg-type]
dtype=dtype,
mask=self.base_mask,
offset=self.offset,
Expand All @@ -397,7 +398,7 @@ def as_timedelta_column(
self, dtype: Dtype
) -> cudf.core.column.TimeDeltaColumn:
return cudf.core.column.TimeDeltaColumn(
data=self.astype("int64").base_data, # type: ignore[arg-type]
data=self.astype(np.dtype(np.int64)).base_data, # type: ignore[arg-type]
dtype=dtype,
mask=self.base_mask,
offset=self.offset,
Expand All @@ -408,7 +409,6 @@ def as_decimal_column(self, dtype: Dtype) -> DecimalBaseColumn:
return self.cast(dtype=dtype) # type: ignore[return-value]

def as_numerical_column(self, dtype: Dtype) -> NumericalColumn:
dtype = cudf.dtype(dtype)
if dtype == self.dtype:
return self
return self.cast(dtype=dtype) # type: ignore[return-value]
Expand Down
9 changes: 5 additions & 4 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cudf.core._internals import binaryop
from cudf.core.buffer import Buffer, acquire_spill_lock
from cudf.core.column.column import ColumnBase
from cudf.utils.dtypes import np_to_pa_dtype
from cudf.utils.dtypes import CUDF_STRING_DTYPE, np_to_pa_dtype
from cudf.utils.utils import (
_all_bools_with_nulls,
_datetime_timedelta_find_and_replace,
Expand Down Expand Up @@ -192,7 +192,7 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str) -> ColumnBase:
"NULL_EQUALS",
"NULL_NOT_EQUALS",
}:
out_dtype = cudf.dtype(np.bool_)
out_dtype = np.dtype(np.bool_)
elif op == "__mod__":
out_dtype = determine_out_dtype(self.dtype, other.dtype)
elif op in {"__truediv__", "__floordiv__"}:
Expand Down Expand Up @@ -374,7 +374,7 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool:
return True
else:
return False
elif to_dtype == cudf.dtype("int64") or to_dtype == cudf.dtype("O"):
elif to_dtype == np.dtype(np.int64) or to_dtype == CUDF_STRING_DTYPE:
# can safely cast to representation, or string
return True
else:
Expand All @@ -383,7 +383,8 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool:
def mean(self, skipna=None) -> pd.Timedelta:
return pd.Timedelta(
cast(
"cudf.core.column.NumericalColumn", self.astype("int64")
"cudf.core.column.NumericalColumn",
self.astype(np.dtype(np.int64)),
).mean(skipna=skipna),
unit=self.time_unit,
).as_unit(self.time_unit)
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6728,7 +6728,7 @@ def _apply_cupy_method_axis_1(self, method, *args, **kwargs):
prepared._data[col]
)
if common_dtype.kind != "M"
else cudf.dtype("float64")
else np.dtype(np.float64)
)
.fillna(np.nan)
)
Expand Down
27 changes: 10 additions & 17 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from cudf.core._compat import PANDAS_GE_210, PANDAS_LT_300
from cudf.core.abc import Serializable
from cudf.utils.docutils import doc_apply
from cudf.utils.dtypes import CUDF_STRING_DTYPE

if PANDAS_GE_210:
PANDAS_NUMPY_DTYPE = pd.core.dtypes.dtypes.NumpyEADtype
Expand All @@ -45,7 +46,7 @@ def dtype(arbitrary):
dtype: the cuDF-supported dtype that best matches `arbitrary`
"""
# first, check if `arbitrary` is one of our extension types:
if isinstance(arbitrary, cudf.core.dtypes._BaseDtype):
if isinstance(arbitrary, (_BaseDtype, pd.DatetimeTZDtype)):
return arbitrary

# next, try interpreting arbitrary as a NumPy dtype that we support:
Expand All @@ -55,7 +56,7 @@ def dtype(arbitrary):
pass
else:
if np_dtype.kind in set("OU"):
return np.dtype("object")
return CUDF_STRING_DTYPE
elif (
np_dtype
not in cudf.utils.dtypes.SUPPORTED_NUMPY_TO_PYLIBCUDF_TYPES
Expand All @@ -78,7 +79,7 @@ def dtype(arbitrary):
"Nullable types not supported in pandas compatibility mode"
)
elif isinstance(pd_dtype, pd.StringDtype):
return np.dtype("object")
return CUDF_STRING_DTYPE
else:
return dtype(pd_dtype.numpy_dtype)
elif isinstance(pd_dtype, PANDAS_NUMPY_DTYPE):
Expand Down Expand Up @@ -368,7 +369,7 @@ def __init__(self, element_type: Any) -> None:
self._typ = pa.list_(element_type._typ)
else:
element_type = cudf.utils.dtypes.cudf_dtype_to_pa_type(
element_type
cudf.dtype(element_type)
)
self._typ = pa.list_(element_type)

Expand Down Expand Up @@ -579,7 +580,7 @@ class StructDtype(_BaseDtype):

def __init__(self, fields):
pa_fields = {
k: cudf.utils.dtypes.cudf_dtype_to_pa_type(v)
k: cudf.utils.dtypes.cudf_dtype_to_pa_type(cudf.dtype(v))
for k, v in fields.items()
}
self._typ = pa.struct(pa_fields)
Expand Down Expand Up @@ -1036,21 +1037,13 @@ def _is_categorical_dtype(obj):
return False
if isinstance(obj, str) and obj == "category":
return True
if isinstance(obj, cudf.core.index.BaseIndex):
return obj._is_categorical()
if isinstance(
obj,
(
cudf.Series,
cudf.core.column.ColumnBase,
pd.Index,
pd.Series,
),
(cudf.core.index.BaseIndex, cudf.core.column.ColumnBase, cudf.Series),
):
try:
return isinstance(cudf.dtype(obj.dtype), cudf.CategoricalDtype)
except TypeError:
return False
return isinstance(obj.dtype, cudf.CategoricalDtype)
if isinstance(obj, (pd.Series, pd.Index)):
return isinstance(obj.dtype, pd.CategoricalDtype)
if hasattr(obj, "type"):
if obj.type is pd.CategoricalDtype.type:
return True
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.mixins import BinaryOperand, Scannable
from cudf.utils import ioutils
from cudf.utils.dtypes import find_common_type
from cudf.utils.dtypes import CUDF_STRING_DTYPE, find_common_type
from cudf.utils.performance_tracking import _performance_tracking
from cudf.utils.utils import _array_ufunc, _warn_no_dask_cudf

Expand Down Expand Up @@ -1008,7 +1008,7 @@ def from_arrow(cls, data: pa.Table) -> Self:
# is specified as 'empty' and np_dtypes as 'object',
# hence handling this special case to type-cast the empty
# float column to str column.
result[name] = result[name].astype(cudf.dtype("str"))
result[name] = result[name].astype(CUDF_STRING_DTYPE)
elif name in data.column_names and isinstance(
data[name].type,
(
Expand Down
5 changes: 3 additions & 2 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from cudf.core.single_column_frame import SingleColumnFrame
from cudf.utils.docutils import copy_docstring
from cudf.utils.dtypes import (
CUDF_STRING_DTYPE,
SIZE_TYPE_DTYPE,
_maybe_convert_to_default_type,
cudf_dtype_from_pa_type,
Expand Down Expand Up @@ -1728,12 +1729,12 @@ def append(self, other):
if is_mixed_with_object_dtype(this, other):
got_dtype = (
other.dtype
if this.dtype == cudf.dtype("object")
if this.dtype == CUDF_STRING_DTYPE
else this.dtype
)
raise TypeError(
f"cudf does not support appending an Index of "
f"dtype `{cudf.dtype('object')}` with an Index "
f"dtype `{CUDF_STRING_DTYPE}` with an Index "
f"of dtype `{got_dtype}`, please type-cast "
f"either one of them to same dtypes."
)
Expand Down
Loading

0 comments on commit a7e0257

Please sign in to comment.