Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public abstract class RwAggregate extends Aggregate {
.put(SqlKind.COUNT, AggCall.Type.COUNT)
.put(SqlKind.MIN, AggCall.Type.MIN)
.put(SqlKind.MAX, AggCall.Type.MAX)
.put(SqlKind.SINGLE_VALUE, AggCall.Type.SINGLE_VALUE)
.build();

/** derive Distribution trait for agg from input side. */
Expand Down
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ message AggCall {
COUNT = 4;
AVG = 5;
STRING_AGG = 6;
SINGLE_VALUE = 7;
}
message Arg {
InputRefExpr input = 1;
Expand Down
2 changes: 2 additions & 0 deletions rust/common/src/expr/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub enum AggKind {
RowCount,
Avg,
StringAgg,
SingleValue,
}

impl TryFrom<Type> for AggKind {
Expand All @@ -27,6 +28,7 @@ impl TryFrom<Type> for AggKind {
Type::Avg => Ok(AggKind::Avg),
Type::Count => Ok(AggKind::Count),
Type::StringAgg => Ok(AggKind::StringAgg),
Type::SingleValue => Ok(AggKind::SingleValue),
_ => Err(ErrorCode::InternalError("Unrecognized agg.".into()).into()),
}
}
Expand Down
16 changes: 16 additions & 0 deletions rust/common/src/vector_op/agg/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,17 @@ pub fn create_agg_state_unary(
(Max, max_str, varchar, varchar),
// Global Agg
(Sum, sum, int64, int64),
// We remark that SingleValue does not produce a runtime error when it receives zero row.
// Therefore, we do NOT need to change the logic in GeneralAgg::output_concrete.
(SingleValue, single_value, int16, int16),
(SingleValue, single_value, int32, int32),
(SingleValue, single_value, int64, int64),
(SingleValue, single_value, float32, float32),
(SingleValue, single_value, float64, float64),
(SingleValue, single_value, decimal, decimal),
(SingleValue, single_value, boolean, boolean),
(SingleValue, single_value_str, char, char),
(SingleValue, single_value_str, varchar, varchar),
];
Ok(state)
}
Expand Down Expand Up @@ -240,5 +251,10 @@ mod tests {
test_create! { decimal_type, Min, decimal_type, is_ok }
test_create! { bool_type, Min, bool_type, is_ok } // TODO(#359): revert to is_err
test_create! { char_type, Min, char_type, is_ok }

test_create! { int64_type, SingleValue, int64_type, is_ok }
test_create! { decimal_type, SingleValue, decimal_type, is_ok }
test_create! { bool_type, SingleValue, bool_type, is_ok }
test_create! { char_type, SingleValue, char_type, is_ok }
}
}
55 changes: 39 additions & 16 deletions rust/common/src/vector_op/agg/functions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::array::Array;
use crate::error::{ErrorCode, Result};

/// Essentially `RTFn` is an alias of the specific Fn. It was aliased not to
/// shorten the `where` clause of `GeneralAgg`, but to workaround an compiler
Expand All @@ -9,7 +10,7 @@ pub trait RTFn<'a, T, R> = Send
+ Fn(
Option<<R as Array>::RefItem<'a>>,
Option<<T as Array>::RefItem<'a>>,
) -> Option<<R as Array>::RefItem<'a>>
) -> Result<Option<<R as Array>::RefItem<'a>>>
where
T: Array,
R: Array;
Expand All @@ -19,44 +20,47 @@ use std::ops::Add;

use crate::types::ScalarRef;

pub fn sum<R, T>(result: Option<R>, input: Option<T>) -> Option<R>
pub fn sum<R, T>(result: Option<R>, input: Option<T>) -> Result<Option<R>>
where
R: From<T> + Add<Output = R> + Copy,
{
match (result, input) {
let res = match (result, input) {
(_, None) => result,
(None, Some(i)) => Some(R::from(i)),
(Some(r), Some(i)) => Some(r + R::from(i)),
}
};
Ok(res)
}

