Skip to content

Commit 50cbaea

Browse files
security(weave_query): Upgrade Pyarrow to 17.0.0 (#3617)
* chore: resolve ArrowInvalid error from ambiguous null specification * chore: update all ListArray.from_arrays calls to not use mask * fix(weave_query): Check for offsets with nulls before using a mask * chore: remove debug to trigger ci * chore: add dummy commit to trigger frontend-tests * chore: remove dummy commit changes
1 parent b60c820 commit 50cbaea

File tree

7 files changed

+73
-33
lines changed

7 files changed

+73
-33
lines changed

weave_query/requirements.legacy.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ typing_extensions>=4.0.0
66

77
# Definitely need arrow
88
# TODO: Colab has 9.0.0, can we support?
9-
# TODO: 17.0.0 breaks a bunch of tests - can we move this requirement to just the engine?
10-
pyarrow>=14.0.1,<17.0.0
9+
pyarrow==17.0.0
1110

1211
# pydantic integration, and required by openai anyway
1312
openai>=1.0.0

weave_query/weave_query/arrow/arrow.py

+33-3
Original file line numberDiff line numberDiff line change
@@ -340,15 +340,18 @@ def rewrite_weavelist_refs(arrow_data, object_type, source_artifact, target_arti
340340
data = arrow_data
341341
if isinstance(arrow_data, pa.ChunkedArray):
342342
data = arrow_data.combine_chunks()
343-
return pa.ListArray.from_arrays(
344-
offsets_starting_at_zero(data),
343+
344+
new_offsets = offsets_starting_at_zero(data)
345+
346+
return safe_list_array_from_arrays(
347+
new_offsets,
345348
rewrite_weavelist_refs(
346349
data.flatten(),
347350
object_type.object_type,
348351
source_artifact,
349352
target_artifact,
350353
),
351-
mask=pa.compute.is_null(data),
354+
mask=pa.compute.is_null(data)
352355
)
353356
else:
354357
# We have a column of refs
@@ -474,6 +477,33 @@ def safe_coalesce(*arrs: pa.Array):
474477
result = pa.compute.if_else(safe_is_null(result), arr, result)
475478
return result
476479

480+
def safe_list_array_from_arrays(offsets, values, mask=None):
481+
# In PyArrow 17.0.0, ListArray.from_arrays() was updated to check for ambiguity
482+
# between null offsets and masks via offsets.null_count(). When concatenating
483+
# offset arrays, PyArrow uses lazy evaluation which can cause this check to see
484+
# an intermediate state and incorrectly raise an "ambiguous to specify both
485+
# validity map and offsets with nulls" error.
486+
487+
# Convert offsets to PyArrow array if it isn't already
488+
if not isinstance(offsets, pa.Array):
489+
offsets = pa.array(offsets)
490+
491+
has_offsets_with_nulls = offsets.null_count > 0
492+
493+
if has_offsets_with_nulls:
494+
# Can't use both a mask and an offset with nulls
495+
result_array = pa.ListArray.from_arrays(
496+
offsets,
497+
values
498+
)
499+
return pc.if_else(mask, None, result_array)
500+
else:
501+
return pa.ListArray.from_arrays(
502+
offsets,
503+
values,
504+
mask=mask
505+
)
506+
477507

478508
def arrow_zip(*arrs: pa.Array) -> pa.Array:
479509
n_arrs = len(arrs)

weave_query/weave_query/arrow/arrow_tags.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
from weave_query import weave_types as types
77
from weave_query.arrow import convert
8-
from weave_query.arrow.arrow import offsets_starting_at_zero
8+
from weave_query.arrow.arrow import (
9+
offsets_starting_at_zero,
10+
safe_list_array_from_arrays
11+
)
912
from weave_query.language_features.tagging import (
1013
process_opdef_output_type,
1114
tag_store,
@@ -30,10 +33,11 @@ def recursively_encode_pyarrow_strings_as_dictionaries(array: pa.Array) -> pa.Ar
3033
mask=pa.compute.invert(array.is_valid()),
3134
)
3235
elif pa.types.is_list(array.type):
33-
return pa.ListArray.from_arrays(
34-
offsets_starting_at_zero(array),
36+
new_offsets = offsets_starting_at_zero(array)
37+
return safe_list_array_from_arrays(
38+
new_offsets,
3539
recursively_encode_pyarrow_strings_as_dictionaries(array.flatten()),
36-
mask=pa.compute.invert(array.is_valid()),
40+
mask=pa.compute.invert(array.is_valid())
3741
)
3842
elif array.type == pa.string():
3943
return pc.dictionary_encode(array)

weave_query/weave_query/arrow/concat.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
make_vec_none,
2626
offsets_starting_at_zero,
2727
unsafe_awl_construction,
28+
safe_list_array_from_arrays
2829
)
2930
from weave_query.language_features.tagging import tagged_value_type
3031

@@ -198,6 +199,7 @@ def _concatenate_lists(
198199
l2_arrow_data.flatten(), l2.object_type.object_type, l2._artifact
199200
)
200201
concatted_values = _concatenate(self_values, other_values, depth=depth + 1)
202+
201203
new_offsets = pa_concat_arrays(
202204
[
203205
offsets_starting_at_zero(l1_arrow_data)[:-1],
@@ -207,17 +209,18 @@ def _concatenate_lists(
207209
).cast(pa.int32()),
208210
]
209211
)
212+
213+
result_array = safe_list_array_from_arrays(
214+
new_offsets,
215+
concatted_values._arrow_data,
216+
mask=pa_concat_arrays([
217+
pa.compute.is_null(l1._arrow_data),
218+
pa.compute.is_null(l2._arrow_data)
219+
])
220+
)
221+
210222
return ArrowWeaveList(
211-
pa.ListArray.from_arrays(
212-
new_offsets,
213-
concatted_values._arrow_data,
214-
mask=pa_concat_arrays(
215-
[
216-
pa.compute.is_null(l1_arrow_data),
217-
pa.compute.is_null(l2_arrow_data),
218-
]
219-
),
220-
),
223+
result_array,
221224
types.List(
222225
types.merge_types(l1.object_type.object_type, l2.object_type.object_type)
223226
),
@@ -512,7 +515,6 @@ def _concatenate(
512515
pa.array(np.zeros(len(other), dtype=np.int8)),
513516
]
514517
),
515-
# offsets=pa.array(np.arange(len(self), dtype=np.int32)),
516518
offsets=pa_concat_arrays(
517519
[
518520
self_offsets,
@@ -539,7 +541,6 @@ def _concatenate(
539541
pa.compute.equal(other_type_codes, other_i).cast(pa.int8()),
540542
]
541543
),
542-
# offsets=pa.array(np.arange(len(other), dtype=np.int32)),
543544
offsets=pa_concat_arrays(
544545
[
545546
pa.array(np.zeros(len(self), dtype=np.int32)),

weave_query/weave_query/arrow/convert.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ArrowWeaveList,
2222
PathType,
2323
unsafe_awl_construction,
24+
safe_list_array_from_arrays
2425
)
2526
from weave_query.language_features.tagging import tag_store, tagged_value_type
2627

@@ -307,7 +308,7 @@ def none_unboxer(iterator: typing.Iterable):
307308
mapper._object_type,
308309
py_objs_already_mapped,
309310
)
310-
return pa.ListArray.from_arrays(
311+
return safe_list_array_from_arrays(
311312
offsets, new_objs, mask=pa.array(mask, type=pa.bool_())
312313
)
313314
elif pa.types.is_temporal(pyarrow_type):

weave_query/weave_query/arrow/list_.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
offsets_starting_at_zero,
3333
pretty_print_arrow_type,
3434
safe_is_null,
35+
safe_list_array_from_arrays
3536
)
3637
from weave_query.language_features.tagging import (
3738
tag_store,
@@ -800,14 +801,16 @@ def _map_column(
800801
items: ArrowWeaveList = ArrowWeaveList(
801802
arr.flatten(), self.object_type.object_type, self._artifact
802803
)._map_column(fn, pre_fn, path + (PathItemList(),))
803-
# print("SELF OBJECT TYPE", self.object_type)
804-
# print("SELF ARROW DATA TYPE", self._arrow_data.type)
804+
805+
new_offsets = offsets_starting_at_zero(self._arrow_data)
806+
result_array = safe_list_array_from_arrays(
807+
new_offsets,
808+
items._arrow_data,
809+
pa.compute.is_null(arr)
810+
)
811+
805812
with_mapped_children = ArrowWeaveList(
806-
pa.ListArray.from_arrays(
807-
offsets_starting_at_zero(self._arrow_data),
808-
items._arrow_data,
809-
mask=pa.compute.is_null(arr),
810-
),
813+
result_array,
811814
self.object_type.__class__(items.object_type),
812815
self._artifact,
813816
invalid_reason=items._invalid_reason,

weave_query/weave_query/ops_arrow/arraylist_ops.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from pyarrow import compute as pc
55

66
from weave_query import weave_types as types
7-
from weave_query.arrow.arrow import ArrowWeaveListType, arrow_as_array
7+
from weave_query.arrow.arrow import (
8+
ArrowWeaveListType,
9+
arrow_as_array,
10+
safe_list_array_from_arrays
11+
)
812
from weave_query.arrow.list_ import ArrowWeaveList
913
from weave_query.decorator_arrow_op import arrow_op
1014
from weave_query.language_features.tagging import tagged_value_type
@@ -350,9 +354,7 @@ def dropna(self):
350354
cumulative_non_null_counts = pa.compute.cumulative_sum(non_null)
351355
new_offsets = cumulative_non_null_counts.take(pa.compute.subtract(end_indexes, 1))
352356
new_offsets = pa.concat_arrays([start_indexes[:1], new_offsets])
353-
unflattened = pa.ListArray.from_arrays(
354-
new_offsets, new_data, mask=pa.compute.is_null(a)
355-
)
357+
unflattened = safe_list_array_from_arrays(new_offsets, new_data, mask=pa.compute.is_null(a))
356358

357359
return ArrowWeaveList(
358360
unflattened,

0 commit comments

Comments
 (0)