Skip to content

Commit

Permalink
feat: support 'col IN (a, b, c)' type expressions
Browse files Browse the repository at this point in the history
Signed-off-by: Robert Pack <[email protected]>
  • Loading branch information
roeap committed Jan 18, 2025
1 parent 8494126 commit 007b4e2
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 4 deletions.
4 changes: 2 additions & 2 deletions kernel/src/engine/arrow_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl TryFrom<&DataType> for ArrowDataType {
// TODO: https://github.com/delta-io/delta/issues/643
PrimitiveType::Timestamp => Ok(ArrowDataType::Timestamp(
TimeUnit::Microsecond,
Some("UTC".into()),
Some("+00:00".into()),
)),
PrimitiveType::TimestampNtz => {
Ok(ArrowDataType::Timestamp(TimeUnit::Microsecond, None))
Expand Down Expand Up @@ -208,7 +208,7 @@ impl TryFrom<&ArrowDataType> for DataType {
ArrowDataType::Date64 => Ok(DataType::DATE),
ArrowDataType::Timestamp(TimeUnit::Microsecond, None) => Ok(DataType::TIMESTAMP_NTZ),
ArrowDataType::Timestamp(TimeUnit::Microsecond, Some(tz))
if tz.eq_ignore_ascii_case("utc") =>
if tz.eq_ignore_ascii_case("utc") || tz.eq_ignore_ascii_case("+00:00") =>
{
Ok(DataType::TIMESTAMP)
}
Expand Down
158 changes: 156 additions & 2 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,81 @@ fn evaluate_expression(
(ArrowDataType::Decimal256(_, _), Decimal256Type)
}
}
(Column(name), Literal(Scalar::Array(ad))) => {
use crate::expressions::ArrayData;

let column = extract_column(batch, name)?;
let data_type = ad
.array_type()
.element_type()
.as_primitive_opt()
.ok_or_else(|| {
Error::invalid_expression(format!(
"IN only supports array literals with primitive elements, got: '{:?}'",
ad.array_type().element_type()
))
})?;

fn op(
col: impl Iterator<Item = Option<impl Into<Scalar>>>,
ad: &ArrayData,
) -> BooleanArray {
#[allow(deprecated)]
let res = col.map(|val| val.map(|v| ad.array_elements().contains(&v.into())));
BooleanArray::from_iter(res)
}

// safety: as_* methods on arrow arrays can panic, but we checked the data type before applying.
let arr: BooleanArray = match (column.data_type(), data_type) {
(ArrowDataType::Utf8, PrimitiveType::String) => op(column.as_string::<i32>().iter(), ad),
(ArrowDataType::LargeUtf8, PrimitiveType::String) => op(column.as_string::<i64>().iter(), ad),
(ArrowDataType::Utf8View, PrimitiveType::String) => op(column.as_string_view().iter(), ad),
(ArrowDataType::Int8, PrimitiveType::Byte) => op(column.as_primitive::<Int8Type>().iter(), ad),
(ArrowDataType::Int16, PrimitiveType::Short) => op(column.as_primitive::<Int16Type>().iter(), ad),
(ArrowDataType::Int32, PrimitiveType::Integer) => op(column.as_primitive::<Int32Type>().iter(), ad),
(ArrowDataType::Int64, PrimitiveType::Long) => op(column.as_primitive::<Int64Type>().iter(), ad),
(ArrowDataType::Float32, PrimitiveType::Float) => op(column.as_primitive::<Float32Type>().iter(), ad),
(ArrowDataType::Float64, PrimitiveType::Double) => op(column.as_primitive::<Float64Type>().iter(), ad),
(ArrowDataType::Date32, PrimitiveType::Date) => {
#[allow(deprecated)]
let res = column
.as_primitive::<Date32Type>()
.iter()
.map(|val| val.map(|v| ad.array_elements().contains(&Scalar::Date(v))));
BooleanArray::from_iter(res)
}
(
ArrowDataType::Timestamp(TimeUnit::Microsecond, unit),
kt @ PrimitiveType::Timestamp | kt @ PrimitiveType::TimestampNtz,
) => {
let res = column.as_primitive::<TimestampMicrosecondType>().iter();
match (unit, kt) {
// regardless of the time zone stored in the timestamp, the underlying value is always in UTC
(Some(_), PrimitiveType::Timestamp) => {
BooleanArray::from_iter(res.map(|val| {
#[allow(deprecated)]
val.map(|v| ad.array_elements().contains(&Scalar::Timestamp(v)))
}))
}
(None, PrimitiveType::TimestampNtz) => {
BooleanArray::from_iter(res.map(|val| {
val.map(|v| {
#[allow(deprecated)]
ad.array_elements().contains(&Scalar::TimestampNtz(v))
})
}))
}
_ => unreachable!(),
}
}
(l, r) => {
return Err(Error::invalid_expression(format!(
"Cannot check if value of type '{l}' is contained in array with values of type '{r}'"
)))
}
};
Ok(Arc::new(arr))
}
(Literal(lit), Literal(Scalar::Array(ad))) => {
#[allow(deprecated)]
let exists = ad.array_elements().contains(lit);
Expand Down Expand Up @@ -382,8 +457,8 @@ fn new_field_with_metadata(

// A helper that is a wrapper over `transform_field_and_col`. This will take apart the passed struct
// and use that method to transform each column and then put the struct back together. Target types
// and names for each column should be passed in `target_types_and_names`. The number of elements in
// the `target_types_and_names` iterator _must_ be the same as the number of columns in
// and names for each column should be passed in `target_fields`. The number of elements in
// the `target_fields` iterator _must_ be the same as the number of columns in
// `struct_array`. The transformation is ordinal. That is, the order of fields in `target_fields`
// _must_ match the order of the columns in `struct_array`.
fn transform_struct(
Expand Down Expand Up @@ -692,6 +767,85 @@ mod tests {
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_column_in_array() {
let values = Int32Array::from(vec![0, 1, 2, 3]);
let field = Arc::new(Field::new("item", DataType::Int32, true));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::Integer.into(), false),
[Scalar::Integer(1), Scalar::Integer(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();

let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
let in_expected = BooleanArray::from(vec![false, true, false, true]);
assert_eq!(in_result.as_ref(), &in_expected);

let not_in_op = Expression::binary(BinaryOperator::NotIn, column_expr!("item"), rhs);
let not_in_result =
evaluate_expression(&not_in_op, &batch, Some(&crate::schema::DataType::BOOLEAN))
.unwrap();
let not_in_expected = BooleanArray::from(vec![true, false, true, false]);
assert_eq!(not_in_result.as_ref(), &not_in_expected);

let in_expected = BooleanArray::from(vec![false, true, false, true]);

// Date arrays
let values = Date32Array::from(vec![0, 1, 2, 3]);
let field = Arc::new(Field::new("item", DataType::Date32, true));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::Date.into(), false),
[Scalar::Date(1), Scalar::Date(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
assert_eq!(in_result.as_ref(), &in_expected);

// Timestamp arrays
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]).with_timezone_utc();
let field = Arc::new(Field::new(
"item",
(&crate::schema::DataType::TIMESTAMP).try_into().unwrap(),
true,
));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::Timestamp.into(), false),
[Scalar::Timestamp(1), Scalar::Timestamp(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
assert_eq!(in_result.as_ref(), &in_expected);

// Timestamp NTZ arrays
let values = TimestampMicrosecondArray::from(vec![0, 1, 2, 3]);
let field = Arc::new(Field::new(
"item",
(&crate::schema::DataType::TIMESTAMP_NTZ)
.try_into()
.unwrap(),
true,
));
let rhs = Expression::literal(Scalar::Array(ArrayData::new(
ArrayType::new(PrimitiveType::TimestampNtz.into(), false),
[Scalar::TimestampNtz(1), Scalar::TimestampNtz(3)],
)));
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();
let in_op = Expression::binary(BinaryOperator::In, column_expr!("item"), rhs.clone());
let in_result =
evaluate_expression(&in_op, &batch, Some(&crate::schema::DataType::BOOLEAN)).unwrap();
assert_eq!(in_result.as_ref(), &in_expected);
}

#[test]
fn test_extract_column() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand Down

0 comments on commit 007b4e2

Please sign in to comment.