Skip to content

Commit 190f1c1

Browse files
committed
fix(weave_query): Check for offsets with nulls before using a mask
1 parent 60ec4b1 commit 190f1c1

File tree

6 files changed

+66
-32
lines changed

6 files changed

+66
-32
lines changed

weave_query/weave_query/arrow/arrow.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -340,14 +340,18 @@ def rewrite_weavelist_refs(arrow_data, object_type, source_artifact, target_arti
340340
data = arrow_data
341341
if isinstance(arrow_data, pa.ChunkedArray):
342342
data = arrow_data.combine_chunks()
343-
return pa.ListArray.from_arrays(
344-
offsets_starting_at_zero(data),
343+
344+
new_offsets = offsets_starting_at_zero(data)
345+
346+
return safe_list_array_from_arrays(
347+
new_offsets,
345348
rewrite_weavelist_refs(
346349
data.flatten(),
347350
object_type.object_type,
348351
source_artifact,
349352
target_artifact,
350-
)
353+
),
354+
mask=pa.compute.is_null(data)
351355
)
352356
else:
353357
# We have a column of refs
@@ -473,6 +477,33 @@ def safe_coalesce(*arrs: pa.Array):
473477
result = pa.compute.if_else(safe_is_null(result), arr, result)
474478
return result
475479

480+
def safe_list_array_from_arrays(offsets, values, mask=None):
481+
# In PyArrow 17.0.0, ListArray.from_arrays() was updated to check for ambiguity
482+
# between null offsets and masks via offsets.null_count(). When concatenating
483+
# offset arrays, PyArrow uses lazy evaluation which can cause this check to see
484+
# an intermediate state and incorrectly raise an "ambiguous to specify both
485+
# validity map and offsets with nulls" error.
486+
487+
# Convert offsets to PyArrow array if it isn't already
488+
if not isinstance(offsets, pa.Array):
489+
offsets = pa.array(offsets)
490+
491+
has_offsets_with_nulls = offsets.null_count > 0
492+
493+
if has_offsets_with_nulls:
494+
# Can't use both a mask and an offset with nulls
495+
result_array = pa.ListArray.from_arrays(
496+
offsets,
497+
values
498+
)
499+
return pc.if_else(mask, None, result_array)
500+
else:
501+
return pa.ListArray.from_arrays(
502+
offsets,
503+
values,
504+
mask=mask
505+
)
506+
476507

477508
def arrow_zip(*arrs: pa.Array) -> pa.Array:
478509
n_arrs = len(arrs)

weave_query/weave_query/arrow/arrow_tags.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
from weave_query import weave_types as types
77
from weave_query.arrow import convert
8-
from weave_query.arrow.arrow import offsets_starting_at_zero
8+
from weave_query.arrow.arrow import (
9+
offsets_starting_at_zero,
10+
safe_list_array_from_arrays
11+
)
912
from weave_query.language_features.tagging import (
1013
process_opdef_output_type,
1114
tag_store,
@@ -30,11 +33,12 @@ def recursively_encode_pyarrow_strings_as_dictionaries(array: pa.Array) -> pa.Ar
3033
mask=pa.compute.invert(array.is_valid()),
3134
)
3235
elif pa.types.is_list(array.type):
33-
result_array = pa.ListArray.from_arrays(
34-
offsets_starting_at_zero(array),
35-
recursively_encode_pyarrow_strings_as_dictionaries(array.flatten())
36+
new_offsets = offsets_starting_at_zero(array)
37+
return safe_list_array_from_arrays(
38+
new_offsets,
39+
recursively_encode_pyarrow_strings_as_dictionaries(array.flatten()),
40+
mask=pa.compute.invert(array.is_valid())
3641
)
37-
return pc.if_else(pa.compute.invert(array.is_valid()), None, result_array)
3842
elif array.type == pa.string():
3943
return pc.dictionary_encode(array)
4044
else:

weave_query/weave_query/arrow/concat.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
make_vec_none,
2626
offsets_starting_at_zero,
2727
unsafe_awl_construction,
28+
safe_list_array_from_arrays
2829
)
2930
from weave_query.language_features.tagging import tagged_value_type
3031

@@ -209,21 +210,15 @@ def _concatenate_lists(
209210
]
210211
)
211212

