Skip to content

Commit fe23133

Browse files
feat: Support EXISTS subquery (#137)
Co-authored-by: Alexandr Romanenko <[email protected]>
1 parent 28a07c3 commit fe23133

File tree

10 files changed

+460
-121
lines changed

10 files changed

+460
-121
lines changed

datafusion/core/src/logical_plan/builder.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ use super::{dfschema::ToDFSchema, expr_rewriter::coerce_plan_expr_for_schema, Di
4747
use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
4848
use crate::logical_plan::{
4949
columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, Column,
50-
CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition, Values,
50+
CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition,
51+
SubqueryType, Values,
5152
};
5253
use crate::sql::utils::group_window_expr_by_sort_keys;
5354

@@ -528,12 +529,15 @@ impl LogicalPlanBuilder {
528529
pub fn subquery(
529530
&self,
530531
subqueries: impl IntoIterator<Item = impl Into<LogicalPlan>>,
532+
types: impl IntoIterator<Item = SubqueryType>,
531533
) -> Result<Self> {
532534
let subqueries = subqueries.into_iter().map(|l| l.into()).collect::<Vec<_>>();
533-
let schema = Arc::new(Subquery::merged_schema(&self.plan, &subqueries));
535+
let types = types.into_iter().collect::<Vec<_>>();
536+
let schema = Arc::new(Subquery::merged_schema(&self.plan, &subqueries, &types));
534537
Ok(Self::from(LogicalPlan::Subquery(Subquery {
535538
input: Arc::new(self.plan.clone()),
536539
subqueries,
540+
types,
537541
schema,
538542
})))
539543
}

datafusion/core/src/logical_plan/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,6 @@ pub use plan::{
6868
CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, CrossJoin, Distinct,
6969
DropTable, EmptyRelation, Filter, JoinConstraint, JoinType, Limit, LogicalPlan,
7070
Partitioning, PlanType, PlanVisitor, Repartition, StringifiedPlan, Subquery,
71-
TableScan, ToStringifiedPlan, Union, Values,
71+
SubqueryType, TableScan, ToStringifiedPlan, Union, Values,
7272
};
7373
pub use registry::FunctionRegistry;

datafusion/core/src/logical_plan/plan.rs

+84-11
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use crate::error::DataFusionError;
2626
use crate::logical_plan::dfschema::DFSchemaRef;
2727
use crate::sql::parser::FileType;
2828
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
29-
use datafusion_common::DFSchema;
29+
use datafusion_common::{DFField, DFSchema};
3030
use std::fmt::Formatter;
3131
use std::{
3232
collections::HashSet,
@@ -267,22 +267,83 @@ pub struct Limit {
267267
/// Evaluates correlated sub queries
268268
#[derive(Clone)]
269269
pub struct Subquery {
270-
/// The list of sub queries
271-
pub subqueries: Vec<LogicalPlan>,
272270
/// The incoming logical plan
273271
pub input: Arc<LogicalPlan>,
272+
/// The list of sub queries
273+
pub subqueries: Vec<LogicalPlan>,
274+
/// The list of subquery types
275+
pub types: Vec<SubqueryType>,
274276
/// The schema description of the output
275277
pub schema: DFSchemaRef,
276278
}
277279

280+
/// Subquery type
281+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
282+
pub enum SubqueryType {
283+
/// Scalar (SELECT, WHERE) evaluating to one value
284+
Scalar,
285+
/// EXISTS(...) evaluating to true if at least one row was produced
286+
Exists,
287+
// This will be extended with `AnyAll` type.
288+
}
289+
290+
impl Display for SubqueryType {
291+
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
292+
let subquery_type = match self {
293+
SubqueryType::Scalar => "Scalar",
294+
SubqueryType::Exists => "Exists",
295+
};
296+
write!(f, "{}", subquery_type)
297+
}
298+
}
299+
278300
impl Subquery {
279301
/// Merge schema of main input and correlated subquery columns
280-
pub fn merged_schema(input: &LogicalPlan, subqueries: &[LogicalPlan]) -> DFSchema {
281-
subqueries.iter().fold((**input.schema()).clone(), |a, b| {
282-
let mut res = a;
283-
res.merge(b.schema());
284-
res
285-
})
302+
pub fn merged_schema(
303+
input: &LogicalPlan,
304+
subqueries: &[LogicalPlan],
305+
types: &[SubqueryType],
306+
) -> DFSchema {
307+
subqueries.iter().zip(types.iter()).fold(
308+
(**input.schema()).clone(),
309+
|schema, (plan, typ)| {
310+
let mut schema = schema;
311+
schema.merge(&Self::transform_dfschema(plan.schema(), *typ));
312+
schema
313+
},
314+
)
315+
}
316+
317+
/// Transform DataFusion schema according to subquery type
318+
pub fn transform_dfschema(schema: &DFSchema, typ: SubqueryType) -> DFSchema {
319+
match typ {
320+
SubqueryType::Scalar => schema.clone(),
321+
SubqueryType::Exists => {
322+
let new_fields = schema
323+
.fields()
324+
.iter()
325+
.map(|field| {
326+
let new_field = Subquery::transform_field(field.field(), typ);
327+
if let Some(qualifier) = field.qualifier() {
328+
DFField::from_qualified(qualifier, new_field)
329+
} else {
330+
DFField::from(new_field)
331+
}
332+
})
333+
.collect();
334+
DFSchema::new_with_metadata(new_fields, schema.metadata().clone())
335+
.unwrap()
336+
} // Schema will be transformed for `AnyAll` as well
337+
}
338+
}
339+
340+
/// Transform Arrow field according to subquery type
341+
pub fn transform_field(field: &Field, typ: SubqueryType) -> Field {
342+
match typ {
343+
SubqueryType::Scalar => field.clone(),
344+
SubqueryType::Exists => Field::new(field.name(), DataType::Boolean, false),
345+
// Field will be transformed for `AnyAll` as well
346+
}
286347
}
287348
}
288349

@@ -475,13 +536,23 @@ impl LogicalPlan {
475536
LogicalPlan::Values(Values { schema, .. }) => vec![schema],
476537
LogicalPlan::Window(Window { input, schema, .. })
477538
| LogicalPlan::Projection(Projection { input, schema, .. })
478-
| LogicalPlan::Subquery(Subquery { input, schema, .. })
479539
| LogicalPlan::Aggregate(Aggregate { input, schema, .. })
480540
| LogicalPlan::TableUDFs(TableUDFs { input, schema, .. }) => {
481541
let mut schemas = input.all_schemas();
482542
schemas.insert(0, schema);
483543
schemas
484544
}
545+
LogicalPlan::Subquery(Subquery {
546+
input,
547+
subqueries,
548+
schema,
549+
..
550+
}) => {
551+
let mut schemas = input.all_schemas();
552+
schemas.extend(subqueries.iter().map(|s| s.schema()));
553+
schemas.insert(0, schema);
554+
schemas
555+
}
485556
LogicalPlan::Join(Join {
486557
left,
487558
right,
@@ -1063,7 +1134,9 @@ impl LogicalPlan {
10631134
}
10641135
Ok(())
10651136
}
1066-
LogicalPlan::Subquery(Subquery { .. }) => write!(f, "Subquery"),
1137+
LogicalPlan::Subquery(Subquery { types, .. }) => {
1138+
write!(f, "Subquery: types={:?}", types)
1139+
}
10671140
LogicalPlan::Filter(Filter {
10681141
predicate: ref expr,
10691142
..

datafusion/core/src/optimizer/projection_drop_out.rs

+2
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ fn optimize_plan(
254254
LogicalPlan::Subquery(Subquery {
255255
input,
256256
subqueries,
257+
types,
257258
schema,
258259
}) => {
259260
// TODO: subqueries are not optimized
@@ -269,6 +270,7 @@ fn optimize_plan(
269270
.map(|(p, _)| p)?,
270271
),
271272
subqueries: subqueries.clone(),
273+
types: types.clone(),
272274
schema: schema.clone(),
273275
}),
274276
None,

datafusion/core/src/optimizer/projection_push_down.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,10 @@ fn optimize_plan(
453453
}))
454454
}
455455
LogicalPlan::Subquery(Subquery {
456-
input, subqueries, ..
456+
input,
457+
subqueries,
458+
types,
459+
..
457460
}) => {
458461
let mut subquery_required_columns = HashSet::new();
459462
for subquery in subqueries.iter() {
@@ -484,11 +487,12 @@ fn optimize_plan(
484487
has_projection,
485488
_optimizer_config,
486489
)?;
487-
let new_schema = Subquery::merged_schema(&input, subqueries);
490+
let new_schema = Subquery::merged_schema(&input, subqueries, types);
488491
Ok(LogicalPlan::Subquery(Subquery {
489492
input: Arc::new(input),
490493
schema: Arc::new(new_schema),
491494
subqueries: subqueries.clone(),
495+
types: types.clone(),
492496
}))
493497
}
494498
// all other nodes: Add any additional columns used by

datafusion/core/src/optimizer/utils.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,11 @@ pub fn from_plan(
161161
alias: alias.clone(),
162162
}))
163163
}
164-
LogicalPlan::Subquery(Subquery { schema, .. }) => {
164+
LogicalPlan::Subquery(Subquery { schema, types, .. }) => {
165165
Ok(LogicalPlan::Subquery(Subquery {
166-
subqueries: inputs[1..inputs.len()].to_vec(),
167166
input: Arc::new(inputs[0].clone()),
167+
subqueries: inputs[1..inputs.len()].to_vec(),
168+
types: types.clone(),
168169
schema: schema.clone(),
169170
}))
170171
}

datafusion/core/src/physical_plan/planner.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ impl DefaultPhysicalPlanner {
917917

918918
Ok(Arc::new(GlobalLimitExec::new(input, *skip, *fetch)))
919919
}
920-
LogicalPlan::Subquery(Subquery { subqueries, input, schema }) => {
920+
LogicalPlan::Subquery(Subquery { input, subqueries, types, schema }) => {
921921
let cursor = Arc::new(OuterQueryCursor::new(schema.as_ref().to_owned().into()));
922922
let mut new_session_state = session_state.clone();
923923
new_session_state.execution_props = new_session_state.execution_props.with_outer_query_cursor(cursor.clone());
@@ -931,7 +931,7 @@ impl DefaultPhysicalPlanner {
931931
})
932932
.collect::<Vec<_>>();
933933
let input = self.create_initial_plan(input, &new_session_state).await?;
934-
Ok(Arc::new(SubqueryExec::try_new(subqueries, input, cursor)?))
934+
Ok(Arc::new(SubqueryExec::try_new(input, subqueries, types.clone(), cursor)?))
935935
}
936936
LogicalPlan::CreateExternalTable(_) => {
937937
// There is no default plan for "CREATE EXTERNAL
@@ -1033,6 +1033,7 @@ pub fn create_physical_expr(
10331033
let cursors = execution_props.outer_query_cursors.clone();
10341034
let cursor = cursors
10351035
.iter()
1036+
.rev()
10361037
.find(|cur| cur.schema().field_with_name(c.name.as_str()).is_ok())
10371038
.ok_or_else(|| {
10381039
DataFusionError::Execution(format!(

0 commit comments

Comments
 (0)