-
Notifications
You must be signed in to change notification settings - Fork 28.8k
[SPARK-53029][PYTHON] Support return type coercion for Arrow Python UDTFs #52140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
83ad11a
f27409a
25a9742
f6ff4c7
481a3a5
6df8f57
e29f5d1
44a46b0
505db61
65fa7a1
e49e59b
0b5337b
4ed23a4
8174be8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -204,6 +204,25 @@ class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer): | |
def __init__(self, table_arg_offsets=None): | ||
super().__init__() | ||
self.table_arg_offsets = table_arg_offsets if table_arg_offsets else [] | ||
self._arrow_cast = True | ||
|
||
def _create_array(self, arr, arrow_type): | ||
import pyarrow as pa | ||
|
||
assert isinstance(arr, pa.Array) | ||
assert isinstance(arrow_type, pa.DataType) | ||
|
||
if arr.type == arrow_type: | ||
return arr | ||
else: | ||
try: | ||
# when safe is True, the cast will fail if there's a overflow or other unsafe conversion | ||
return arr.cast(target_type=arrow_type, safe=True) | ||
except (pa.ArrowInvalid, pa.ArrowTypeError): | ||
raise PySparkTypeError( | ||
"Arrow UDTFs require the return type to match the expected Arrow type. " | ||
f"Expected: {arrow_type}, but got: {arr.type}." | ||
) | ||
|
||
def load_stream(self, stream): | ||
""" | ||
|
@@ -227,6 +246,58 @@ def load_stream(self, stream): | |
result_batches.append(batch.column(i)) | ||
yield result_batches | ||
|
||
def dump_stream(self, iterator, stream): | ||
""" | ||
Override to handle type coercion for ArrowUDTF outputs. | ||
ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. | ||
""" | ||
import pyarrow as pa | ||
from pyspark.serializers import write_int, SpecialLengths | ||
|
||
def wrap_and_init_stream(): | ||
should_write_start_length = True | ||
for packed in iterator: | ||
batch, arrow_return_type = packed | ||
assert isinstance( | ||
arrow_return_type, pa.StructType | ||
), f"Expected pa.StructType, got {type(arrow_return_type)}" | ||
|
||
# Handle empty struct case specially | ||
if batch.num_columns == 0: | ||
# When batch has no column, it should still create | ||
# an empty batch with the number of rows set. | ||
struct = pa.array([{}] * batch.num_rows) | ||
coerced_batch = pa.RecordBatch.from_arrays([struct], ["_0"]) | ||
Comment on lines
+266
to
+270
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need to handle this case? cc @ueshin There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to ensure the test case "test_arrow_udtf_with_empty_column_result" to work. Please refer to #52140 (comment) comment for the unexpected behavior change. |
||
else: | ||
# Apply type coercion to each column if needed | ||
coerced_arrays = [] | ||
for i, field in enumerate(arrow_return_type): | ||
if i < batch.num_columns: | ||
original_array = batch.column(i) | ||
coerced_array = self._create_array(original_array, field.type) | ||
coerced_arrays.append(coerced_array) | ||
else: | ||
raise PySparkRuntimeError( | ||
errorClass="UDTF_RETURN_SCHEMA_MISMATCH", | ||
messageParameters={ | ||
"expected": str(len(arrow_return_type)), | ||
"actual": str(batch.num_columns), | ||
"func": "ArrowUDTF", | ||
}, | ||
) | ||
Comment on lines
+280
to
+287
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. I think we already checked if the return column mismatch the expected return schema in worker.py. Would you mind double check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The column length is checked before it. Please take a look at: if result.num_columns != return_type_size:
... in |
||
|
||
struct = pa.StructArray.from_arrays(coerced_arrays, fields=arrow_return_type) | ||
coerced_batch = pa.RecordBatch.from_arrays([struct], ["_0"]) | ||
|
||
# Write the first record batch with initialization | ||
if should_write_start_length: | ||
write_int(SpecialLengths.START_ARROW_STREAM, stream) | ||
should_write_start_length = False | ||
|
||
yield coerced_batch | ||
Comment on lines
+293
to
+297
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are done in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm just wondering whether we can use |
||
|
||
return ArrowStreamSerializer.dump_stream(self, wrap_and_init_stream(), stream) | ||
|
||
|
||
class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -178,18 +178,23 @@ def eval(self) -> Iterator["pa.Table"]: | |
result_df.collect() | ||
|
||
def test_arrow_udtf_error_mismatched_schema(self): | ||
@arrow_udtf(returnType="x int, y string") | ||
|
||
@arrow_udtf(returnType="x int, y int") | ||
class MismatchedSchemaUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
result_table = pa.table( | ||
{ | ||
"wrong_col": pa.array([1], type=pa.int32()), | ||
"another_wrong_col": pa.array([2.5], type=pa.float64()), | ||
"col_with_arrow_cast": pa.array([1], type=pa.int32()), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if we have input to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it will. We had a test case "test_return_type_coercion_overflow" |
||
"wrong_col": pa.array(["wrong_col"], type=pa.string()), | ||
} | ||
) | ||
yield result_table | ||
|
||
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): | ||
with self.assertRaisesRegex( | ||
PythonException, | ||
"Arrow UDTFs require the return type to match the expected Arrow type." + | ||
"Expected: int32, but got: string.", | ||
): | ||
result_df = MismatchedSchemaUDTF() | ||
result_df.collect() | ||
|
||
|
@@ -330,9 +335,10 @@ def eval(self) -> Iterator["pa.Table"]: | |
) | ||
yield result_table | ||
|
||
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): | ||
result_df = LongToIntUDTF() | ||
result_df.collect() | ||
# Should succeed with automatic coercion | ||
result_df = LongToIntUDTF() | ||
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int") | ||
assertDataFrameEqual(result_df, expected_df) | ||
|
||
def test_arrow_udtf_type_coercion_string_to_int(self): | ||
@arrow_udtf(returnType="id int") | ||
|
@@ -341,15 +347,87 @@ def eval(self) -> Iterator["pa.Table"]: | |
# Return string values that cannot be coerced to int | ||
result_table = pa.table( | ||
{ | ||
"id": pa.array(["abc", "def", "xyz"], type=pa.string()), | ||
"id": pa.array(["1", "2", "xyz"], type=pa.string()), | ||
} | ||
) | ||
yield result_table | ||
|
||
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): | ||
|
||
# Should fail with Arrow cast exception since string cannot be cast to int | ||
with self.assertRaisesRegex( | ||
PythonException, | ||
"Arrow UDTFs require the return type to match the expected Arrow type." + | ||
"Expected: int32, but got: string.", | ||
): | ||
result_df = StringToIntUDTF() | ||
result_df.collect() | ||
|
||
def test_arrow_udtf_type_corecion_int64_to_int32_safe(self): | ||
@arrow_udtf(returnType="id int") | ||
class Int64ToInt32UDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
result_table = pa.table( | ||
{ | ||
"id": pa.array([1, 2, 3], type=pa.int64()), # long values | ||
} | ||
) | ||
yield result_table | ||
|
||
result_df = Int64ToInt32UDTF() | ||
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int") | ||
assertDataFrameEqual(result_df, expected_df) | ||
|
||
def test_return_type_coercion_success(self): | ||
@arrow_udtf(returnType="value int") | ||
class CoercionSuccessUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
result_table = pa.table( | ||
{ | ||
"value": pa.array([10, 20, 30], type=pa.int64()), # long -> int coercion | ||
} | ||
) | ||
yield result_table | ||
|
||
result_df = CoercionSuccessUDTF() | ||
expected_df = self.spark.createDataFrame([(10,), (20,), (30,)], "value int") | ||
assertDataFrameEqual(result_df, expected_df) | ||
|
||
def test_return_type_coercion_overflow(self): | ||
@arrow_udtf(returnType="value int") | ||
class CoercionOverflowUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
# Return values that will cause overflow when casting long to int | ||
result_table = pa.table( | ||
{ | ||
"value": pa.array([2147483647 + 1], type=pa.int64()), # int32 max + 1 | ||
} | ||
) | ||
yield result_table | ||
|
||
# Should fail with PyArrow overflow exception | ||
with self.assertRaises(Exception): | ||
result_df = CoercionOverflowUDTF() | ||
result_df.collect() | ||
|
||
def test_return_type_coercion_multiple_columns(self): | ||
@arrow_udtf(returnType="id int, price float") | ||
class MultipleColumnCoercionUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
result_table = pa.table( | ||
{ | ||
"id": pa.array([1, 2, 3], type=pa.int64()), # long -> int coercion | ||
"price": pa.array( | ||
[10.5, 20.7, 30.9], type=pa.float64() | ||
), # double -> float coercion | ||
} | ||
) | ||
yield result_table | ||
|
||
result_df = MultipleColumnCoercionUDTF() | ||
expected_df = self.spark.createDataFrame( | ||
[(1, 10.5), (2, 20.7), (3, 30.9)], "id int, price float" | ||
) | ||
assertDataFrameEqual(result_df, expected_df) | ||
|
||
def test_arrow_udtf_with_empty_column_result(self): | ||
@arrow_udtf(returnType=StructType()) | ||
class EmptyResultUDTF: | ||
|
@@ -612,6 +690,7 @@ class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase): | |
pass | ||
|
||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. Could you run: ./dev/reformat-python to make the linter happy? |
||
if __name__ == "__main__": | ||
from pyspark.sql.tests.arrow.test_arrow_udtf import * # noqa: F401 | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1344,7 +1344,9 @@ def read_udtf(pickleSer, infile, eval_type): | |
num_table_arg_offsets = read_int(infile) | ||
table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)] | ||
# Use PyArrow-native serializer for Arrow UDTFs with potential UDT support | ||
ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets) | ||
ser = ArrowStreamArrowUDTFSerializer( | ||
table_arg_offsets=table_arg_offsets | ||
) | ||
Comment on lines
+1347
to
+1349
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like an unnecessary change? |
||
else: | ||
# Each row is a group so do not batch but send one by one. | ||
ser = BatchedSerializer(CPickleSerializer(), 1) | ||
|
@@ -1970,14 +1972,8 @@ def verify_result(result): | |
}, | ||
) | ||
|
||
# Verify the type and the schema of the result. | ||
verify_arrow_result( | ||
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))), | ||
assign_cols_by_name=False, | ||
expected_cols_and_types=[ | ||
(col.name, to_arrow_type(col.dataType)) for col in return_type.fields | ||
], | ||
) | ||
# We verify the type of the result in the serializer | ||
# as we now support type corerion in return values | ||
return result | ||
|
||
# Wrap the exception thrown from the UDTF in a PySparkRuntimeError. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4003,6 +4003,7 @@ object SQLConf { | |
.booleanConf | ||
.createWithDefault(false) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: revert this? |
||
val PYTHON_UDF_LEGACY_PANDAS_CONVERSION_ENABLED = | ||
buildConf("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled") | ||
.internal() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: