-
Notifications
You must be signed in to change notification settings - Fork 28.8k
[SPARK-53029][PYTHON] Support return type coercion for Arrow Python UDTFs #52140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[SPARK-53029][PYTHON] Support return type coercion for Arrow Python UDTFs #52140
Conversation
@@ -201,9 +201,26 @@ class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer): | |||
Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays. | |||
""" | |||
|
|||
def __init__(self, table_arg_offsets=None): | |||
def __init__(self, table_arg_offsets=None, arrow_cast=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default value should be True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep! I changed it to True and add a SQLConf to gate it
def test_arrow_udtf_with_empty_column_result(self): | ||
@arrow_udtf(returnType=StructType()) | ||
class EmptyResultUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
yield pa.Table.from_struct_array(pa.array([{}] * 3)) | ||
|
||
assertDataFrameEqual(EmptyResultUDTF(), [Row(), Row(), Row()]) | ||
assertDataFrameEqual(EmptyResultUDTF(), [None, None, None]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this change is unexpected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I have reverted it and create an empty batch with the number of rows set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for supporting this!
if batch.num_columns == 0: | ||
# When batch has no column, it should still create | ||
# an empty batch with the number of rows set. | ||
struct = pa.array([{}] * batch.num_rows) | ||
coerced_batch = pa.RecordBatch.from_arrays([struct], ["_0"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need to handle this case? cc @ueshin
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is to ensure the test case "test_arrow_udtf_with_empty_column_result" to work. Please refer to #52140 (comment) comment for the unexpected behavior change.
@@ -201,9 +201,26 @@ class ArrowStreamArrowUDTFSerializer(ArrowStreamUDTFSerializer): | |||
Serializer for PyArrow-native UDTFs that work directly with PyArrow RecordBatches and Arrays. | |||
""" | |||
|
|||
def __init__(self, table_arg_offsets=None): | |||
def __init__(self, table_arg_offsets=None, arrow_cast=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's enable arrow_cast by default for ArrowUDTFs (it's a new feature) so we don't need a flag here.
if arr.type == arrow_type: | ||
return arr | ||
elif self._arrow_cast: | ||
return arr.cast(target_type=arrow_type, safe=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between safe=True vs False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will only allow casts that are guaranteed not to lose information. Truncation (floats to ints), narrowing (int64 → int8), or precision loss are not allowed. Will add a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zhengruifeng is this the same behavior as Arrow UDFs?
assert isinstance( | ||
batch, pa.RecordBatch | ||
), f"Expected pa.RecordBatch, got {type(batch)}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we already check this worker.py so no need to duplicate this check :)
raise PySparkRuntimeError( | ||
errorClass="UDTF_RETURN_SCHEMA_MISMATCH", | ||
messageParameters={ | ||
"expected": str(len(arrow_return_type)), | ||
"actual": str(batch.num_columns), | ||
"func": "ArrowUDTF", | ||
}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto. I think we already checked if the return column mismatch the expected return schema in worker.py. Would you mind double check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean verify_arrow_result
in worker.py? I removed it since verify_arrow_result requires return type to strictly match arrow_return_type
in the conversion of pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type)))
.
verify_arrow_result(
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
assign_cols_by_name=False,
expected_cols_and_types=[
(col.name, to_arrow_type(col.dataType)) for col in return_type.fields
],
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The column length is checked before it. Please take a look at:
if result.num_columns != return_type_size:
...
in verify_result
.
if arr.type == arrow_type: | ||
return arr | ||
elif self._arrow_cast: | ||
return arr.cast(target_type=arrow_type, safe=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, it would be great to list he type coercion rule here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a comment
with self.assertRaisesRegex(PythonException, "Schema at index 0 was different"): | ||
result_df = MismatchedSchemaUDTF() | ||
result_df.collect() | ||
if self.spark.conf.get("spark.sql.execution.pythonUDTF.typeCoercion.enabled").lower() == "false": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use with self.sql_conf("...")
val PYTHON_TABLE_UDF_TYPE_CORERION_ENABLED = | ||
buildConf("spark.sql.execution.pythonUDTF.typeCoercion.enabled") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's enable Arrow cast for Arrow Python UDTFs by default so we don't need this config :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, on it
Update: done
class MismatchedSchemaUDTF: | ||
def eval(self) -> Iterator["pa.Table"]: | ||
result_table = pa.table( | ||
{ | ||
"wrong_col": pa.array([1], type=pa.int32()), | ||
"another_wrong_col": pa.array([2.5], type=pa.float64()), | ||
"col_with_arrow_cast": pa.array([1], type=pa.int32()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we have input to be int64
and output to be int32
? Does arrow cast throw exception in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will. We had a test case "test_return_type_coercion_overflow"
if arr.type == arrow_type: | ||
return arr | ||
elif self._arrow_cast: | ||
return arr.cast(target_type=arrow_type, safe=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zhengruifeng is this the same behavior as Arrow UDFs?
result_df = MismatchedSchemaUDTF() | ||
result_df.collect() | ||
else: | ||
with self.assertRaisesRegex(PythonException, "Failed to parse string: 'wrong_col' as a scalar of type int32"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm looks like without arrow cast, the error message looks better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a try-catch block to polish the error message with the arrow cast
for packed in iterator: | ||
batch, arrow_return_type = packed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
for packed in iterator: | |
batch, arrow_return_type = packed | |
for batch, arrow_return_type in iterator: |
@@ -4003,6 +4003,7 @@ object SQLConf { | |||
.booleanConf | |||
.createWithDefault(false) | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: revert this?
ser = ArrowStreamArrowUDTFSerializer( | ||
table_arg_offsets=table_arg_offsets | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like an unnecessary change?
@@ -612,6 +690,7 @@ class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase): | |||
pass | |||
|
|||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto.
Could you run:
./dev/reformat-python
to make the linter happy?
raise PySparkRuntimeError( | ||
errorClass="UDTF_RETURN_SCHEMA_MISMATCH", | ||
messageParameters={ | ||
"expected": str(len(arrow_return_type)), | ||
"actual": str(batch.num_columns), | ||
"func": "ArrowUDTF", | ||
}, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The column length is checked before it. Please take a look at:
if result.num_columns != return_type_size:
...
in verify_result
.
if should_write_start_length: | ||
write_int(SpecialLengths.START_ARROW_STREAM, stream) | ||
should_write_start_length = False | ||
|
||
yield coerced_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are done in the super().dump_stream()
. What we should do here is just type-casting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just wondering whether we can use RecordBatch.cast
for this instead of casting each column?
What changes were proposed in this pull request?
Support return type coercion for Arrow Python UDTFs by doing
arrow_cast
by defaultWhy are the changes needed?
Consistent behavior across Arrow UDFs and Arrow UDTFs
Does this PR introduce any user-facing change?
No, Arrow UDTF is not a public API yet
How was this patch tested?
New and existing UTs
Was this patch authored or co-authored using generative AI tooling?
No