pub fn min<'a, T>(result: Option<T>, input: Option<T>) -> Option<T>
pub fn min<'a, T>(result: Option<T>, input: Option<T>) -> Result<Option<T>>
where
T: ScalarRef<'a> + PartialOrd,
{
match (result, input) {
let res = match (result, input) {
(None, _) => input,
(_, None) => result,
(Some(r), Some(i)) => Some(if r < i { r } else { i }),
}
};
Ok(res)
}

pub fn min_str<'a>(r: Option<&'a str>, i: Option<&'a str>) -> Option<&'a str> {
pub fn min_str<'a>(r: Option<&'a str>, i: Option<&'a str>) -> Result<Option<&'a str>> {
min(r, i)
}

pub fn max<'a, T>(result: Option<T>, input: Option<T>) -> Option<T>
pub fn max<'a, T>(result: Option<T>, input: Option<T>) -> Result<Option<T>>
where
T: ScalarRef<'a> + PartialOrd,
{
match (result, input) {
let res = match (result, input) {
(None, _) => input,
(_, None) => result,
(Some(r), Some(i)) => Some(if r > i { r } else { i }),
}
};
Ok(res)
}

pub fn max_str<'a>(r: Option<&'a str>, i: Option<&'a str>) -> Option<&'a str> {
pub fn max_str<'a>(r: Option<&'a str>, i: Option<&'a str>) -> Result<Option<&'a str>> {
max(r, i)
}

Expand All @@ -65,15 +69,34 @@ pub fn max_str<'a>(r: Option<&'a str>, i: Option<&'a str>) -> Option<&'a str> {
/// select count(*) from t; gives 1.
/// select count(v1) from t; gives 0.
/// select sum(v1) from t; gives null
pub fn count<T>(result: Option<i64>, input: Option<T>) -> Option<i64> {
match (result, input) {
pub fn count<T>(result: Option<i64>, input: Option<T>) -> Result<Option<i64>> {
let res = match (result, input) {
(None, None) => Some(0),
(Some(r), None) => Some(r),
(None, Some(_)) => Some(1),
(Some(r), Some(_)) => Some(r + 1),
}
};
Ok(res)
}

pub fn count_str(r: Option<i64>, i: Option<&str>) -> Option<i64> {
pub fn count_str(r: Option<i64>, i: Option<&str>) -> Result<Option<i64>> {
count(r, i)
}

pub fn single_value<'a, T>(result: Option<T>, input: Option<T>) -> Result<Option<T>>
where
T: ScalarRef<'a> + PartialOrd,
{
match (result, input) {
(None, _) => Ok(input),
(Some(_), None) => Ok(result),
Comment on lines +91 to +92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a column with [None, Some(2), None] would return Ok rather than Err? It is an error in pg.

btw, how does it work in 2-phase agg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

I believe these two mistakes that need to be corrected. Let me track both points in a new issue.

(Some(_), Some(_)) => Err(ErrorCode::InternalError(
"SingleValue aggregation can only accept exactly one value. But there is more than one.".to_string(),
)
.into()),
}
}

pub fn single_value_str<'a>(r: Option<&'a str>, i: Option<&'a str>) -> Result<Option<&'a str>> {
single_value(r, i)
}
60 changes: 53 additions & 7 deletions rust/common/src/vector_op/agg/general_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,20 @@ where
}
}
pub(super) fn update_with_scalar_concrete(&mut self, input: &T, row_id: usize) -> Result<()> {
self.result = (self.f)(
let datum = (self.f)(
self.result.as_ref().map(|x| x.as_scalar_ref()),
input.value_at(row_id),
)
)?
.map(|x| x.to_owned_scalar());
self.result = datum;
Ok(())
}
pub(super) fn update_concrete(&mut self, input: &T) -> Result<()> {
let r = input
.iter()
.fold(self.result.as_ref().map(|x| x.as_scalar_ref()), &mut self.f)
.map(|x| x.to_owned_scalar());
let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref());
for datum in input.iter() {
cur = (self.f)(cur, datum)?;
}
let r = cur.map(|x| x.to_owned_scalar());
self.result = r;
Ok(())
}
Expand All @@ -67,7 +69,7 @@ where
builder.append(cur)?;
cur = None;
}
cur = (self.f)(cur, v);
cur = (self.f)(cur, v)?;
}
self.result = cur.map(|x| x.to_owned_scalar());
Ok(())
Expand Down Expand Up @@ -186,6 +188,50 @@ mod tests {
builder.finish()
}

