Skip to content

Commit

Permalink
Remove support for pyarrow extension types on the Ray runner.
Browse files Browse the repository at this point in the history
  • Loading branch information
clarkzinzow committed May 18, 2023
1 parent 2728bcf commit b46c42a
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 1 deletion.
7 changes: 7 additions & 0 deletions daft/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pyarrow as pa

from daft.context import get_context
from daft.daft import PyDataType

_RAY_DATA_EXTENSIONS_AVAILABLE = True
Expand Down Expand Up @@ -178,6 +179,12 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType:
f"used in non-Python Arrow implementations and Daft uses the Rust Arrow2 implementation: {arrow_type}"
)
elif isinstance(arrow_type, pa.BaseExtensionType):
if get_context().runner_config.name == "ray":
raise ValueError(
f"pyarrow extension types are not supported for the Ray runner: {arrow_type}. If you need support "
"for this, please let us know on this issue: "
"https://github.com/Eventual-Inc/Daft/issues/933"
)
name = arrow_type.extension_name
try:
metadata = arrow_type.__arrow_ext_serialize__().decode()
Expand Down
25 changes: 25 additions & 0 deletions tests/dataframe/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import daft
from daft.api_annotations import APITypeError
from daft.context import get_context
from daft.dataframe import DataFrame
from daft.datatype import DataType
from tests.conftest import UuidType
Expand Down Expand Up @@ -173,6 +174,10 @@ def test_create_dataframe_arrow_tensor_ray(valid_data: list[dict[str, float]]) -
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
)
@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="Pickling canonical tensor extension type is not supported by pyarrow",
)
def test_create_dataframe_arrow_tensor_canonical(valid_data: list[dict[str, float]]) -> None:
pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()}
dtype = pa.fixed_shape_tensor(pa.int64(), (2, 2))
Expand All @@ -191,6 +196,10 @@ def test_create_dataframe_arrow_tensor_canonical(valid_data: list[dict[str, floa
assert df.to_arrow() == expected


@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="pyarrow extension types aren't supported on Ray clusters.",
)
def test_create_dataframe_arrow_extension_type(valid_data: list[dict[str, float]], uuid_ext_type: UuidType) -> None:
pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()}
storage = pa.array([f"{i}".encode() for i in range(len(valid_data))])
Expand All @@ -207,6 +216,22 @@ def test_create_dataframe_arrow_extension_type(valid_data: list[dict[str, float]
assert df.to_arrow() == expected


# TODO(Clark): Remove this test once pyarrow extension types are supported for Ray clusters.
@pytest.mark.skipif(
get_context().runner_config.name != "ray",
reason="This test requires the Ray runner.",
)
def test_create_dataframe_arrow_extension_type_fails_for_ray(
valid_data: list[dict[str, float]], uuid_ext_type: UuidType
) -> None:
pydict = {k: [item[k] for item in valid_data] for k in valid_data[0].keys()}
storage = pa.array([f"{i}".encode() for i in range(len(valid_data))])
pydict["obj"] = pa.ExtensionArray.from_storage(uuid_ext_type, storage)
t = pa.Table.from_pydict(pydict)
with pytest.raises(ValueError):
daft.from_arrow(t).to_arrow()


class PyExtType(pa.PyExtensionType):
def __init__(self):
pa.PyExtensionType.__init__(self, pa.binary())
Expand Down
9 changes: 9 additions & 0 deletions tests/series/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ray.data.extensions import ArrowTensorArray

from daft import DataType, Series
from daft.context import get_context
from tests.conftest import *
from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES

Expand Down Expand Up @@ -115,6 +116,10 @@ def test_series_concat_tensor_array_ray(chunks) -> None:
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
)
@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="Pickling canonical tensor extension type is not supported by pyarrow",
)
@pytest.mark.parametrize("chunks", [1, 2, 3, 10])
def test_series_concat_tensor_array_canonical(chunks) -> None:
element_shape = (2, 2)
Expand All @@ -141,6 +146,10 @@ def test_series_concat_tensor_array_canonical(chunks) -> None:
np.testing.assert_equal(concated_arrow.to_numpy_ndarray(), expected)


@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="pyarrow extension types aren't supported on Ray clusters.",
)
@pytest.mark.parametrize("chunks", [1, 2, 3, 10])
def test_series_concat_extension_type(uuid_ext_type, chunks) -> None:
chunk_size = 3
Expand Down
5 changes: 5 additions & 0 deletions tests/series/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pyarrow as pa
import pytest

from daft.context import get_context
from daft.datatype import DataType
from daft.series import Series
from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES
Expand Down Expand Up @@ -132,6 +133,10 @@ def test_series_filter_on_extension_array(uuid_ext_type) -> None:
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
)
@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="Pickling canonical tensor extension type is not supported by pyarrow",
)
def test_series_filter_on_canonical_tensor_extension_array() -> None:
arr = np.arange(20).reshape((5, 2, 2))
data = pa.FixedShapeTensorArray.from_numpy_ndarray(arr)
Expand Down
9 changes: 9 additions & 0 deletions tests/series/test_if_else.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from daft import Series
from daft.context import get_context
from daft.datatype import DataType

ARROW_VERSION = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric())
Expand Down Expand Up @@ -296,6 +297,10 @@ def test_series_if_else_struct(if_true, if_false, expected) -> None:
assert result.to_pylist() == expected


@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="pyarrow extension types aren't supported on Ray clusters.",
)
@pytest.mark.parametrize(
["if_true_storage", "if_false_storage", "expected_storage"],
[
Expand Down Expand Up @@ -346,6 +351,10 @@ def test_series_if_else_extension_type(uuid_ext_type, if_true_storage, if_false_
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
)
@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="Pickling canonical tensor extension type is not supported by pyarrow",
)
@pytest.mark.parametrize(
["if_true", "if_false", "expected"],
[
Expand Down
9 changes: 9 additions & 0 deletions tests/series/test_size_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pyarrow as pa
import pytest

from daft.context import get_context
from daft.datatype import DataType
from daft.series import Series
from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES
Expand Down Expand Up @@ -181,6 +182,10 @@ def test_series_struct_size_bytes(size, with_nulls) -> None:
assert s.size_bytes() == get_total_buffer_size(data) + conversion_to_large_string_bytes


@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="pyarrow extension types aren't supported on Ray clusters.",
)
@pytest.mark.parametrize("size", [1, 2, 8, 9, 16])
@pytest.mark.parametrize("with_nulls", [True, False])
def test_series_extension_type_size_bytes(uuid_ext_type, size, with_nulls) -> None:
Expand All @@ -207,6 +212,10 @@ def test_series_extension_type_size_bytes(uuid_ext_type, size, with_nulls) -> No
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
)
@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="Pickling canonical tensor extension type is not supported by pyarrow",
)
@pytest.mark.parametrize("dtype, size", itertools.product(ARROW_INT_TYPES + ARROW_FLOAT_TYPES, [0, 1, 2, 8, 9, 16]))
@pytest.mark.parametrize("with_nulls", [True, False])
def test_series_canonical_tensor_extension_type_size_bytes(dtype, size, with_nulls) -> None:
Expand Down
9 changes: 9 additions & 0 deletions tests/series/test_take.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pyarrow as pa
import pytest

from daft.context import get_context
from daft.datatype import DataType
from daft.series import Series
from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES
Expand Down Expand Up @@ -118,6 +119,10 @@ def test_series_struct_take() -> None:
assert result.to_pylist() == expected


@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="pyarrow extension types aren't supported on Ray clusters.",
)
def test_series_extension_type_take(uuid_ext_type) -> None:
pydata = [f"{i}".encode() for i in range(6)]
pydata[2] = None
Expand All @@ -143,6 +148,10 @@ def test_series_extension_type_take(uuid_ext_type) -> None:
ARROW_VERSION < (12, 0, 0),
reason=f"Arrow version {ARROW_VERSION} doesn't support the canonical tensor extension type.",
)
@pytest.mark.skipif(
get_context().runner_config.name == "ray",
reason="Pickling canonical tensor extension type is not supported by pyarrow",
)
def test_series_canonical_tensor_extension_type_take() -> None:
pydata = np.arange(24).reshape((6, 4)).tolist()
pydata[2] = None
Expand Down
6 changes: 5 additions & 1 deletion tests/table/test_from_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ray.data.extensions import ArrowTensorArray, ArrowTensorType

from daft import DataType
from daft.context import get_context
from daft.series import Series
from daft.table import Table

Expand Down Expand Up @@ -119,14 +120,17 @@
"timestamp": pa.timestamp("us"),
}

if ARROW_VERSION >= (12, 0, 0):
if ARROW_VERSION >= (12, 0, 0) and get_context().runner_config.name == "ray":
ARROW_ROUNDTRIP_TYPES["canonical_tensor"] = pa.fixed_shape_tensor(pa.int64(), (2, 2))
ARROW_TYPE_ARRAYS["canonical_tensor"] = pa.FixedShapeTensorArray.from_numpy_ndarray(
np.array(PYTHON_TYPE_ARRAYS["tensor"])
)


def _with_uuid_ext_type(uuid_ext_type) -> tuple[dict, dict]:
if get_context().runner_config.name == "ray":
# pyarrow extension types aren't supported in Ray clusters yet.
return ARROW_ROUNDTRIP_TYPES, ARROW_TYPE_ARRAYS
arrow_roundtrip_types = ARROW_ROUNDTRIP_TYPES.copy()
arrow_type_arrays = ARROW_TYPE_ARRAYS.copy()
arrow_roundtrip_types["ext_type"] = uuid_ext_type
Expand Down

0 comments on commit b46c42a

Please sign in to comment.