Skip to content

Conversation

shujingyang-db
Copy link
Contributor

What changes were proposed in this pull request?

Support return type coercion for Arrow Python UDTFs by doing arrow_cast by default

Why are the changes needed?

Consistent behavior across Arrow UDFs and Arrow UDTFs

Does this PR introduce any user-facing change?

No, Arrow UDTF is not a public API yet

How was this patch tested?

New and existing UTs

Was this patch authored or co-authored using generative AI tooling?

No

zhengruifeng
zhengruifeng previously approved these changes Aug 27, 2025
@@ -201,9 +201,26 @@ class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer):
Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays.
"""

def __init__(self, table_arg_offsets=None):
def __init__(self, table_arg_offsets=None, arrow_cast=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default value should be True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep! I changed it to True and add a SQLConf to gate it

def test_arrow_udtf_with_empty_column_result(self):
@arrow_udtf(returnType=StructType())
class EmptyResultUDTF:
def eval(self) -> Iterator["pa.Table"]:
yield pa.Table.from_struct_array(pa.array([{}] * 3))

assertDataFrameEqual(EmptyResultUDTF(), [Row(), Row(), Row()])
assertDataFrameEqual(EmptyResultUDTF(), [None, None, None])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this change is unexpected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I have reverted it and create an empty batch with the number of rows set

@github-actions github-actions bot added the CORE label Aug 28, 2025
Copy link
Contributor

@allisonwang-db allisonwang-db left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for supporting this!

Comment on lines +267 to +271
if batch.num_columns == 0:
# When batch has no column, it should still create
# an empty batch with the number of rows set.
struct = pa.array([{}] * batch.num_rows)
coerced_batch = pa.RecordBatch.from_arrays([struct], ["_0"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to handle this case? cc @ueshin

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to ensure the test case "test_arrow_udtf_with_empty_column_result" to work. Please refer to #52140 (comment) comment for the unexpected behavior change.

@@ -201,9 +201,26 @@ class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer):
Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays.
"""

def __init__(self, table_arg_offsets=None):
def __init__(self, table_arg_offsets=None, arrow_cast=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's enable arrow_cast by default for ArrowUDTFs (it's a new feature) so we don't need a flag here.

if arr.type == arrow_type:
return arr
elif self._arrow_cast:
return arr.cast(target_type=arrow_type, safe=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between safe=True vs False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will only allow casts that are guaranteed not to lose information. Truncation (floats to ints), narrowing (int64 → int8), or precision loss are not allowed. Will add a comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zhengruifeng is this the same behavior as Arrow UDFs?

Comment on lines 259 to 261
assert isinstance(
batch, pa.RecordBatch
), f"Expected pa.RecordBatch, got {type(batch)}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we already check this worker.py so no need to duplicate this check :)

Comment on lines +281 to +288
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(len(arrow_return_type)),
"actual": str(batch.num_columns),
"func": "ArrowUDTF",
},
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. I think we already checked if the return column mismatch the expected return schema in worker.py. Would you mind double check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean verify_arrow_result in worker.py? I removed it since verify_arrow_result requires return type to strictly match arrow_return_type in the conversion of pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))).

verify_arrow_result(
    pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
    assign_cols_by_name=False,
    expected_cols_and_types=[
        (col.name, to_arrow_type(col.dataType)) for col in return_type.fields
    ],
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column length is checked before it. Please take a look at:

if result.num_columns != return_type_size:
...

in verify_result.

if arr.type == arrow_type:
return arr
elif self._arrow_cast:
return arr.cast(target_type=arrow_type, safe=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it would be great to list he type coercion rule here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment

with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"):
result_df = MismatchedSchemaUDTF()
result_df.collect()
if self.spark.conf.get("spark.sql.execution.pythonUDTF.typeCoercion.enabled").lower() == "false":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use with self.sql_conf("...")

Comment on lines 4006 to 4007
val PYTHON_TABLE_UDF_TYPE_CORERION_ENABLED =
buildConf("spark.sql.execution.pythonUDTF.typeCoercion.enabled")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's enable Arrow cast for Arrow Python UDTFs by default so we don't need this config :)

Copy link
Contributor Author

@shujingyang-db shujingyang-db Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, on it

Update: done

class MismatchedSchemaUDTF:
def eval(self) -> Iterator["pa.Table"]:
result_table = pa.table(
{
"wrong_col": pa.array([1], type=pa.int32()),
"another_wrong_col": pa.array([2.5], type=pa.float64()),
"col_with_arrow_cast": pa.array([1], type=pa.int32()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we have input to be int64 and output to be int32? Does arrow cast throw exception in this case?

Copy link
Contributor Author

@shujingyang-db shujingyang-db Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will. We had a test case "test_return_type_coercion_overflow"

if arr.type == arrow_type:
return arr
elif self._arrow_cast:
return arr.cast(target_type=arrow_type, safe=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zhengruifeng is this the same behavior as Arrow UDFs?

result_df = MismatchedSchemaUDTF()
result_df.collect()
else:
with self.assertRaisesRegex(PythonException, "Failed to parse string: 'wrong_col' as a scalar of type int32"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm looks like without arrow cast, the error message looks better.

Copy link
Contributor Author

@shujingyang-db shujingyang-db Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a try-catch block to polish the error message with the arrow cast

Comment on lines +257 to +258
for packed in iterator:
batch, arrow_return_type = packed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
for packed in iterator:
batch, arrow_return_type = packed
for batch, arrow_return_type in iterator:

@@ -4003,6 +4003,7 @@ object SQLConf {
.booleanConf
.createWithDefault(false)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: revert this?

Comment on lines +1347 to +1349
ser = ArrowStreamArrowUDTFSerializer(
table_arg_offsets=table_arg_offsets
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like an unnecessary change?

@@ -612,6 +690,7 @@ class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
pass



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Could you run:

./dev/reformat-python

to make the linter happy?

Comment on lines +281 to +288
raise PySparkRuntimeError(
errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
messageParameters={
"expected": str(len(arrow_return_type)),
"actual": str(batch.num_columns),
"func": "ArrowUDTF",
},
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column length is checked before it. Please take a look at:

if result.num_columns != return_type_size:
...

in verify_result.

Comment on lines +293 to +297
if should_write_start_length:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
should_write_start_length = False

yield coerced_batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are done in the super().dump_stream(). What we should do here is just type-casting.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just wondering whether we can use RecordBatch.cast for this instead of casting each column?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants