Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 3 additions & 4 deletions rust/common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ pub enum ErrorCode {
TaskNotFound,
#[error("Item not found: {0}")]
ItemNotFound(String),

#[error(r#"invalid input syntax for {0} type: "{1}""#)]
InvalidInputSyntax(String, String),
#[error("Invalid input syntax: {0}")]
InvalidInputSyntax(String),
#[error("Can not compare in memory: {0}")]
MemComparableError(MemComparableError),

Expand Down Expand Up @@ -198,7 +197,7 @@ impl ErrorCode {
ErrorCode::TaskNotFound => 10,
ErrorCode::ProstError(_) => 11,
ErrorCode::ItemNotFound(_) => 13,
ErrorCode::InvalidInputSyntax(_, _) => 14,
ErrorCode::InvalidInputSyntax(_) => 14,
ErrorCode::MemComparableError(_) => 15,
ErrorCode::MetaError(_) => 18,
ErrorCode::CatalogError(..) => 21,
Expand Down
2 changes: 1 addition & 1 deletion rust/common/src/vector_op/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ pub fn str_to_bool(input: &str) -> Result<bool> {
{
Ok(false)
} else {
Err(InvalidInputSyntax("boolean".to_string(), input.to_string()).into())
Err(InvalidInputSyntax(format!("'{}' is not a valid bool", input)).into())
}
}

Expand Down
8 changes: 4 additions & 4 deletions rust/common/src/vector_op/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ pub fn substr_start_for(
writer: BytesWriter,
) -> Result<BytesGuard> {
if count < 0 {
return Err(ErrorCode::InvalidInputSyntax(
String::from("non-negative substring length"),
count.to_string(),
)
return Err(ErrorCode::InvalidInputSyntax(format!(
"length in substr should be non-negative: {}",
count
))
.into());
}
let begin = max(start - 1, 0) as usize;
Expand Down
2 changes: 2 additions & 0 deletions rust/frontend/src/binder/bind_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl ColumnBinding {
pub struct BindContext {
// Mapping column name to `ColumnBinding`
pub columns: HashMap<String, Vec<ColumnBinding>>,
pub in_values_clause: bool,
}

impl BindContext {
Expand All @@ -29,6 +30,7 @@ impl BindContext {
BindContext {
// tables: HashMap::new(),
columns: HashMap::new(),
in_values_clause: false,
}
}
}
69 changes: 69 additions & 0 deletions rust/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::expr::AggKind;
use risingwave_sqlparser::ast::{Function, FunctionArg, FunctionArgExpr};

use crate::binder::Binder;
use crate::expr::{AggCall, ExprImpl, ExprType, FunctionCall};

impl Binder {
pub(super) fn bind_function(&mut self, f: Function) -> Result<ExprImpl> {
let inputs = f
.args
.into_iter()
.map(|arg| self.bind_function_arg(arg))
.collect::<Result<Vec<ExprImpl>>>()?;

if f.name.0.len() == 1 {
let function_name = f.name.0.get(0).unwrap();
let agg_kind = match function_name.value.as_str() {
"count" => Some(AggKind::Count),
"sum" => Some(AggKind::Sum),
"min" => Some(AggKind::Min),
"max" => Some(AggKind::Max),
_ => None,
};
if let Some(kind) = agg_kind {
if self.context.in_values_clause {
return Err(ErrorCode::InvalidInputSyntax(
"aggregate functions are not allowed in VALUES".to_string(),
)
.into());
}
return Ok(ExprImpl::AggCall(Box::new(AggCall::new(kind, inputs)?)));
}
let function_type = match function_name.value.as_str() {
"substr" => ExprType::Substr,
_ => {
return Err(ErrorCode::NotImplementedError(format!(
"unsupported function: {:?}",
function_name
))
.into())
}
};
Ok(ExprImpl::FunctionCall(Box::new(
FunctionCall::new(function_type, inputs).unwrap(),
)))
} else {
Err(
ErrorCode::NotImplementedError(format!("unsupported function: {:?}", f.name))
.into(),
)
}
}

fn bind_function_expr_arg(&mut self, arg_expr: FunctionArgExpr) -> Result<ExprImpl> {
match arg_expr {
FunctionArgExpr::Expr(expr) => self.bind_expr(expr),
FunctionArgExpr::QualifiedWildcard(_) => todo!(),
FunctionArgExpr::Wildcard => todo!(),
}
}

fn bind_function_arg(&mut self, arg: FunctionArg) -> Result<ExprImpl> {
match arg {
FunctionArg::Unnamed(expr) => self.bind_function_expr_arg(expr),
FunctionArg::Named { .. } => todo!(),
}
}
}
72 changes: 70 additions & 2 deletions rust/frontend/src/binder/expr/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use risingwave_common::error::{ErrorCode, Result};
use risingwave_sqlparser::ast::Expr;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{DataType as AstDataType, Expr, UnaryOperator};

use crate::binder::Binder;
use crate::expr::ExprImpl;
use crate::expr::{Expr as _, ExprImpl, ExprType, FunctionCall};

mod binary_op;
mod column;
mod function;
mod value;

impl Binder {
Expand All @@ -17,8 +19,74 @@ impl Binder {
Expr::BinaryOp { left, op, right } => Ok(ExprImpl::FunctionCall(Box::new(
self.bind_binary_op(*left, op, *right)?,
))),
Expr::UnaryOp { op, expr } => Ok(ExprImpl::FunctionCall(Box::new(
self.bind_unary_expr(op, *expr)?,
))),
Expr::Nested(expr) => self.bind_expr(*expr),
Expr::Cast { expr, data_type } => Ok(ExprImpl::FunctionCall(Box::new(
self.bind_cast(*expr, data_type)?,
))),
Expr::Function(f) => Ok(self.bind_function(f)?),
_ => Err(ErrorCode::NotImplementedError(format!("{:?}", expr)).into()),
}
}

pub(super) fn bind_unary_expr(
&mut self,
op: UnaryOperator,
expr: Expr,
) -> Result<FunctionCall> {
let func_type = match op {
UnaryOperator::Minus => ExprType::Neg,
UnaryOperator::Not => ExprType::Not,
_ => {
return Err(ErrorCode::NotImplementedError(format!(
"unsupported expression: {:?}",
op
))
.into())
}
};
let expr = self.bind_expr(expr)?;
let return_type = expr.return_type();
FunctionCall::new(func_type, vec![expr]).ok_or_else(|| {
ErrorCode::NotImplementedError(format!("{:?} {:?}", op, return_type)).into()
})
}

pub(super) fn bind_cast(&mut self, expr: Expr, data_type: AstDataType) -> Result<FunctionCall> {
Ok(FunctionCall::new_with_return_type(
ExprType::Cast,
vec![self.bind_expr(expr)?],
bind_data_type(&data_type)?,
))
}
}

pub fn bind_data_type(data_type: &AstDataType) -> Result<DataType> {
let data_type = match data_type {
AstDataType::SmallInt(_) => DataType::Int16,
AstDataType::Int(_) => DataType::Int32,
AstDataType::BigInt(_) => DataType::Int64,
AstDataType::Float(_) => DataType::Float64,
AstDataType::Double => DataType::Float64,
AstDataType::String => DataType::Varchar,
AstDataType::Boolean => DataType::Boolean,
AstDataType::Char(_) => DataType::Char,
AstDataType::Varchar(_) => DataType::Varchar,
AstDataType::Decimal(_, _) => DataType::Decimal,
AstDataType::Date => DataType::Date,
AstDataType::Time => DataType::Time,
AstDataType::Timestamp => DataType::Timestamp,
AstDataType::Interval => DataType::Interval,
AstDataType::Real => DataType::Float32,
_ => {
return Err(ErrorCode::NotImplementedError(format!(
"unsupported data type: {:?}",
data_type
))
.into())
}
};
Ok(data_type)
}
5 changes: 5 additions & 0 deletions rust/frontend/src/binder/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,15 @@ impl Binder {
Value::SingleQuotedString(s) => {
Ok(Literal::new(Some(ScalarImpl::Utf8(s)), DataType::Varchar))
}
Value::Boolean(b) => self.bind_bool(b),
_ => Err(ErrorCode::NotImplementedError(format!("{:?}", value)).into()),
}
}

fn bind_bool(&mut self, b: bool) -> Result<Literal> {
Ok(Literal::new(Some(ScalarImpl::Bool(b)), DataType::Boolean))
}

fn bind_number(&mut self, s: String, _b: bool) -> Result<Literal> {
let (data, data_type) = if let Ok(int_32) = s.parse::<i32>() {
(Some(ScalarImpl::Int32(int_32)), DataType::Int32)
Expand Down
2 changes: 1 addition & 1 deletion rust/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use risingwave_common::error::Result;
use risingwave_sqlparser::ast::Statement;

mod bind_context;
mod expr;
pub(crate) mod expr;
mod insert;
mod projection;
mod query;
Expand Down
3 changes: 3 additions & 0 deletions rust/frontend/src/binder/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ pub struct BoundValues {

impl Binder {
pub(super) fn bind_values(&mut self, values: Values) -> Result<BoundValues> {
self.context.in_values_clause = true;
let vec2d = values.0;
let bound = vec2d
.into_iter()
.map(|vec| vec.into_iter().map(|expr| self.bind_expr(expr)).collect())
.collect::<Result<Vec<Vec<_>>>>()?;
self.context.in_values_clause = false;

// calc column type and insert casts here
let mut types = bound[0]
.iter()
Expand Down
19 changes: 14 additions & 5 deletions rust/frontend/src/expr/agg_call.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use risingwave_common::error::Result;
use risingwave_common::expr::AggKind;
use risingwave_common::types::DataType;

Expand All @@ -10,12 +11,20 @@ pub struct AggCall {
inputs: Vec<ExprImpl>,
}
impl AggCall {
#![allow(unreachable_code)]
#![allow(unused_variables)]
#![allow(clippy::diverging_sub_expression)]
pub fn new(agg_kind: AggKind, inputs: Vec<ExprImpl>) -> Option<Self> {
let return_type = todo!(); // should be derived from inputs
Some(AggCall {
/// Returns error if the function name matches with an existing function
/// but with illegal arguments. `Ok(None)` is returned when there's no matching
/// function.
pub fn new(agg_kind: AggKind, inputs: Vec<ExprImpl>) -> Result<Self> {
// TODO(TaoWu): Add arguments validator.
let return_type = match agg_kind {
AggKind::Min => inputs.get(0).unwrap().return_type(),
AggKind::Max => inputs.get(0).unwrap().return_type(),
AggKind::Sum => DataType::Int64,
AggKind::Count => DataType::Int64,
_ => todo!(),
}; // should be derived from inputs
Ok(AggCall {
agg_kind,
return_type,
inputs,
Expand Down
13 changes: 8 additions & 5 deletions rust/frontend/src/expr/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,15 @@ fn build_type_derive_map() -> HashMap<FuncSign, DataTypeName> {
ExprType::GreaterThan,
ExprType::GreaterThanOrEqual,
];
let logical_exprs = vec![ExprType::And, ExprType::Or, ExprType::Not];
let bool_check_exprs = vec![
// Exprs that accept two bools and return a bool.
let bool_binary_exprs = vec![ExprType::And, ExprType::Or];
// Exprs that accept one bool and return a bool.
let bool_unary_exprs = vec![
ExprType::IsTrue,
ExprType::IsNotTrue,
ExprType::IsFalse,
ExprType::IsNotFalse,
ExprType::Not,
];
let null_check_exprs = vec![
ExprType::IsNull,
Expand All @@ -205,15 +208,15 @@ fn build_type_derive_map() -> HashMap<FuncSign, DataTypeName> {
DataTypeName::Boolean,
);
}
for expr in logical_exprs {
for expr in bool_binary_exprs {
map.insert(
FuncSign::new_binary(expr, DataTypeName::Boolean, DataTypeName::Boolean),
DataTypeName::Boolean,
);
}
for expr in bool_check_exprs {
for expr in bool_unary_exprs {
map.insert(
FuncSign::new_binary(expr, DataTypeName::Boolean, DataTypeName::Boolean),
FuncSign::new_unary(expr, DataTypeName::Boolean),
DataTypeName::Boolean,
);
}
Expand Down
19 changes: 4 additions & 15 deletions rust/frontend/src/handler/create_table.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use pgwire::pg_response::{PgResponse, StatementType};
use risingwave_common::error::Result;
use risingwave_common::types::DataType;
use risingwave_pb::meta::table::Info;
use risingwave_pb::meta::Table;
use risingwave_pb::plan::{ColumnDesc, TableSourceInfo};
use risingwave_sqlparser::ast::{ColumnDef, DataType as AstDataType, ObjectName};
use risingwave_sqlparser::ast::{ColumnDef, ObjectName};

use crate::binder::expr::bind_data_type;
use crate::catalog::catalog_service::DEFAULT_SCHEMA_NAME;
use crate::session::SessionImpl;

Expand All @@ -17,24 +17,13 @@ fn columns_to_prost(columns: &[ColumnDef]) -> Result<Vec<ColumnDesc>> {
Ok(ColumnDesc {
column_id: idx as i32,
name: col.name.to_string(),
column_type: Some(convert_data_type(&col.data_type).to_protobuf()?),
column_type: Some(bind_data_type(&col.data_type)?.to_protobuf()?),
..Default::default()
})
})
.collect::<Result<_>>()
}

fn convert_data_type(data_type: &AstDataType) -> DataType {
match data_type {
AstDataType::SmallInt(_) => DataType::Int16,
AstDataType::Int(_) => DataType::Int32,
AstDataType::BigInt(_) => DataType::Int64,
AstDataType::Float(_) => DataType::Float32,
AstDataType::Double => DataType::Float64,
_ => unimplemented!("Unsupported data type {:?} in create table", data_type),
}
}

pub async fn handle_create_table(
session: &SessionImpl,
table_name: ObjectName,
Expand Down Expand Up @@ -92,7 +81,7 @@ mod tests {
expected_map.insert("v1".to_string(), DataType::Int16);
expected_map.insert("v2".to_string(), DataType::Int32);
expected_map.insert("v3".to_string(), DataType::Int64);
expected_map.insert("v4".to_string(), DataType::Float32);
expected_map.insert("v4".to_string(), DataType::Float64);
expected_map.insert("v5".to_string(), DataType::Float64);
assert_eq!(columns, expected_map);
}
Expand Down
Loading