Skip to content

POC: PDEP16 default to masked nullable dtypes #61716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2786,17 +2786,21 @@ def maybe_convert_objects(ndarray[object] objects,
seen.object_ = True

elif seen.str_:
if convert_to_nullable_dtype and is_string_array(objects, skipna=True):
from pandas.core.arrays.string_ import StringDtype
if is_string_array(objects, skipna=True):
if convert_to_nullable_dtype:
from pandas.core.arrays.string_ import StringDtype

dtype = StringDtype()
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)
if using_string_dtype():
dtype = StringDtype(na_value=np.nan)
else:
dtype = StringDtype()
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)

elif using_string_dtype() and is_string_array(objects, skipna=True):
from pandas.core.arrays.string_ import StringDtype
elif using_string_dtype():
from pandas.core.arrays.string_ import StringDtype

dtype = StringDtype(na_value=np.nan)
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)
dtype = StringDtype(na_value=np.nan)
return dtype.construct_array_type()._from_sequence(objects, dtype=dtype)

seen.object_ = True
elif seen.interval_:
Expand Down
5 changes: 5 additions & 0 deletions pandas/_testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
NpDtype,
)

# Alias so we can update old `assert obj.dtype == np_dtype` checks to PDEP16
# behavior.
to_dtype = pd.core.dtypes.common.pandas_dtype

UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"]
UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
Expand Down Expand Up @@ -304,6 +307,8 @@ def box_expected(expected, box_cls, transpose: bool = True):
expected = pd.concat([expected] * 2, ignore_index=True)
elif box_cls is np.ndarray or box_cls is np.array:
expected = np.array(expected)
if expected.dtype.kind in "iufb" and pd.get_option("mode.pdep16_data_types"):
expected = pd.array(expected, copy=False)
elif box_cls is to_array:
expected = to_array(expected)
else:
Expand Down
18 changes: 12 additions & 6 deletions pandas/_testing/asserters.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,13 +323,19 @@ def _check_types(left, right, obj: str = "Index") -> None:
elif check_exact and check_categorical:
if not left.equals(right):
mismatch = left._values != right._values
if isinstance(left, RangeIndex) and not mismatch.any():
# TODO: probably need to fix RangeIndex.equals?
pass
elif isinstance(right, RangeIndex) and not mismatch.any():
# TODO: probably need to fix some other equals method?
pass
else:
if not isinstance(mismatch, np.ndarray):
mismatch = cast("ExtensionArray", mismatch).fillna(True)

if not isinstance(mismatch, np.ndarray):
mismatch = cast("ExtensionArray", mismatch).fillna(True)

diff = np.sum(mismatch.astype(int)) * 100.0 / len(left)
msg = f"{obj} values are different ({np.round(diff, 5)} %)"
raise_assert_detail(obj, msg, left, right)
diff = np.sum(mismatch.astype(int)) * 100.0 / len(left)
msg = f"{obj} values are different ({np.round(diff, 5)} %)"
raise_assert_detail(obj, msg, left, right)
else:
# if we have "equiv", this becomes True
exact_bool = bool(exact)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
# pass those through to the underlying ndarray
return self._ndarray.view(dtype)

dtype = pandas_dtype(dtype)
dtype = pandas_dtype(dtype, allow_numpy_dtypes=True)
arr = self._ndarray

if isinstance(dtype, PeriodDtype):
Expand Down
13 changes: 12 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import numpy as np

from pandas._config import get_option

