Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions rust/frontend/src/expr/function_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ impl FunctionCall {
pub fn decompose(self) -> (ExprType, Vec<ExprImpl>) {
(self.func_type, self.inputs)
}
pub fn decompose_as_binary(self) -> (ExprType, ExprImpl, ExprImpl) {
assert_eq!(self.inputs.len(), 2);
let mut iter = self.inputs.into_iter();
let left = iter.next().unwrap();
let right = iter.next().unwrap();
(self.func_type, left, right)
}
pub fn decompose_as_unary(self) -> (ExprType, ExprImpl) {
assert_eq!(self.inputs.len(), 1);
let mut iter = self.inputs.into_iter();
let input = iter.next().unwrap();
(self.func_type, input)
}

pub fn get_expr_type(&self) -> ExprType {
self.func_type
}
Expand Down
25 changes: 17 additions & 8 deletions rust/frontend/src/optimizer/plan_node/batch_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,26 @@ use std::fmt;
use risingwave_common::catalog::Schema;

use super::{
IntoPlanRef, JoinPredicate, LogicalJoin, PlanRef, PlanTreeNodeBinary, ToDistributedBatch,
EqJoinPredicate, IntoPlanRef, LogicalJoin, PlanRef, PlanTreeNodeBinary, ToDistributedBatch,
};
use crate::optimizer::property::{Distribution, WithDistribution, WithOrder, WithSchema};

#[derive(Debug, Clone)]
pub struct BatchHashJoin {
logical: LogicalJoin,
eq_join_predicate: EqJoinPredicate,
}
impl BatchHashJoin {
pub fn new(logical: LogicalJoin) -> Self {
Self { logical }
pub fn new(logical: LogicalJoin, eq_join_predicate: EqJoinPredicate) -> Self {
Self {
logical,
eq_join_predicate,
}
}
pub fn predicate(&self) -> &JoinPredicate {
self.logical.predicate()

/// Get a reference to the batch hash join's eq join predicate.
pub fn eq_join_predicate(&self) -> &EqJoinPredicate {
&self.eq_join_predicate
}
}

Expand All @@ -34,7 +40,10 @@ impl PlanTreeNodeBinary for BatchHashJoin {
self.logical.right()
}
fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
Self::new(self.logical.clone_with_left_right(left, right))
Self::new(
self.logical.clone_with_left_right(left, right),
self.eq_join_predicate.clone(),
)
}
fn left_dist_required(&self) -> &Distribution {
todo!()
Expand All @@ -56,11 +65,11 @@ impl ToDistributedBatch for BatchHashJoin {
fn to_distributed(&self) -> PlanRef {
let left = self.left().to_distributed_with_required(
self.left_order_required(),
&Distribution::HashShard(self.predicate().left_keys()),
&Distribution::HashShard(self.eq_join_predicate().left_eq_indexes()),
);
let right = self.right().to_distributed_with_required(
self.right_order_required(),
&Distribution::HashShard(self.predicate().right_keys()),
&Distribution::HashShard(self.eq_join_predicate().right_eq_indexes()),
);

self.clone_with_left_right(left, right).into_plan_ref()
Expand Down
24 changes: 16 additions & 8 deletions rust/frontend/src/optimizer/plan_node/batch_sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,36 @@ use std::fmt;
use risingwave_common::catalog::Schema;

use super::{
IntoPlanRef, JoinPredicate, LogicalJoin, PlanRef, PlanTreeNodeBinary, ToDistributedBatch,
EqJoinPredicate, IntoPlanRef, LogicalJoin, PlanRef, PlanTreeNodeBinary, ToDistributedBatch,
};
use crate::optimizer::property::{
Direction, Distribution, FieldOrder, Order, WithDistribution, WithOrder, WithSchema,
};
#[derive(Debug, Clone)]
pub struct BatchSortMergeJoin {
logical: LogicalJoin,
eq_join_predicate: EqJoinPredicate,
order: Order,
}
impl BatchSortMergeJoin {
pub fn new(logical: LogicalJoin) -> Self {
pub fn new(logical: LogicalJoin, eq_join_predicate: EqJoinPredicate) -> Self {
let order = Self::derive_order(logical.left().order(), logical.right().order());
Self { logical, order }
Self {
logical,
order,
eq_join_predicate,
}
}

// Panic if input orders can't satisfy sortMergeJoin
fn derive_order(_left: &Order, _right: &Order) -> Order {
todo!()
}

pub fn left_required_order(join_predicate: &JoinPredicate) -> Order {
pub fn left_required_order(eq_join_predicate: &EqJoinPredicate) -> Order {
Order {
field_order: join_predicate
.left_keys()
field_order: eq_join_predicate
.left_eq_indexes()
.into_iter()
.map(|index| FieldOrder {
index,
Expand All @@ -38,7 +43,7 @@ impl BatchSortMergeJoin {
}
pub fn right_required_order_from_left_order(
_left_order: &Order,
_join_predicate: &JoinPredicate,
_eq_join_predicate: &EqJoinPredicate,
) -> Order {
todo!()
}
Expand All @@ -58,7 +63,10 @@ impl PlanTreeNodeBinary for BatchSortMergeJoin {
self.logical.right()
}
fn clone_with_left_right(&self, left: PlanRef, right: PlanRef) -> Self {
Self::new(self.logical.clone_with_left_right(left, right))
Self::new(
self.logical.clone_with_left_right(left, right),
self.eq_join_predicate.clone(),
)
}
fn right_order_required(&self) -> &Order {
todo!()
Expand Down
149 changes: 149 additions & 0 deletions rust/frontend/src/optimizer/plan_node/eq_join_predicate.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use fixedbitset::FixedBitSet;

use crate::expr::{get_inputs_col_index, Expr, ExprImpl, ExprType, FunctionCall, InputRef};
use crate::utils::Condition;
#[derive(Debug, Clone)]
/// the join predicate used in optimizer
pub struct EqJoinPredicate {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this is called EqJoinPredicate? So at least it contains one equal preidcate?

Do we have NeqJoinPredicate? etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just for equal join which has equal join and used for hashjoin and sortMergeJoin to easily get the eq and noneq condition. for the non eq join, the condition is enough.

/// other conditions, linked with AND conjunction.
other_cond: Condition,

/// the equal columns indexes(in the input schema) both sides,
/// the first is from the left table and the second is from the right table.
/// now all are normal equal(not null-safe-equal),
eq_keys: Vec<(InputRef, InputRef)>,
}
#[allow(dead_code)]
impl EqJoinPredicate {
/// the new method for `JoinPredicate` without any analysis, check or rewrite.
pub fn new(other_cond: Condition, eq_keys: Vec<(InputRef, InputRef)>) -> Self {
Self {
other_cond,
eq_keys,
}
}

/// `create` will analyze the on clause condition and construct a `JoinPredicate`.
/// e.g.
/// ```sql
/// select a.v1, a.v2, b.v1, b.v2 from a,b on a.v1 = a.v2 and a.v1 = b.v1 and a.v2 > b.v2
/// ```
/// will call the `create` function with left_colsnum = 2 and on_clause is (supposed input_ref
/// count start from 0)
/// ```sql
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sql?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(it is just pass the doc test

/// input_ref(0) = input_ref(1) and input_ref(0) = input_ref(2) and input_ref(1) > input_ref(3)
/// ```
/// And the `create funcitons` should return `JoinPredicate`
/// ```sql
/// other_conds = Vec[input_ref(1) = input_ref(1), input_ref(1) > input_ref(3)],
/// keys= Vec[(1,1)]
/// ```
#[allow(unused_variables)]
pub fn create(left_cols_num: usize, right_cols_num: usize, on_clause: Condition) -> Self {
let mut other_cond = Condition::true_cond();
let mut eq_keys = vec![];

for cond_expr in on_clause.conjunctions {
let mut cols = FixedBitSet::with_capacity(left_cols_num + right_cols_num);
get_inputs_col_index(&cond_expr, &mut cols);
let from_left = cols
.ones()
.min()
.map(|mx| mx < left_cols_num)
.unwrap_or(false);
let from_right = cols
.ones()
.max()
.map(|mx| mx >= left_cols_num)
.unwrap_or(false);
match (from_left, from_right) {
(true, true) => {
// TODO: refactor with if_chain
let mut is_eq_cond = false;
if let ExprImpl::FunctionCall(function_call) = cond_expr.clone() {
if function_call.get_expr_type() == ExprType::Equal {
if let (_, ExprImpl::InputRef(x), ExprImpl::InputRef(y)) =
function_call.decompose_as_binary()
{
is_eq_cond = true;
if x.index() < y.index() {
eq_keys.push((*x, *y));
} else {
eq_keys.push((*y, *x));
}
}
}
}
if !is_eq_cond {
other_cond.conjunctions.push(cond_expr)
}
}
(true, false) => other_cond.conjunctions.push(cond_expr),
(false, true) => other_cond.conjunctions.push(cond_expr),
(false, false) => other_cond.conjunctions.push(cond_expr),
}
}
Self::new(other_cond, eq_keys)
}

/// Get join predicate's eq conds.
pub fn eq_cond(&self) -> Condition {
Condition {
conjunctions: self
.eq_keys
.iter()
.cloned()
.map(|(l, r)| {
FunctionCall::new(ExprType::Equal, vec![l.bound_expr(), r.bound_expr()])
.unwrap()
.bound_expr()
})
.collect(),
}
}

pub fn non_eq_cond(&self) -> Condition {
self.other_cond.clone()
}

pub fn all_cond(&self) -> Condition {
let mut cond = self.eq_cond();
cond.and(self.non_eq_cond());
cond
}

pub fn has_eq(&self) -> bool {
!self.eq_keys.is_empty()
}

pub fn has_non_eq(&self) -> bool {
!self.other_cond.always_true()
}

/// Get a reference to the join predicate's other cond.
pub fn other_cond(&self) -> &Condition {
&self.other_cond
}

/// Get a reference to the join predicate's eq keys.
pub fn eq_keys(&self) -> &[(InputRef, InputRef)] {
self.eq_keys.as_ref()
}

pub fn eq_indexes(&self) -> Vec<(usize, usize)> {
self.eq_keys
.iter()
.map(|(left, right)| (left.index(), right.index()))
.collect()
}

pub fn left_eq_indexes(&self) -> Vec<usize> {
self.eq_keys.iter().map(|(left, _)| left.index()).collect()
}
pub fn right_eq_indexes(&self) -> Vec<usize> {
self.eq_keys
.iter()
.map(|(_, right)| right.index())
.collect()
}
}
73 changes: 0 additions & 73 deletions rust/frontend/src/optimizer/plan_node/join_predicate.rs

This file was deleted.

Loading