#[test]
fn single_value_int32() -> Result<()> {
// zero row
let input = I32Array::from_slice(&[None]).unwrap();
let agg_type = AggKind::SingleValue;
let input_type = DataType::Int32;
let return_type = DataType::Int32;
let actual = eval_agg(
input_type.clone(),
Arc::new(input.into()),
&agg_type,
return_type.clone(),
ArrayBuilderImpl::Int32(I32ArrayBuilder::new(0)?),
)?;
let actual = actual.as_int32();
let actual = actual.iter().collect::<Vec<_>>();
assert_eq!(actual, &[None]);

// one row
let input = I32Array::from_slice(&[Some(1)]).unwrap();
let actual = eval_agg(
input_type.clone(),
Arc::new(input.into()),
&agg_type,
return_type.clone(),
ArrayBuilderImpl::Int32(I32ArrayBuilder::new(0)?),
)?;
let actual = actual.as_int32();
let actual = actual.iter().collect::<Vec<_>>();
assert_eq!(actual, &[Some(1)]);

// more than one row
let input = I32Array::from_slice(&[Some(1), Some(2)]).unwrap();
let actual = eval_agg(
input_type,
Arc::new(input.into()),
&agg_type,
return_type,
ArrayBuilderImpl::Int32(I32ArrayBuilder::new(0)?),
);
assert!(actual.is_err());
Ok(())
}

#[test]
fn vec_sum_int32() -> Result<()> {
let input = I32Array::from_slice(&[Some(1), Some(2), Some(3)]).unwrap();
Expand Down
25 changes: 13 additions & 12 deletions rust/common/src/vector_op/agg/general_distinct_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,26 @@ where
.value_at(row_id)
.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value());
if self.exists.insert(value) {
self.result = (self.f)(
let datum = (self.f)(
self.result.as_ref().map(|x| x.as_scalar_ref()),
input.value_at(row_id),
)
)?
.map(|x| x.to_owned_scalar());
self.result = datum;
}
Ok(())
}

fn update_concrete(&mut self, input: &T) -> Result<()> {
let r = input
.iter()
.filter(|scalar_ref| {
self.exists.insert(
scalar_ref.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value()),
)
})
.fold(self.result.as_ref().map(|x| x.as_scalar_ref()), &mut self.f)
.map(|x| x.to_owned_scalar());
let input = input.iter().filter(|scalar_ref| {
self.exists
.insert(scalar_ref.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value()))
});
let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref());
for datum in input {
cur = (self.f)(cur, datum)?;
}
let r = cur.map(|x| x.to_owned_scalar());
self.result = r;
Ok(())
}
Expand All @@ -90,7 +91,7 @@ where
}
let scalar_impl = v.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value());
if self.exists.insert(scalar_impl) {
cur = (self.f)(cur, v);
cur = (self.f)(cur, v)?;
}
}
self.result = cur.map(|x| x.to_owned_scalar());
Expand Down
1 change: 1 addition & 0 deletions rust/stream/src/executor/managed_state/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ impl<S: StateStore> ManagedStateImpl<S> {
ManagedValueState::new(agg_call, keyspace, row_count).await?,
))
}
AggKind::SingleValue => todo!(),
}
}
}