Skip to content
71 changes: 71 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,25 @@ class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer):
def __init__(self, table_arg_offsets=None):
super().__init__()
self.table_arg_offsets = table_arg_offsets if table_arg_offsets else []
self._arrow_cast = True

def _create_array(self, arr, arrow_type):
import pyarrow as pa

assert isinstance(arr, pa.Array)
assert isinstance(arrow_type, pa.DataType)

if arr.type == arrow_type:
return arr
else:
try:
# when safe is True, the cast will fail if there's a overflow or other unsafe conversion
return arr.cast(target_type=arrow_type, safe=True)
except (pa.ArrowInvalid, pa.ArrowTypeError):
raise PySparkTypeError(
"Arrow UDTFs require the return type to match the expected Arrow type. "
f"Expected: {arrow_type}, but got: {arr.type}."
)

def load_stream(self, stream):
"""
Expand All @@ -227,6 +246,58 @@ def load_stream(self, stream):
result_batches.append(batch.column(i))
yield result_batches

def dump_stream(self, iterator, stream):
"""
Override to handle type coercion for ArrowUDTF outputs.
ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples.
"""
import pyarrow as pa
from pyspark.serializers import write_int, SpecialLengths

def wrap_and_init_stream():
should_write_start_length = True
for packed in iterator:
batch, arrow_return_type = packed
Comment on lines +259 to +260
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
for packed in iterator:
batch, arrow_return_type = packed
for batch, arrow_return_type in iterator:

assert isinstance(
arrow_return_type, pa.StructType
), f"Expected pa.StructType, got {type(arrow_return_type)}"

# Handle empty struct case specially
if batch.num_columns == 0:
# When batch has no column, it should still create
# an empty batch with the number of rows set.
struct = pa.array([{}] * batch.num_rows)
coerced_batch = pa.RecordBatch.from_arrays([struct], ["_0"])
Comment on lines +266 to +270
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to handle this case? cc @ueshin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to ensure the test case "test_arrow_udtf_with_empty_column_result" to work. Please refer to #52140 (comment) comment for the unexpected behavior change.

else:
# Apply type coercion to each column if needed
coerced_arrays = []
for i, field in enumerate(arrow_return_type):
if i < batch.num_columns:
original_array = batch.column(i)
coerced_array = self._create_array(original_array, field.type)
coerced_arrays.append(coerced_array)
else:
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(len(arrow_return_type)),
"actual": str(batch.num_columns),
"func": "ArrowUDTF",
},
)
Comment on lines +280 to +287
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. I think we already checked if the return column mismatch the expected return schema in worker.py. Would you mind double check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean verify_arrow_result in worker.py? I removed it since verify_arrow_result requires return type to strictly match arrow_return_type in the conversion of pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))).

verify_arrow_result(
    pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
    assign_cols_by_name=False,
    expected_cols_and_types=[
        (col.name, to_arrow_type(col.dataType)) for col in return_type.fields
    ],
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column length is checked before it. Please take a look at:

if result.num_columns != return_type_size:
...

in verify_result.


struct = pa.StructArray.from_arrays(coerced_arrays, fields=arrow_return_type)
coerced_batch = pa.RecordBatch.from_arrays([struct], ["_0"])

# Write the first record batch with initialization
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False

yield coerced_batch
Comment on lines +293 to +297
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are done in the super().dump_stream(). What we should do here is just type-casting.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering whether we can use RecordBatch.cast for this instead of casting each column?


return ArrowStreamSerializer.dump_stream(self, wrap_and_init_stream(), stream)


class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer):
"""
Expand Down
99 changes: 89 additions & 10 deletions python/pyspark/sql/tests/arrow/test_arrow_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,18 +178,23 @@ def eval(self) -> Iterator["pa.Table"]:
result_df.collect()

def test_arrow_udtf_error_mismatched_schema(self):
@arrow_udtf(returnType="x int, y string")

@arrow_udtf(returnType="x int, y int")
class MismatchedSchemaUDTF:
def eval(self) -> Iterator["pa.Table"]:
result_table = pa.table(
{
"wrong_col": pa.array([1], type=pa.int32()),
"another_wrong_col": pa.array([2.5], type=pa.float64()),
"col_with_arrow_cast": pa.array([1], type=pa.int32()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we have input to be int64 and output to be int32? Does arrow cast throw exception in this case?

Copy link
Contributor Author

@shujingyang-db shujingyang-db Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will. We had a test case "test_return_type_coercion_overflow"

"wrong_col": pa.array(["wrong_col"], type=pa.string()),
}
)
yield result_table

with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
with self.assertRaisesRegex(
PythonException,
"Arrow UDTFs require the return type to match the expected Arrow type." +
"Expected: int32, but got: string.",
):
result_df = MismatchedSchemaUDTF()
result_df.collect()

Expand Down Expand Up @@ -330,9 +335,10 @@ def eval(self) -> Iterator["pa.Table"]:
)
yield result_table

