From 5175cb03f8c16672441b350415999e65da02c04c Mon Sep 17 00:00:00 2001 From: Mark Zhao Date: Wed, 7 Dec 2022 04:46:39 +0000 Subject: [PATCH 1/3] Allow `from_arrow` with list column --- torcharrow/_interop.py | 2 ++ torcharrow/velox_rt/list_column_cpu.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/torcharrow/_interop.py b/torcharrow/_interop.py index 81e86f6e6..53cfdb100 100644 --- a/torcharrow/_interop.py +++ b/torcharrow/_interop.py @@ -202,6 +202,8 @@ def _arrowtype_to_dtype(t: pa.DataType, nullable: bool) -> dt.DType: [dt.Field(f.name, _arrowtype_to_dtype(f.type, f.nullable)) for f in t], nullable, ) + if pa.types.is_list(t): + return dt.List(_arrowtype_to_dtype(t.value_type, nullable), nullable=nullable) raise NotImplementedError(f"Unsupported Arrow type: {str(t)}") diff --git a/torcharrow/velox_rt/list_column_cpu.py b/torcharrow/velox_rt/list_column_cpu.py index 3f7555614..2d4107472 100644 --- a/torcharrow/velox_rt/list_column_cpu.py +++ b/torcharrow/velox_rt/list_column_cpu.py @@ -81,6 +81,15 @@ def _from_pysequence(device: str, data: List[List], dtype: dt.List): for i in data: col._append(i) return col._finalize() + + @staticmethod + def _from_arrow(device: str, array, dtype: dt.List): + import pyarrow as pa + + assert isinstance(array, pa.Array) + + pydata = array.to_pylist() + return ListColumnCpu._from_pysequence(device, pydata, dtype) def _append_null(self): if self._finalized: @@ -293,3 +302,6 @@ def vmap(self, fun: Callable[[Column], Column]): Dispatcher.register( (dt.List.typecode + "_from_pysequence", "cpu"), ListColumnCpu._from_pysequence ) +Dispatcher.register( + (dt.List.typecode + "_from_arrow", "cpu"), ListColumnCpu._from_arrow +) From b1230246c6b9dfa0cfe50f5eb60a8595492206c7 Mon Sep 17 00:00:00 2001 From: Mark Zhao Date: Wed, 7 Dec 2022 05:59:49 +0000 Subject: [PATCH 2/3] add test --- torcharrow/interop_arrow.py | 2 ++ torcharrow/test/test_arrow_interop.py | 33 +++++++++++++++-------- torcharrow/test/test_arrow_interop_cpu.py | 4 +-- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/torcharrow/interop_arrow.py b/torcharrow/interop_arrow.py index c85c9fde1..7eff5678f 100644 --- a/torcharrow/interop_arrow.py +++ b/torcharrow/interop_arrow.py @@ -35,6 +35,8 @@ def _from_arrow_array( # increase the amount of places we can use the from_arrow result # pyre-fixme[16]: `Array` has no attribute `type`. dtype_from_arrowtype = _arrowtype_to_dtype(array.type, array.null_count > 0) + # TODO: following fails with nullable lists. + # handle comparison for nullable lists if dtype and ( dt.get_underlying_dtype(dtype) != dt.get_underlying_dtype(dtype_from_arrowtype) ): diff --git a/torcharrow/test/test_arrow_interop.py b/torcharrow/test/test_arrow_interop.py index 965fba6c3..94b24a034 100644 --- a/torcharrow/test/test_arrow_interop.py +++ b/torcharrow/test/test_arrow_interop.py @@ -26,6 +26,7 @@ class TestArrowInterop(unittest.TestCase): (pa.float64(), dt.Float64(True)), (pa.string(), dt.String(True)), (pa.large_string(), dt.String(True)), + (pa.list_(pa.int64()), dt.List(dt.Int64(True), True)) ) unsupported_types: Tuple[pa.DataType, ...] = ( @@ -44,7 +45,6 @@ class TestArrowInterop(unittest.TestCase): pa.binary(), pa.large_binary(), pa.decimal128(38), - pa.list_(pa.int64()), pa.large_list(pa.int64()), pa.map_(pa.int64(), pa.int64()), pa.dictionary(pa.int64(), pa.int64()), @@ -420,19 +420,30 @@ def base_test_table_memory_reclaimed(self): del df self.assertEqual(pa.total_allocated_bytes(), initial_memory) - def base_test_table_unsupported_types(self): + def base_test_from_arrow_table_with_list(self): pt = pa.table( { - "f1": pa.array([1, 2, 3], type=pa.int64()), - "f2": pa.array(["foo", "bar", None], type=pa.string()), - "f3": pa.array([[1, 2], [3, 4, 5], [6]], type=pa.list_(pa.int8())), - } - ) - with self.assertRaises(RuntimeError) as ex: - df = ta.from_arrow(pt, device=self.device) - self.assertTrue( - f"Unsupported Arrow type: {str(pt.field(2).type)}" in str(ex.exception) + "f1":[1, 2, 3], + "f2": ["foo", "bar", None], + "f3": [[1, 2], [3, 4, 5], [6]], + }, + schema=pa.schema( + [ + pa.field("f1", pa.int64(), nullable=True), + pa.field("f2", pa.string(), nullable=True), + pa.field("f3", pa.list_(pa.int8()), nullable=False), + ] + ) ) + df = ta.from_arrow(pt, device=self.device) + for (i, ta_field) in enumerate(df.dtype.fields): + pa_field = pt.schema.field(i) + self.assertEqual(ta_field.name, pa_field.name) + self.assertEqual( + ta_field.dtype, _arrowtype_to_dtype(pa_field.type, pa_field.nullable) + ) + self.assertEqual(list(df[ta_field.name]), pt[i].to_pylist()) + def base_test_nullability(self): pydata = [1, 2, 3] diff --git a/torcharrow/test/test_arrow_interop_cpu.py b/torcharrow/test/test_arrow_interop_cpu.py index e40a8314b..dc036b2ac 100644 --- a/torcharrow/test/test_arrow_interop_cpu.py +++ b/torcharrow/test/test_arrow_interop_cpu.py @@ -76,8 +76,8 @@ def test_table_ownership_transferred(self): def test_table_memory_reclaimed(self): return self.base_test_table_memory_reclaimed() - def test_table_unsupported_types(self): - return self.base_test_table_unsupported_types() + def test_from_arrow_table_with_list(self): + return self.base_test_from_arrow_table_with_list() def test_nullability(self): return self.base_test_nullability() From 5382fc0d87338bd5975463f163e22bb9093000c8 Mon Sep 17 00:00:00 2001 From: Mark Zhao Date: Wed, 7 Dec 2022 22:13:25 +0000 Subject: [PATCH 3/3] Handle nullable lists in from_arrow --- torcharrow/dtypes.py | 2 ++ torcharrow/interop_arrow.py | 2 -- torcharrow/test/test_arrow_interop.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torcharrow/dtypes.py b/torcharrow/dtypes.py index e49abcbb0..272d808e1 100644 --- a/torcharrow/dtypes.py +++ b/torcharrow/dtypes.py @@ -580,6 +580,8 @@ def cast_as(dtype): def get_underlying_dtype(dtype: DType) -> DType: + if is_list(dtype): + return replace(dtype, nullable=False, item_dtype=replace(dtype.item_dtype, nullable=False)) return replace(dtype, nullable=False) diff --git a/torcharrow/interop_arrow.py b/torcharrow/interop_arrow.py index 7eff5678f..c85c9fde1 100644 --- a/torcharrow/interop_arrow.py +++ b/torcharrow/interop_arrow.py @@ -35,8 +35,6 @@ def _from_arrow_array( # increase the amount of places we can use the from_arrow result # pyre-fixme[16]: `Array` has no attribute `type`. dtype_from_arrowtype = _arrowtype_to_dtype(array.type, array.null_count > 0) - # TODO: following fails with nullable lists. - # handle comparison for nullable lists if dtype and ( dt.get_underlying_dtype(dtype) != dt.get_underlying_dtype(dtype_from_arrowtype) ): diff --git a/torcharrow/test/test_arrow_interop.py b/torcharrow/test/test_arrow_interop.py index 94b24a034..cbed2134f 100644 --- a/torcharrow/test/test_arrow_interop.py +++ b/torcharrow/test/test_arrow_interop.py @@ -426,12 +426,14 @@ def base_test_from_arrow_table_with_list(self): "f1":[1, 2, 3], "f2": ["foo", "bar", None], "f3": [[1, 2], [3, 4, 5], [6]], + "f4": [[1, 2], [3, 4, 5], [6]], }, schema=pa.schema( [ pa.field("f1", pa.int64(), nullable=True), pa.field("f2", pa.string(), nullable=True), - pa.field("f3", pa.list_(pa.int8()), nullable=False), + pa.field("f3", pa.list_(pa.int8()), nullable=True), + pa.field("f4", pa.list_(pa.int8()), nullable=False), ] ) )