Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions datafusion/sql/src/cte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel};

use arrow::datatypes::{Schema, SchemaRef};
use datafusion_common::{
Result, not_impl_err, plan_err,
Result, TableReference, not_impl_err, plan_err,
tree_node::{TreeNode, TreeNodeRecursion},
};
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource};
use sqlparser::ast::{Query, SetExpr, SetOperator, With};
use sqlparser::ast::{Ident, Query, SetExpr, SetOperator, With};

impl<S: ContextProvider> SqlToRel<'_, S> {
pub(super) fn plan_with_clause(
Expand All @@ -46,14 +46,24 @@ impl<S: ContextProvider> SqlToRel<'_, S> {

// Create a logical plan for the CTE
let cte_plan = if is_recursive {
self.recursive_cte(&cte_name, *cte.query, planner_context)?
let columns = cte.alias.columns.iter().map(|c| c.name.clone()).collect();
self.recursive_cte(&cte_name, columns, *cte.query, planner_context)?
} else {
self.non_recursive_cte(*cte.query, planner_context)?
};

// Each `WITH` block can change the column names in the last
// projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2").
let final_plan = self.apply_table_alias(cte_plan, cte.alias)?;
// Each `WITH` block can change the column names in the last projection
// (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). Recursive CTEs apply those
// to the static term in recursive_cte(), so only the relation name here.
let final_plan = if is_recursive {
LogicalPlanBuilder::from(cte_plan)
.alias(TableReference::bare(
self.ident_normalizer.normalize(cte.alias.name),
))?
.build()?
} else {
self.apply_table_alias(cte_plan, cte.alias)?
};
// Export the CTE to the outer query
planner_context.insert_cte(cte_name, final_plan);
}
Expand All @@ -71,6 +81,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
fn recursive_cte(
&self,
cte_name: &str,
columns: Vec<Ident>,
mut cte_query: Query,
planner_context: &mut PlannerContext,
) -> Result<LogicalPlan> {
Expand All @@ -91,9 +102,11 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
set_quantifier,
} => (left, right, set_quantifier),
other => {
// If the query is not a UNION, then it is not a recursive CTE
// Not a UNION, so not actually a recursive CTE. The caller adds only
// the relation name for recursive CTEs, so apply the column aliases here.
*cte_query.body = other;
return self.non_recursive_cte(cte_query, planner_context);
let plan = self.non_recursive_cte(cte_query, planner_context)?;
return self.apply_expr_alias(plan, columns);
}
};

Expand All @@ -111,6 +124,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
// ---------- Step 1: Compile the static term ------------------
let static_plan = self.set_expr_to_plan(*left_expr, planner_context)?;

// Apply the declared column-list aliases (e.g. `t(n)`) to the static term, so
// the work table built from its schema below exposes the declared names.
let static_plan = self.apply_expr_alias(static_plan, columns)?;

// Since the recursive CTEs include a component that references a
// table with its name, like the example below:
//
Expand Down
89 changes: 89 additions & 0 deletions datafusion/sqllogictest/test_files/cte.slt
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,95 @@ physical_plan
07)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
08)----------WorkTableExec: name=nodes

# recursive CTE with a column-list alias (e.g. `t(n)`): the declared names must be
# applied to the static term so the recursive self-reference can resolve them
query I rowsort
WITH RECURSIVE t(n) AS (
SELECT 1
UNION ALL
SELECT n + 1 FROM t WHERE n < 10
)
SELECT n FROM t
----
1
10
2
3
4
5
6
7
8
9

# recursive CTE with a multi-column column-list alias
query II rowsort
WITH RECURSIVE t(a, b) AS (
SELECT 1, 2
UNION ALL
SELECT a + 1, b * 2 FROM t WHERE a < 5
)
SELECT a, b FROM t
----
1 2
2 4
3 8
4 16
5 32

# recursive CTE with a column-list alias and UNION (DISTINCT)
query I rowsort
WITH RECURSIVE t(n) AS (
SELECT 1
UNION
SELECT n + 1 FROM t WHERE n < 5
)
SELECT n FROM t
----
1
2
3
4
5

# recursive CTE column-list alias arity mismatch is rejected cleanly (raised at
# the static term, rather than the old confusing "No field named ...")
query error DataFusion error: Error during planning: Source table contains 1 columns but only 2 names given as column alias
WITH RECURSIVE t(a, b) AS (
SELECT 1
UNION ALL
SELECT a + 1 FROM t WHERE a < 3
)
SELECT * FROM t

# explain a column-list-aliased recursive CTE: the declared name is applied to
# the static term, so there is no extra projection on top of RecursiveQuery
query TT
EXPLAIN WITH RECURSIVE t(n) AS (
SELECT 1
UNION ALL
SELECT n + 1 FROM t WHERE n < 10
)
SELECT * FROM t
----
logical_plan
01)SubqueryAlias: t
02)--RecursiveQuery: is_distinct=false
03)----Projection: Int64(1) AS n
04)------EmptyRelation: rows=1
05)----Projection: t.n + Int64(1)
06)------Filter: t.n < Int64(10)
07)--------TableScan: t projection=[n]
physical_plan
01)RecursiveQueryExec: name=t, is_distinct=false
02)--ProjectionExec: expr=[CAST(1 AS Int64) as n]
03)----PlaceholderRowExec
04)--CoalescePartitionsExec
05)----ProjectionExec: expr=[n@0 + 1 as n]
06)------FilterExec: n@0 < 10
07)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
08)----------WorkTableExec: name=t

# simple deduplicating recursive CTE works
query I
WITH RECURSIVE nodes AS (
Expand Down