with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
result_df = LongToIntUDTF()
result_df.collect()
# Should succeed with automatic coercion
result_df = LongToIntUDTF()
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
assertDataFrameEqual(result_df, expected_df)

def test_arrow_udtf_type_coercion_string_to_int(self):
@arrow_udtf(returnType="id int")
Expand All @@ -341,15 +347,87 @@ def eval(self) -> Iterator["pa.Table"]:
# Return string values that cannot be coerced to int
result_table = pa.table(
{
"id": pa.array(["abc", "def", "xyz"], type=pa.string()),
"id": pa.array(["1", "2", "xyz"], type=pa.string()),
}
)
yield result_table

with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):

# Should fail with Arrow cast exception since string cannot be cast to int
with self.assertRaisesRegex(
PythonException,
"Arrow UDTFs require the return type to match the expected Arrow type." +
"Expected: int32, but got: string.",
):
result_df = StringToIntUDTF()
result_df.collect()

def test_arrow_udtf_type_corecion_int64_to_int32_safe(self):
@arrow_udtf(returnType="id int")
class Int64ToInt32UDTF:
def eval(self) -> Iterator["pa.Table"]:
result_table = pa.table(
{
"id": pa.array([1, 2, 3], type=pa.int64()), # long values
}
)
yield result_table

result_df = Int64ToInt32UDTF()
expected_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
assertDataFrameEqual(result_df, expected_df)

def test_return_type_coercion_success(self):
@arrow_udtf(returnType="value int")
class CoercionSuccessUDTF:
def eval(self) -> Iterator["pa.Table"]:
result_table = pa.table(
{
"value": pa.array([10, 20, 30], type=pa.int64()), # long -> int coercion
}
)
yield result_table

result_df = CoercionSuccessUDTF()
expected_df = self.spark.createDataFrame([(10,), (20,), (30,)], "value int")
assertDataFrameEqual(result_df, expected_df)

def test_return_type_coercion_overflow(self):
@arrow_udtf(returnType="value int")
class CoercionOverflowUDTF:
def eval(self) -> Iterator["pa.Table"]:
# Return values that will cause overflow when casting long to int
result_table = pa.table(
{
"value": pa.array([2147483647 + 1], type=pa.int64()), # int32 max + 1
}
)
yield result_table

# Should fail with PyArrow overflow exception
with self.assertRaises(Exception):
result_df = CoercionOverflowUDTF()
result_df.collect()

def test_return_type_coercion_multiple_columns(self):
@arrow_udtf(returnType="id int, price float")
class MultipleColumnCoercionUDTF:
def eval(self) -> Iterator["pa.Table"]:
result_table = pa.table(
{
"id": pa.array([1, 2, 3], type=pa.int64()), # long -> int coercion
"price": pa.array(
[10.5, 20.7, 30.9], type=pa.float64()
), # double -> float coercion
}
)
yield result_table

result_df = MultipleColumnCoercionUDTF()
expected_df = self.spark.createDataFrame(
[(1, 10.5), (2, 20.7), (3, 30.9)], "id int, price float"
)
assertDataFrameEqual(result_df, expected_df)

def test_arrow_udtf_with_empty_column_result(self):
@arrow_udtf(returnType=StructType())
class EmptyResultUDTF:
Expand Down Expand Up @@ -612,6 +690,7 @@ class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
pass



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Could you run:

./dev/reformat-python

to make the linter happy?

if __name__ == "__main__":
from pyspark.sql.tests.arrow.test_arrow_udtf import * # noqa: F401

Expand Down
14 changes: 5 additions & 9 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,7 +1344,9 @@ def read_udtf(pickleSer, infile, eval_type):
num_table_arg_offsets = read_int(infile)
table_arg_offsets = [read_int(infile) for _ in range(num_table_arg_offsets)]
# Use PyArrow-native serializer for Arrow UDTFs with potential UDT support
ser = ArrowStreamArrowUDTFSerializer(table_arg_offsets=table_arg_offsets)
ser = ArrowStreamArrowUDTFSerializer(
table_arg_offsets=table_arg_offsets
)
Comment on lines +1347 to +1349
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like an unnecessary change?

else:
# Each row is a group so do not batch but send one by one.
ser = BatchedSerializer(CPickleSerializer(), 1)
Expand Down Expand Up @@ -1970,14 +1972,8 @@ def verify_result(result):
},
)

# Verify the type and the schema of the result.
verify_arrow_result(
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
assign_cols_by_name=False,
expected_cols_and_types=[
(col.name, to_arrow_type(col.dataType)) for col in return_type.fields
],
)
# We verify the type of the result in the serializer
# as we now support type corerion in return values
return result

# Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4003,6 +4003,7 @@ object SQLConf {
.booleanConf
.createWithDefault(false)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: revert this?

val PYTHON_UDF_LEGACY_PANDAS_CONVERSION_ENABLED =
buildConf("spark.sql.legacy.execution.pythonUDF.pandas.conversion.enabled")
.internal()
Expand Down