from pandas._libs import (
algos as libalgos,
lib,
Expand Down Expand Up @@ -2420,7 +2422,12 @@ def _where(self, mask: npt.NDArray[np.bool_], value) -> Self:
result = self.copy()

if is_list_like(value):
val = value[~mask]
if np.ndim(value) == 1 and len(value) == 1:
# test_where.test_broadcast if we change to use nullable...
# maybe this should be handled at a higher level?
val = value[0]
else:
val = value[~mask]
else:
val = value

Expand Down Expand Up @@ -2655,6 +2662,10 @@ def _groupby_op(
if op.how in op.cast_blocklist:
# i.e. how in ["rank"], since other cast_blocklist methods don't go
# through cython_operation
if get_option("mode.pdep16_data_types"):
from pandas import array as pd_array

return pd_array(res_values)
return res_values

if isinstance(self.dtype, StringDtype):
Expand Down
20 changes: 19 additions & 1 deletion pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,13 @@ def _cmp_method(self, other, op):
try:
other = self._validate_comparison_value(other)
except InvalidComparison:
return invalid_comparison(self, other, op)
res = invalid_comparison(self, other, op)
if get_option("mode.pdep16_data_types"):
res = pd_array(res)
o_mask = isna(other)
mask = self._isnan | o_mask
res[mask] = res.dtype.na_value
return res

dtype = getattr(other, "dtype", None)
if is_object_dtype(dtype):
Expand All @@ -982,12 +988,18 @@ def _cmp_method(self, other, op):
result = ops.comp_method_OBJECT_ARRAY(
op, np.asarray(self.astype(object)), other
)
if get_option("mode.pdep16_data_types"):
result = pd_array(result)
result[self.isna()] = result.dtype.na_value
return result
if other is NaT:
if op is operator.ne:
result = np.ones(self.shape, dtype=bool)
else:
result = np.zeros(self.shape, dtype=bool)
if get_option("mode.pdep16_data_types"):
result = pd_array(result)
result[self.isna()] = result.dtype.na_value
return result

if not isinstance(self.dtype, PeriodDtype):
Expand Down Expand Up @@ -1018,6 +1030,10 @@ def _cmp_method(self, other, op):
nat_result = op is operator.ne
np.putmask(result, mask, nat_result)

if get_option("mode.pdep16_data_types"):
result = pd_array(result)
if mask.any():
result[mask] = result.dtype.na_value
return result

# pow is invalid for all three subclasses; TimedeltaArray will override
Expand Down Expand Up @@ -1702,6 +1718,8 @@ def _groupby_op(
if op.how in op.cast_blocklist:
# i.e. how in ["rank"], since other cast_blocklist methods don't go
# through cython_operation
# if get_option("mode.pdep16_data_types"):
# return pd_array(res_values) # breaks bc they dont support 2D
return res_values

# We did a view to M8[ns] above, now we go the other direction
Expand Down
7 changes: 5 additions & 2 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,9 @@ def _str_map(
elif dtype == np.dtype("bool"):
# GH#55736
na_value = bool(na_value)

dtype = pandas_dtype(dtype)
pass_dtype = dtype.numpy_dtype
result = lib.map_infer_mask(
arr,
f,
Expand All @@ -453,7 +456,7 @@ def _str_map(
# error: Argument 1 to "dtype" has incompatible type
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
# "Type[object]"
dtype=np.dtype(cast(type, dtype)),
dtype=np.dtype(cast(type, pass_dtype)),
)

if not na_value_is_na:
Expand Down Expand Up @@ -837,7 +840,7 @@ def astype(self, dtype, copy: bool = True):
arr_ea = self.copy()
mask = self.isna()
arr_ea[mask] = "0"
values = arr_ea.astype(dtype.numpy_dtype)
values = arr_ea.to_numpy(dtype=dtype.numpy_dtype)
return FloatingArray(values, mask, copy=False)
elif isinstance(dtype, ExtensionDtype):
# Skip the NumpyExtensionArray.astype method
Expand Down
18 changes: 15 additions & 3 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

import numpy as np

from pandas._config import get_option

from pandas._libs import (
lib,
tslibs,
Expand Down Expand Up @@ -59,6 +61,7 @@
from pandas.core.arrays import datetimelike as dtl
from pandas.core.arrays._ranges import generate_regular_range
import pandas.core.common as com
from pandas.core.construction import array as pd_array
from pandas.core.ops.common import unpack_zerodim_and_defer

if TYPE_CHECKING:
Expand Down Expand Up @@ -507,6 +510,10 @@ def __mul__(self, other) -> Self:
# numpy >= 2.1 may not raise a TypeError
# and seems to dispatch to others.__rmul__?
raise TypeError(f"Cannot multiply with {type(other).__name__}")
if isinstance(result, type(self)):
# e.g. if other is IntegerArray
assert result.dtype == self.dtype
return result
return type(self)._simple_new(result, dtype=result.dtype)

__rmul__ = __mul__
Expand All @@ -524,10 +531,13 @@ def _scalar_divlike_op(self, other, op):
# specifically timedelta64-NaT
res = np.empty(self.shape, dtype=np.float64)
res.fill(np.nan)
return res

# otherwise, dispatch to Timedelta implementation
return op(self._ndarray, other)
else:
# otherwise, dispatch to Timedelta implementation
res = op(self._ndarray, other)
if get_option("mode.pdep16_data_types"):
res = pd_array(res)
return res

else:
# caller is responsible for checking lib.is_scalar(other)
Expand Down Expand Up @@ -581,6 +591,8 @@ def _vector_divlike_op(self, other, op) -> np.ndarray | Self:
result = result.astype(np.float64)
np.putmask(result, mask, np.nan)

if get_option("mode.pdep16_data_types"):
result = pd_array(result)
return result

@unpack_zerodim_and_defer("__truediv__")
Expand Down
9 changes: 9 additions & 0 deletions pandas/core/config_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,15 @@ def is_terminal() -> bool:
validator=is_one_of_factory([True, False, "warn"]),
)

with cf.config_prefix("mode"):
cf.register_option(
"pdep16_data_types",
True,
"Whether to default to numpy-nullable dtypes for integer, float, "
"and bool dtypes",
validator=is_one_of_factory([True, False]),
)


# user warnings
chained_assignment = """
Expand Down
14 changes: 11 additions & 3 deletions pandas/core/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import numpy as np
from numpy import ma

from pandas._config import using_string_dtype
from pandas._config import (
get_option,
using_string_dtype,
)

from pandas._libs import lib
from pandas._libs.tslibs import (
Expand Down Expand Up @@ -612,7 +615,9 @@ def sanitize_array(
if dtype is None:
subarr = data
if data.dtype == object and infer_object:
subarr = maybe_infer_to_datetimelike(data)
subarr = maybe_infer_to_datetimelike(
data, convert_to_nullable_dtype=get_option("mode.pdep16_data_types")
)
elif data.dtype.kind == "U" and using_string_dtype():
from pandas.core.arrays.string_ import StringDtype

Expand Down Expand Up @@ -659,7 +664,10 @@ def sanitize_array(
subarr = maybe_convert_platform(data)
if subarr.dtype == object:
subarr = cast(np.ndarray, subarr)
subarr = maybe_infer_to_datetimelike(subarr)
subarr = maybe_infer_to_datetimelike(
subarr,
convert_to_nullable_dtype=get_option("mode.pdep16_data_types"),
)

subarr = _sanitize_ndim(subarr, data, dtype, index, allow_2d=allow_2d)

Expand Down
9 changes: 7 additions & 2 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@

import numpy as np

from pandas._config import using_string_dtype
from pandas._config import (
get_option,
using_string_dtype,
)

from pandas._libs import (
Interval,
Expand Down Expand Up @@ -135,7 +138,9 @@ def maybe_convert_platform(

if arr.dtype == _dtype_obj:
arr = cast(np.ndarray, arr)
arr = lib.maybe_convert_objects(arr)
arr = lib.maybe_convert_objects(
arr, convert_to_nullable_dtype=get_option("mode.pdep16_data_types")
)

return arr

Expand Down
Loading
Loading