Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 74 additions & 51 deletions datafusion/optimizer/src/scalar_subquery_to_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,10 @@ impl OptimizerRule for ScalarSubqueryToJoin {
build_join(&subquery, &cur_input, &alias)?
{
if !compensation_exprs.is_empty() {
rewrite_expr = rewrite_expr
.transform_up(|expr| {
if let Some(compensation_expr) = expr
.try_as_col()
.and_then(|col| compensation_exprs.get(col))
{
Ok(Transformed::yes(compensation_expr.clone()))
} else {
Ok(Transformed::no(expr))
}
})
.data()?;
rewrite_expr = apply_compensation_exprs(
rewrite_expr,
&compensation_exprs,
)?;
}
cur_input = optimized_subquery;
} else {
Expand Down Expand Up @@ -168,67 +160,76 @@ impl OptimizerRule for ScalarSubqueryToJoin {
return Ok(Transformed::no(LogicalPlan::Projection(projection)));
}

let mut all_subqueries = vec![];
let mut alias_to_index: HashMap<String, usize> = HashMap::new();
let mut rewrite_exprs: Vec<Expr> =
Vec::with_capacity(projection.expr.len());
for (idx, expr) in projection.expr.iter().enumerate() {
let (subqueries, rewrite_expr) = self.extract_subquery_exprs(
let mut subqueries_to_join = vec![];
let mut alias_to_state_index = HashMap::new();
let mut rewrite_states = Vec::with_capacity(projection.expr.len());
for expr in &projection.expr {
let state_idx = rewrite_states.len();
let (subqueries, rewritten_expr) = self.extract_subquery_exprs(
expr,
config.alias_generator(),
physical_uncorrelated,
)?;
for (_, alias) in &subqueries {
alias_to_index.insert(alias.clone(), idx);
}
all_subqueries.extend(subqueries);
rewrite_exprs.push(rewrite_expr);
let subquery_aliases = subqueries
.iter()
.map(|(_, alias)| alias.clone())
.collect::<Vec<_>>();
subqueries_to_join.extend(subqueries);
let state = ProjectionRewriteState {
rewritten_expr,
subquery_aliases,
};
alias_to_state_index.extend(
state
.subquery_aliases
.iter()
.cloned()
.map(|alias| (alias, state_idx)),
);
rewrite_states.push(state);
}
assert_or_internal_err!(
!all_subqueries.is_empty(),
!subqueries_to_join.is_empty(),
"Expected subqueries not found in projection"
);
// iterate through all subqueries in predicate, turning each into a left join

// Iterate through projection subqueries, turning each into a left join.
let mut cur_input = projection.input.as_ref().clone();
for (subquery, alias) in all_subqueries {
for (subquery, alias) in subqueries_to_join {
if let Some((optimized_subquery, compensation_exprs)) =
build_join(&subquery, &cur_input, &alias)?
{
cur_input = optimized_subquery;
if !compensation_exprs.is_empty()
&& let Some(&idx) = alias_to_index.get(&alias)
&& let Some(&idx) = alias_to_state_index.get(&alias)
{
let new_expr = rewrite_exprs[idx]
.clone()
.transform_up(|expr| {
if let Some(compensation_expr) = expr
.try_as_col()
.and_then(|col| compensation_exprs.get(col))
{
Ok(Transformed::yes(compensation_expr.clone()))
} else {
Ok(Transformed::no(expr))
}
})
.data()?;
rewrite_exprs[idx] = new_expr;
let new_expr = apply_compensation_exprs(
rewrite_states[idx].rewritten_expr.clone(),
&compensation_exprs,
)?;
rewrite_states[idx].rewritten_expr = new_expr;
}
} else {
// if we can't handle all of the subqueries then bail for now
return Ok(Transformed::no(LogicalPlan::Projection(projection)));
}
}

let mut proj_exprs = vec![];
for (expr, new_expr) in projection.expr.iter().zip(rewrite_exprs) {
let old_expr_name = expr.schema_name().to_string();
let new_expr_name = new_expr.schema_name().to_string();
if new_expr_name != old_expr_name {
proj_exprs.push(new_expr.alias(old_expr_name))
} else {
proj_exprs.push(new_expr);
}
}
let proj_exprs = projection
.expr
.iter()
.zip(rewrite_states)
.map(|(expr, state)| {
let old_expr_name = expr.schema_name().to_string();
let new_expr = state.rewritten_expr;
let new_expr_name = new_expr.schema_name().to_string();
if new_expr_name != old_expr_name {
new_expr.alias(old_expr_name)
} else {
new_expr
}
})
.collect::<Vec<_>>();
let new_plan = LogicalPlanBuilder::from(cur_input)
.project(proj_exprs)?
.build()?;
Expand Down Expand Up @@ -266,6 +267,28 @@ fn contains_scalar_subquery_to_rewrite(expr: &Expr, physical_uncorrelated: bool)
.expect("Inner is always Ok")
}

struct ProjectionRewriteState {
rewritten_expr: Expr,
subquery_aliases: Vec<String>,
}

fn apply_compensation_exprs(
expr: Expr,
compensation_exprs: &HashMap<Column, Expr>,
) -> Result<Expr> {
expr.transform_up(|expr| {
if let Some(compensation_expr) = expr
.try_as_col()
.and_then(|col| compensation_exprs.get(col))
{
Ok(Transformed::yes(compensation_expr.clone()))
} else {
Ok(Transformed::no(expr))
}
})
.data()
}

struct ExtractScalarSubQuery<'a> {
sub_query_info: Vec<(Subquery, String)>,
alias_gen: &'a Arc<AliasGenerator>,
Expand Down
46 changes: 46 additions & 0 deletions datafusion/sqllogictest/test_files/subquery.slt
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,52 @@ FROM t1
33 4 5
44 1 2

#correlated_scalar_subquery_multiple_projection_slots
# Distinct projection slots must each own their scalar subquery rewrite.
# COUNT gets empty-input compensation; SUM preserves NULL on no match.
query III rowsort
SELECT
t1_id,
(SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) AS cnt,
(SELECT sum(t2_int) FROM t2 WHERE t2.t2_id = t1.t1_id) AS total
FROM t1
----
11 1 3
22 0 1
33 3 NULL
44 0 3

#correlated_scalar_subquery_multiple_subqueries_one_projection_slot
# Multiple COUNT subqueries in a single projection expression must all be
# compensated before the expression is evaluated.
query II rowsort
SELECT
t1_id,
(SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int)
+ (SELECT count(*) FROM t2 WHERE t2.t2_id = t1.t1_id) AS combined
FROM t1
----
11 2
22 1
33 3
44 1

#correlated_scalar_subquery_mixed_repeated_and_non_count_projection_slots
# Repeated COUNT slots must each be compensated while the SUM slot keeps NULL
# semantics for unmatched outer rows.
query IIII rowsort
SELECT
t1_id,
(SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) + 1 AS a,
(SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) + 2 AS b,
(SELECT sum(t2_int) FROM t2 WHERE t2.t2_int = t1.t1_int) AS c
FROM t1
----
11 2 3 1
22 1 2 NULL
33 4 5 9
44 1 2 NULL

#correlated_scalar_subquery_count_agg2
query TT
explain SELECT t1_id, (SELECT count(*) FROM t2 WHERE t2.t2_int = t1.t1_int) as cnt from t1
Expand Down
Loading