212-
result_array = pa.ListArray.from_arrays(
213+
result_array = safe_list_array_from_arrays(
213214
new_offsets,
214215
concatted_values._arrow_data,
216+
mask=pa_concat_arrays([
217+
pa.compute.is_null(l1._arrow_data),
218+
pa.compute.is_null(l2._arrow_data)
219+
])
215220
)
216221

217-
# In pyarrow 17.0.0, we can only use either null offsets OR provide a mask,
218-
# so track the null values to convert empty lists (that were null before
219-
# conversion) and convert those empty lists to None before returning the AWL
220-
combined_nulls = pa_concat_arrays([
221-
pa.compute.is_null(l1._arrow_data),
222-
pa.compute.is_null(l2._arrow_data)
223-
])
224-
225-
result_array = pc.if_else(combined_nulls, None, result_array)
226-
227222
return ArrowWeaveList(
228223
result_array,
229224
types.List(
@@ -520,7 +515,6 @@ def _concatenate(
520515
pa.array(np.zeros(len(other), dtype=np.int8)),
521516
]
522517
),
523-
# offsets=pa.array(np.arange(len(self), dtype=np.int32)),
524518
offsets=pa_concat_arrays(
525519
[
526520
self_offsets,
@@ -547,7 +541,6 @@ def _concatenate(
547541
pa.compute.equal(other_type_codes, other_i).cast(pa.int8()),
548542
]
549543
),
550-
# offsets=pa.array(np.arange(len(other), dtype=np.int32)),
551544
offsets=pa_concat_arrays(
552545
[
553546
pa.array(np.zeros(len(self), dtype=np.int32)),

weave_query/weave_query/arrow/convert.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
ArrowWeaveList,
2222
PathType,
2323
unsafe_awl_construction,
24+
safe_list_array_from_arrays
2425
)
2526
from weave_query.language_features.tagging import tag_store, tagged_value_type
2627

@@ -307,9 +308,9 @@ def none_unboxer(iterator: typing.Iterable):
307308
mapper._object_type,
308309
py_objs_already_mapped,
309310
)
310-
result_array = pa.ListArray.from_arrays(offsets, new_objs)
311-
312-
return pc.if_else(pa.array(mask, type=pa.bool_()), None, result_array)
311+
return safe_list_array_from_arrays(
312+
offsets, new_objs, mask=pa.array(mask, type=pa.bool_())
313+
)
313314
elif pa.types.is_temporal(pyarrow_type):
314315
if py_objs_already_mapped:
315316
return pa.array(py_objs, type=pyarrow_type)

weave_query/weave_query/arrow/list_.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
offsets_starting_at_zero,
3333
pretty_print_arrow_type,
3434
safe_is_null,
35+
safe_list_array_from_arrays
3536
)
3637
from weave_query.language_features.tagging import (
3738
tag_store,
@@ -802,12 +803,14 @@ def _map_column(
802803
)._map_column(fn, pre_fn, path + (PathItemList(),))
803804
# print("SELF OBJECT TYPE", self.object_type)
804805
# print("SELF ARROW DATA TYPE", self._arrow_data.type)
805-
result_array = pa.ListArray.from_arrays(
806-
offsets_starting_at_zero(self._arrow_data),
807-
items._arrow_data
806+
807+
new_offsets = offsets_starting_at_zero(self._arrow_data)
808+
result_array = safe_list_array_from_arrays(
809+
new_offsets,
810+
items._arrow_data,
811+
pa.compute.is_null(arr)
808812
)
809813

810-
result_array = pc.if_else(pa.compute.is_null(arr), None, result_array)
811814
with_mapped_children = ArrowWeaveList(
812815
result_array,
813816
self.object_type.__class__(items.object_type),

weave_query/weave_query/ops_arrow/arraylist_ops.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from pyarrow import compute as pc
55

66
from weave_query import weave_types as types
7-
from weave_query.arrow.arrow import ArrowWeaveListType, arrow_as_array
7+
from weave_query.arrow.arrow import (
8+
ArrowWeaveListType,
9+
arrow_as_array,
10+
safe_list_array_from_arrays
11+
)
812
from weave_query.arrow.list_ import ArrowWeaveList
913
from weave_query.decorator_arrow_op import arrow_op
1014
from weave_query.language_features.tagging import tagged_value_type
@@ -350,9 +354,7 @@ def dropna(self):
350354
cumulative_non_null_counts = pa.compute.cumulative_sum(non_null)
351355
new_offsets = cumulative_non_null_counts.take(pa.compute.subtract(end_indexes, 1))
352356
new_offsets = pa.concat_arrays([start_indexes[:1], new_offsets])
353-
unflattened = pa.ListArray.from_arrays(new_offsets, new_data)
354-
unflattened = pa.compute.if_else(pa.compute.is_null(a), None, unflattened)
355-
357+
unflattened = safe_list_array_from_arrays(new_offsets, new_data, mask=pa.compute.is_null(a))
356358

357359
return ArrowWeaveList(
358360
unflattened,

0 commit comments

Comments
 (0)