diff --git a/crates/integrations/datafusion/tests/pk_tables.rs b/crates/integrations/datafusion/tests/pk_tables.rs index fa0bd6f1..c3cc01f2 100644 --- a/crates/integrations/datafusion/tests/pk_tables.rs +++ b/crates/integrations/datafusion/tests/pk_tables.rs @@ -1973,3 +1973,523 @@ async fn test_pk_dv_deduplicate_read_no_error() { result.err() ); } + +// ======================= Aggregation Engine ======================= + +/// Basic: aggregation engine sums numeric column and concatenates string +/// column across overlapping primary keys. +#[tokio::test] +async fn test_pk_aggregation_sum_and_listagg_fixed_bucket_e2e() { + let (_tmp, sql_context) = setup_sql_context().await; + + sql_context + .sql( + "CREATE TABLE paimon.test_db.t_agg_sum ( + id INT NOT NULL, amount INT, tag STRING, + PRIMARY KEY (id) + ) WITH ( + 'bucket' = '1', + 'merge-engine' = 'aggregation', + 'fields.amount.aggregate-function' = 'sum', + 'fields.tag.aggregate-function' = 'listagg', + 'fields.tag.list-agg-delimiter' = '|' + )", + ) + .await + .unwrap(); + + sql_context + .sql( + "INSERT INTO paimon.test_db.t_agg_sum VALUES \ + (1, 10, 'a'), (2, 20, 'x')", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + sql_context + .sql( + "INSERT INTO paimon.test_db.t_agg_sum VALUES \ + (1, 5, 'b'), (2, 7, CAST(NULL AS STRING)), (3, 99, 'solo')", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + let batches = sql_context + .sql("SELECT id, amount, tag FROM paimon.test_db.t_agg_sum ORDER BY id") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let mut rows: Vec<(i32, Option, Option)> = Vec::new(); + for batch in &batches { + let ids = batch + .column_by_name("id") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + let amounts = batch + .column_by_name("amount") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + let tags = batch + .column_by_name("tag") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + for i in 0..batch.num_rows() { + rows.push(( + ids.value(i), + if amounts.is_null(i) { + None + } else { + Some(amounts.value(i)) + }, + if tags.is_null(i) { + None + } else { + Some(tags.value(i).to_string()) + }, + )); + } + } + + assert_eq!( + rows, + vec![ + (1, Some(15), Some("a|b".to_string())), + (2, Some(27), Some("x".to_string())), + (3, Some(99), Some("solo".to_string())), + ] + ); +} + +/// `fields.default-aggregate-function` applies to any column without an +/// explicit per-field aggregator. +#[tokio::test] +async fn test_pk_aggregation_default_function() { + let (_tmp, sql_context) = setup_sql_context().await; + + sql_context + .sql( + "CREATE TABLE paimon.test_db.t_agg_default ( + id INT NOT NULL, a INT, b STRING, + PRIMARY KEY (id) + ) WITH ( + 'bucket' = '1', + 'merge-engine' = 'aggregation', + 'fields.default-aggregate-function' = 'last_non_null_value' + )", + ) + .await + .unwrap(); + + sql_context + .sql("INSERT INTO paimon.test_db.t_agg_default VALUES (1, 10, 'old')") + .await + .unwrap() + .collect() + .await + .unwrap(); + sql_context + .sql( + "INSERT INTO paimon.test_db.t_agg_default VALUES \ + (1, CAST(NULL AS INT), 'new')", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + sql_context + .sql("INSERT INTO paimon.test_db.t_agg_default VALUES (1, 99, CAST(NULL AS STRING))") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let batches = sql_context + .sql("SELECT id, a, b FROM paimon.test_db.t_agg_default") + .await + .unwrap() + .collect() + .await + .unwrap(); + + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 1); + let batch = &batches[0]; + let id = batch + .column_by_name("id") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + let a = batch + .column_by_name("a") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + let b = batch + .column_by_name("b") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + assert_eq!(id.value(0), 1); + assert_eq!(a.value(0), 99); // latest non-null int across the three commits + assert_eq!(b.value(0), "new"); // latest non-null string +} + +/// Mixed aggregators in a single table: sum / max / bool_or / count. +#[tokio::test] +async fn test_pk_aggregation_mixed_aggregators() { + let (_tmp, sql_context) = setup_sql_context().await; + + sql_context + .sql( + "CREATE TABLE paimon.test_db.t_agg_mixed ( + id INT NOT NULL, total INT, peak INT, ok BOOLEAN, cnt BIGINT, + PRIMARY KEY (id) + ) WITH ( + 'bucket' = '1', + 'merge-engine' = 'aggregation', + 'fields.total.aggregate-function' = 'sum', + 'fields.peak.aggregate-function' = 'max', + 'fields.ok.aggregate-function' = 'bool_or', + 'fields.cnt.aggregate-function' = 'count' + )", + ) + .await + .unwrap(); + + sql_context + .sql( + "INSERT INTO paimon.test_db.t_agg_mixed VALUES \ + (1, 10, 5, false, CAST(1 AS BIGINT)), \ + (1, 5, 8, true, CAST(1 AS BIGINT)), \ + (1, 3, 7, false, CAST(1 AS BIGINT))", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + let batches = sql_context + .sql("SELECT id, total, peak, ok, cnt FROM paimon.test_db.t_agg_mixed") + .await + .unwrap() + .collect() + .await + .unwrap(); + + use datafusion::arrow::array::{BooleanArray, Int64Array}; + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 1); + let batch = &batches[0]; + let total = batch + .column_by_name("total") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + let peak = batch + .column_by_name("peak") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + let ok = batch + .column_by_name("ok") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + let cnt = batch + .column_by_name("cnt") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + assert_eq!(total.value(0), 18); // 10 + 5 + 3 + assert_eq!(peak.value(0), 8); // max(5, 8, 7) + assert!(ok.value(0)); // bool_or = true if any is true + assert_eq!(cnt.value(0), 3); // three non-null rows +} + +/// `sequence.field` forces the named column to `last_value`, even when the +/// user explicitly configures another aggregator for it. +#[tokio::test] +async fn test_pk_aggregation_sequence_field_forced_last_value() { + let (_tmp, sql_context) = setup_sql_context().await; + + sql_context + .sql( + "CREATE TABLE paimon.test_db.t_agg_seq ( + id INT NOT NULL, amount INT, ts INT, + PRIMARY KEY (id) + ) WITH ( + 'bucket' = '1', + 'merge-engine' = 'aggregation', + 'sequence.field' = 'ts', + 'fields.amount.aggregate-function' = 'sum', + 'fields.ts.aggregate-function' = 'sum' + )", + ) + .await + .unwrap(); + + sql_context + .sql( + "INSERT INTO paimon.test_db.t_agg_seq VALUES \ + (1, 10, 100), (1, 20, 250)", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + let batches = sql_context + .sql("SELECT id, amount, ts FROM paimon.test_db.t_agg_seq") + .await + .unwrap() + .collect() + .await + .unwrap(); + + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 1); + let batch = &batches[0]; + let amount = batch + .column_by_name("amount") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + let ts = batch + .column_by_name("ts") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + assert_eq!(amount.value(0), 30); // sum still applies + assert_eq!(ts.value(0), 250); // forced last_value over sum +} + +/// Aggregation engine reads must surface Unsupported when a DELETE/UPDATE +/// row appears. +#[tokio::test] +async fn test_pk_aggregation_rejects_delete() { + let (_tmp, sql_context) = setup_sql_context().await; + + sql_context + .sql( + "CREATE TABLE paimon.test_db.t_agg_del ( + id INT NOT NULL, amount INT, + PRIMARY KEY (id) + ) WITH ( + 'bucket' = '1', + 'merge-engine' = 'aggregation', + 'fields.amount.aggregate-function' = 'sum' + )", + ) + .await + .unwrap(); + + sql_context + .sql("INSERT INTO paimon.test_db.t_agg_del VALUES (1, 10), (2, 20)") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // DELETE may either fail at planning or surface Unsupported at execution. + // Either way the error must mention the aggregation engine refusing the + // retract row; we explicitly assert both branches so a future parser + // change cannot silently turn this into a no-op pass. + let plan_result = sql_context + .sql("DELETE FROM paimon.test_db.t_agg_del WHERE id = 1") + .await; + match plan_result { + Ok(df) => { + let exec = df.collect().await; + assert!(exec.is_err(), "DELETE on aggregation table should fail"); + let msg = format!("{:?}", exec.err().unwrap()); + assert!( + msg.contains("aggregation") + || msg.contains("DELETE") + || msg.contains("UPDATE_BEFORE"), + "expected aggregation engine to reject DELETE at execution, got {msg}" + ); + } + Err(e) => { + let msg = format!("{e:?}"); + assert!( + msg.contains("aggregation") + || msg.contains("DELETE") + || msg.contains("Unsupported"), + "expected aggregation engine to reject DELETE at planning, got {msg}" + ); + } + } +} + +/// CREATE TABLE with `merge-engine=aggregation` but no `aggregate-function` +/// configured for any value column should fail at runtime with a clear +/// message instructing the user which option to set. +#[tokio::test] +async fn test_pk_aggregation_requires_agg_function_per_field() { + let (_tmp, sql_context) = setup_sql_context().await; + + sql_context + .sql( + "CREATE TABLE paimon.test_db.t_agg_missing ( + id INT NOT NULL, amount INT, + PRIMARY KEY (id) + ) WITH ( + 'bucket' = '1', + 'merge-engine' = 'aggregation' + )", + ) + .await + .unwrap(); + + sql_context + .sql("INSERT INTO paimon.test_db.t_agg_missing VALUES (1, 10), (1, 20)") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // First read should fail with the per-field config error. + let err = sql_context + .sql("SELECT * FROM paimon.test_db.t_agg_missing") + .await + .unwrap() + .collect() + .await + .unwrap_err(); + let msg = format!("{err:?}"); + assert!( + msg.contains("aggregate-function") && msg.contains("amount"), + "expected missing aggregate-function error to name the field, got {msg}" + ); +} + +/// CREATE TABLE should reject unsupported aggregation knobs in basic mode. +#[tokio::test] +async fn test_pk_aggregation_rejects_unsupported_options_at_create() { + let (_tmp, sql_context) = setup_sql_context().await; + + let err = sql_context + .sql( + "CREATE TABLE paimon.test_db.t_agg_bad ( + id INT NOT NULL, amount INT, + PRIMARY KEY (id) + ) WITH ( + 'bucket' = '1', + 'merge-engine' = 'aggregation', + 'fields.amount.aggregate-function' = 'sum', + 'fields.amount.ignore-retract' = 'true' + )", + ) + .await + .expect_err("CREATE TABLE with ignore-retract should fail in basic mode"); + let msg = format!("{err:?}"); + assert!( + msg.contains("ignore-retract"), + "expected create-time rejection to mention ignore-retract, got {msg}" + ); +} + +/// All-NULL aggregation group on a nullable `sum` column should emit NULL +/// rather than 0 or an error: nothing was observed, so there is no +/// arithmetic result to surface. +#[tokio::test] +async fn test_pk_aggregation_sum_all_null_emits_null_for_nullable_column() { + let (_tmp, sql_context) = setup_sql_context().await; + + sql_context + .sql( + "CREATE TABLE paimon.test_db.t_agg_null ( + id INT NOT NULL, amount INT, + PRIMARY KEY (id) + ) WITH ( + 'bucket' = '1', + 'merge-engine' = 'aggregation', + 'fields.amount.aggregate-function' = 'sum' + )", + ) + .await + .unwrap(); + + sql_context + .sql( + "INSERT INTO paimon.test_db.t_agg_null VALUES \ + (1, CAST(NULL AS INT)), (1, CAST(NULL AS INT))", + ) + .await + .unwrap() + .collect() + .await + .unwrap(); + + let batches = sql_context + .sql("SELECT id, amount FROM paimon.test_db.t_agg_null") + .await + .unwrap() + .collect() + .await + .unwrap(); + + assert_eq!(batches.iter().map(|b| b.num_rows()).sum::(), 1); + let amount = batches[0] + .column_by_name("amount") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + assert!(amount.is_null(0), "sum over all-NULL group should be NULL"); +} + +/// Regression guard: end-to-end SELECT on an aggregation table must traverse +/// the KeyValueFileReader path (TableRead::to_arrow → read_pk → read_kv), +/// not silently fall through to read_raw. The basic correctness assertion +/// (sum aggregation) implies this routing — a fallthrough to read_raw would +/// return the raw rows unmerged, breaking the sum. +#[tokio::test] +async fn test_pk_aggregation_routing_uses_kv_path() { + let (_tmp, sql_context) = setup_sql_context().await; + + sql_context + .sql( + "CREATE TABLE paimon.test_db.t_agg_route ( + id INT NOT NULL, amount INT, + PRIMARY KEY (id) + ) WITH ( + 'bucket' = '1', + 'merge-engine' = 'aggregation', + 'fields.amount.aggregate-function' = 'sum' + )", + ) + .await + .unwrap(); + + // Two rows with the same key in a single INSERT — read_raw would return 2 + // rows; read_kv (with AggregateMergeFunction) collapses them into 1 with + // amount=30. + sql_context + .sql("INSERT INTO paimon.test_db.t_agg_route VALUES (1, 10), (1, 20)") + .await + .unwrap() + .collect() + .await + .unwrap(); + + let n = row_count(&sql_context, "SELECT * FROM paimon.test_db.t_agg_route").await; + assert_eq!( + n, 1, + "aggregation table must collapse same-PK rows; got {n} rows which suggests \ + to_arrow fell through to read_raw" + ); + let batches = sql_context + .sql("SELECT amount FROM paimon.test_db.t_agg_route") + .await + .unwrap() + .collect() + .await + .unwrap(); + let amount = batches[0] + .column_by_name("amount") + .and_then(|c| c.as_any().downcast_ref::()) + .unwrap(); + assert_eq!(amount.value(0), 30); +} diff --git a/crates/paimon/src/spec/aggregation.rs b/crates/paimon/src/spec/aggregation.rs new file mode 100644 index 00000000..024d2bf5 --- /dev/null +++ b/crates/paimon/src/spec/aggregation.rs @@ -0,0 +1,238 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +const MERGE_ENGINE_OPTION: &str = "merge-engine"; +const AGGREGATION_ENGINE: &str = "aggregation"; +const IGNORE_DELETE_OPTION: &str = "ignore-delete"; +const IGNORE_DELETE_SUFFIX: &str = ".ignore-delete"; +const AGGREGATION_REMOVE_RECORD_ON_DELETE_OPTION: &str = "aggregation.remove-record-on-delete"; +const FIELDS_DEFAULT_AGG_FUNCTION_OPTION: &str = "fields.default-aggregate-function"; +const FIELDS_PREFIX: &str = "fields."; +const AGG_FUNCTION_SUFFIX: &str = ".aggregate-function"; +const IGNORE_RETRACT_SUFFIX: &str = ".ignore-retract"; +const DISTINCT_SUFFIX: &str = ".distinct"; +const SEQUENCE_GROUP_SUFFIX: &str = ".sequence-group"; +const NESTED_KEY_SUFFIX: &str = ".nested-key"; +const COUNT_LIMIT_SUFFIX: &str = ".count-limit"; + +/// Minimal aggregation mode recognized by the current Rust implementation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum AggregationMode { + Basic, +} + +/// Aggregation-merge-engine option inspection and validation. +/// +/// The basic mode accepts only `merge-engine=aggregation` on a PK table with +/// the following option keys: +/// - `fields.default-aggregate-function` +/// - `fields..aggregate-function` +/// - `fields..list-agg-delimiter` +/// +/// All other aggregation-specific knobs (`ignore-retract`, `distinct`, +/// `nested-key`, `count-limit`, `aggregation.remove-record-on-delete`, +/// `sequence-group`, `ignore-delete`) are rejected. Retract rows +/// (DELETE / UPDATE_BEFORE) are rejected at runtime by the merge function. +#[derive(Debug, Clone, Copy)] +pub(crate) struct AggregationConfig<'a> { + options: &'a HashMap, +} + +impl<'a> AggregationConfig<'a> { + pub(crate) fn new(options: &'a HashMap) -> Self { + Self { options } + } + + pub(crate) fn is_enabled(&self) -> bool { + self.options + .get(MERGE_ENGINE_OPTION) + .is_some_and(|value| value.eq_ignore_ascii_case(AGGREGATION_ENGINE)) + } + + /// Validate options at CREATE TABLE time. + pub(crate) fn validate_create_mode( + &self, + has_primary_keys: bool, + ) -> crate::Result> { + match self.validated_mode(has_primary_keys) { + Ok(mode) => Ok(mode), + Err(unsupported_options) => Err(crate::Error::ConfigInvalid { + message: format!( + "merge-engine=aggregation only supports the basic mode in this build; unsupported options: {}", + unsupported_options.join(", ") + ), + }), + } + } + + /// Validate options at read/write runtime. + pub(crate) fn validate_runtime_mode( + &self, + has_primary_keys: bool, + table_name: &str, + ) -> crate::Result> { + match self.validated_mode(has_primary_keys) { + Ok(mode) => Ok(mode), + Err(unsupported_options) => Err(crate::Error::Unsupported { + message: format!( + "Table '{table_name}' uses merge-engine=aggregation options not supported by this build: {}", + unsupported_options.join(", ") + ), + }), + } + } + + fn validated_mode( + &self, + has_primary_keys: bool, + ) -> std::result::Result, Vec> { + if !has_primary_keys || !self.is_enabled() { + return Ok(None); + } + + let unsupported_options = self.unsupported_option_keys(); + if !unsupported_options.is_empty() { + return Err(unsupported_options); + } + + Ok(Some(AggregationMode::Basic)) + } + + fn unsupported_option_keys(&self) -> Vec { + let mut keys: Vec = self + .options + .keys() + .filter(|key| is_unsupported_aggregation_option(key)) + .cloned() + .collect(); + keys.sort(); + keys + } + + /// Per-field aggregate function configured via `fields..aggregate-function`. + pub(crate) fn agg_function_for_field(&self, field_name: &str) -> Option<&str> { + let key = format!("{FIELDS_PREFIX}{field_name}{AGG_FUNCTION_SUFFIX}"); + self.options.get(&key).map(String::as_str) + } + + /// Default aggregate function from `fields.default-aggregate-function`. + pub(crate) fn default_agg_function(&self) -> Option<&str> { + self.options + .get(FIELDS_DEFAULT_AGG_FUNCTION_OPTION) + .map(String::as_str) + } +} + +fn is_unsupported_aggregation_option(key: &str) -> bool { + key == IGNORE_DELETE_OPTION + || key.ends_with(IGNORE_DELETE_SUFFIX) + || key == AGGREGATION_REMOVE_RECORD_ON_DELETE_OPTION + || is_fields_option_with_suffix(key, IGNORE_RETRACT_SUFFIX) + || is_fields_option_with_suffix(key, DISTINCT_SUFFIX) + || is_fields_option_with_suffix(key, SEQUENCE_GROUP_SUFFIX) + || is_fields_option_with_suffix(key, NESTED_KEY_SUFFIX) + || is_fields_option_with_suffix(key, COUNT_LIMIT_SUFFIX) +} + +fn is_fields_option_with_suffix(key: &str, suffix: &str) -> bool { + key.starts_with(FIELDS_PREFIX) && key.ends_with(suffix) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn aggregation_options(extra: &[(&str, &str)]) -> HashMap { + let mut options = HashMap::from([( + MERGE_ENGINE_OPTION.to_string(), + AGGREGATION_ENGINE.to_string(), + )]); + options.extend( + extra + .iter() + .map(|(key, value)| ((*key).to_string(), (*value).to_string())), + ); + options + } + + #[test] + fn test_validate_create_mode_accepts_basic_pk_aggregation() { + let options = aggregation_options(&[ + ("fields.price.aggregate-function", "sum"), + ("fields.default-aggregate-function", "last_non_null_value"), + ("fields.tags.list-agg-delimiter", ";"), + ]); + let config = AggregationConfig::new(&options); + + assert_eq!( + config.validate_create_mode(true).unwrap(), + Some(AggregationMode::Basic) + ); + } + + #[test] + fn test_validate_create_mode_ignores_non_pk_tables() { + let options = aggregation_options(&[("fields.x.ignore-retract", "true")]); + let config = AggregationConfig::new(&options); + + assert_eq!(config.validate_create_mode(false).unwrap(), None); + } + + #[test] + fn test_is_enabled_disabled_for_other_engines() { + let options = HashMap::from([(MERGE_ENGINE_OPTION.to_string(), "partial-update".into())]); + let config = AggregationConfig::new(&options); + assert!(!config.is_enabled()); + assert_eq!(config.validate_create_mode(true).unwrap(), None); + } + + #[test] + fn test_validate_create_mode_rejects_unsupported_options() { + for key in [ + IGNORE_DELETE_OPTION, + "fields.price.ignore-delete", + AGGREGATION_REMOVE_RECORD_ON_DELETE_OPTION, + "fields.price.ignore-retract", + "fields.tags.distinct", + "fields.price.sequence-group", + "fields.payload.nested-key", + "fields.payload.count-limit", + ] { + let options = aggregation_options(&[(key, "value")]); + let config = AggregationConfig::new(&options); + let err = config.validate_create_mode(true).unwrap_err(); + assert!( + matches!(err, crate::Error::ConfigInvalid { ref message } if message.contains(key)), + "expected create-time rejection to mention '{key}', got {err:?}" + ); + } + } + + #[test] + fn test_validate_runtime_mode_rejects_unsupported_options() { + let options = aggregation_options(&[("fields.price.ignore-retract", "true")]); + let config = AggregationConfig::new(&options); + let err = config.validate_runtime_mode(true, "default.t").unwrap_err(); + + assert!( + matches!(err, crate::Error::Unsupported { ref message } if message.contains("fields.price.ignore-retract")), + "expected runtime rejection to mention the unsupported option, got {err:?}" + ); + } +} diff --git a/crates/paimon/src/spec/core_options.rs b/crates/paimon/src/spec/core_options.rs index bafad0a1..b99df730 100644 --- a/crates/paimon/src/spec/core_options.rs +++ b/crates/paimon/src/spec/core_options.rs @@ -73,6 +73,8 @@ pub enum MergeEngine { PartialUpdate, /// Keep the first row for each key (ignore later updates). FirstRow, + /// Apply per-field aggregate functions across rows sharing the same key. + Aggregation, } /// Format the bucket directory name for a given bucket number. @@ -131,6 +133,7 @@ impl<'a> CoreOptions<'a> { "deduplicate" => Ok(MergeEngine::Deduplicate), "partial-update" => Ok(MergeEngine::PartialUpdate), "first-row" => Ok(MergeEngine::FirstRow), + "aggregation" => Ok(MergeEngine::Aggregation), other => Err(crate::Error::Unsupported { message: format!("Unsupported merge-engine: '{other}'"), }), @@ -546,6 +549,14 @@ mod tests { assert_eq!(core.merge_engine().unwrap(), MergeEngine::PartialUpdate); } + #[test] + fn test_merge_engine_accepts_aggregation() { + let options = HashMap::from([(MERGE_ENGINE_OPTION.to_string(), "aggregation".into())]); + let core = CoreOptions::new(&options); + + assert_eq!(core.merge_engine().unwrap(), MergeEngine::Aggregation); + } + #[test] fn test_commit_options_defaults() { let options = HashMap::new(); diff --git a/crates/paimon/src/spec/mod.rs b/crates/paimon/src/spec/mod.rs index 89de289f..8760d1b3 100644 --- a/crates/paimon/src/spec/mod.rs +++ b/crates/paimon/src/spec/mod.rs @@ -35,6 +35,9 @@ pub use core_options::*; mod partial_update; pub(crate) use partial_update::PartialUpdateConfig; +mod aggregation; +pub(crate) use aggregation::AggregationConfig; + mod schema; pub use schema::*; diff --git a/crates/paimon/src/spec/schema.rs b/crates/paimon/src/spec/schema.rs index e6d5af99..d06a5fd2 100644 --- a/crates/paimon/src/spec/schema.rs +++ b/crates/paimon/src/spec/schema.rs @@ -17,6 +17,7 @@ use crate::spec::core_options::CoreOptions; use crate::spec::types::{ArrayType, DataType, MapType, MultisetType, RowType}; +use crate::spec::AggregationConfig; use crate::spec::PartialUpdateConfig; use serde::{Deserialize, Serialize}; use serde_with::serde_as; @@ -291,6 +292,7 @@ impl Schema { let fields = Self::normalize_fields(&fields, &partition_keys, &primary_keys)?; Self::validate_blob_fields(&fields, &partition_keys, &options)?; PartialUpdateConfig::new(&options).validate_create_mode(!primary_keys.is_empty())?; + AggregationConfig::new(&options).validate_create_mode(!primary_keys.is_empty())?; Ok(Self { fields, @@ -714,7 +716,7 @@ impl Default for SchemaBuilder { #[cfg(test)] mod tests { - use crate::spec::{BlobType, IntType}; + use crate::spec::{BlobType, IntType, VarCharType}; use super::*; @@ -978,6 +980,51 @@ mod tests { } } + #[test] + fn test_aggregation_schema_validation_accepts_basic_options() { + let schema = Schema::builder() + .column("id", DataType::Int(IntType::new())) + .column("value", DataType::Int(IntType::new())) + .column("tags", DataType::VarChar(VarCharType::new(255).unwrap())) + .primary_key(["id"]) + .option("merge-engine", "aggregation") + .option("fields.value.aggregate-function", "sum") + .option("fields.tags.aggregate-function", "listagg") + .option("fields.tags.list-agg-delimiter", ";") + .option("fields.default-aggregate-function", "last_non_null_value") + .build() + .unwrap(); + + assert_eq!(schema.fields().len(), 3); + } + + #[test] + fn test_aggregation_schema_validation_rejects_unsupported_options() { + for (key, value) in [ + ("ignore-delete", "true"), + ("aggregation.remove-record-on-delete", "true"), + ("fields.value.ignore-retract", "true"), + ("fields.value.distinct", "true"), + ("fields.value.sequence-group", "g1"), + ("fields.value.nested-key", "id"), + ("fields.value.count-limit", "10"), + ] { + let err = Schema::builder() + .column("id", DataType::Int(IntType::new())) + .column("value", DataType::Int(IntType::new())) + .primary_key(["id"]) + .option("merge-engine", "aggregation") + .option(key, value) + .build() + .unwrap_err(); + + assert!( + matches!(err, crate::Error::ConfigInvalid { ref message } if message.contains(key)), + "aggregation create-time validation should reject '{key}', got {err:?}" + ); + } + } + #[test] fn test_schema_builder_column_row_type() { let row_type = RowType::new(vec![DataField::new( diff --git a/crates/paimon/src/table/aggregator/bool_agg.rs b/crates/paimon/src/table/aggregator/bool_agg.rs new file mode 100644 index 00000000..31f4f3a8 --- /dev/null +++ b/crates/paimon/src/table/aggregator/bool_agg.rs @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Boolean reducers: `bool_and` and `bool_or`. +//! +//! Both accumulate over BOOLEAN columns, skipping NULL inputs. The output +//! cell is NULL when no non-NULL input was observed for the key. +//! +//! Reference: Java `FieldBoolAndAgg`, `FieldBoolOrAgg` under +//! `org.apache.paimon.mergetree.compact.aggregate`. + +use std::sync::Arc; + +use arrow_array::{Array, ArrayRef, BooleanArray}; + +use super::{unsupported_type_error, FieldAggregator}; +use crate::spec::DataType; + +fn ensure_boolean(field_name: &str, data_type: &DataType, op: &str) -> crate::Result<()> { + match data_type { + DataType::Boolean(_) => Ok(()), + other => Err(unsupported_type_error(op, field_name, other)), + } +} + +#[derive(Debug)] +pub(crate) struct BoolAndAgg { + field_name: String, + acc: Option, +} + +impl BoolAndAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result { + ensure_boolean(field_name, data_type, "bool_and")?; + Ok(Self { + field_name: field_name.to_string(), + acc: None, + }) + } +} + +impl FieldAggregator for BoolAndAgg { + fn name(&self) -> &'static str { + "bool_and" + } + + fn reset(&mut self) { + self.acc = None; + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + if array.is_null(row_idx) { + return Ok(()); + } + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| crate::Error::DataInvalid { + message: format!( + "bool_and column '{}' received non-Boolean Arrow array {:?}", + self.field_name, + array.data_type() + ), + source: None, + })?; + let v = arr.value(row_idx); + self.acc = Some(self.acc.map_or(v, |prev| prev && v)); + Ok(()) + } + + fn result(&self) -> crate::Result { + Ok(Arc::new(BooleanArray::from(vec![self.acc]))) + } +} + +#[derive(Debug)] +pub(crate) struct BoolOrAgg { + field_name: String, + acc: Option, +} + +impl BoolOrAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result { + ensure_boolean(field_name, data_type, "bool_or")?; + Ok(Self { + field_name: field_name.to_string(), + acc: None, + }) + } +} + +impl FieldAggregator for BoolOrAgg { + fn name(&self) -> &'static str { + "bool_or" + } + + fn reset(&mut self) { + self.acc = None; + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + if array.is_null(row_idx) { + return Ok(()); + } + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| crate::Error::DataInvalid { + message: format!( + "bool_or column '{}' received non-Boolean Arrow array {:?}", + self.field_name, + array.data_type() + ), + source: None, + })?; + let v = arr.value(row_idx); + self.acc = Some(self.acc.map_or(v, |prev| prev || v)); + Ok(()) + } + + fn result(&self) -> crate::Result { + Ok(Arc::new(BooleanArray::from(vec![self.acc]))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::spec::{BooleanType, IntType}; + + fn collect(arr: ArrayRef) -> Option { + let a = arr.as_any().downcast_ref::().unwrap(); + if a.is_null(0) { + None + } else { + Some(a.value(0)) + } + } + + #[test] + fn test_bool_and_all_true_returns_true() { + let mut agg = BoolAndAgg::new("b", &DataType::Boolean(BooleanType::new())).unwrap(); + let arr = BooleanArray::from(vec![Some(true), Some(true), None, Some(true)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect(agg.result().unwrap()), Some(true)); + } + + #[test] + fn test_bool_and_short_circuits_false() { + let mut agg = BoolAndAgg::new("b", &DataType::Boolean(BooleanType::new())).unwrap(); + let arr = BooleanArray::from(vec![Some(true), Some(false), Some(true)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect(agg.result().unwrap()), Some(false)); + } + + #[test] + fn test_bool_or_any_true() { + let mut agg = BoolOrAgg::new("b", &DataType::Boolean(BooleanType::new())).unwrap(); + let arr = BooleanArray::from(vec![Some(false), None, Some(true), Some(false)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect(agg.result().unwrap()), Some(true)); + } + + #[test] + fn test_bool_and_or_all_null_returns_null() { + let mut and_agg = BoolAndAgg::new("b", &DataType::Boolean(BooleanType::new())).unwrap(); + let arr = BooleanArray::from(vec![None::, None]); + for i in 0..arr.len() { + and_agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect(and_agg.result().unwrap()), None); + + let mut or_agg = BoolOrAgg::new("b", &DataType::Boolean(BooleanType::new())).unwrap(); + for i in 0..arr.len() { + or_agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect(or_agg.result().unwrap()), None); + } + + #[test] + fn test_bool_and_rejects_non_boolean_type() { + let err = BoolAndAgg::new("b", &DataType::Int(IntType::new())).unwrap_err(); + assert!( + matches!(err, crate::Error::ConfigInvalid { message } if message.contains("bool_and")) + ); + } + + #[test] + fn test_reset_clears_state() { + let mut agg = BoolOrAgg::new("b", &DataType::Boolean(BooleanType::new())).unwrap(); + let arr = BooleanArray::from(vec![Some(true)]); + agg.agg(&arr, 0).unwrap(); + agg.reset(); + assert_eq!(collect(agg.result().unwrap()), None); + } +} diff --git a/crates/paimon/src/table/aggregator/listagg.rs b/crates/paimon/src/table/aggregator/listagg.rs new file mode 100644 index 00000000..f302ced2 --- /dev/null +++ b/crates/paimon/src/table/aggregator/listagg.rs @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `listagg`: concatenate non-NULL string values for a key, separated by a +//! per-field delimiter (`fields..list-agg-delimiter`, defaulting to +//! `","`). +//! +//! Reference: Java `FieldListaggAgg` under +//! `org.apache.paimon.mergetree.compact.aggregate`. + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow_array::{Array, ArrayRef, StringArray}; + +use super::{unsupported_type_error, FieldAggregator}; +use crate::spec::DataType; + +const FIELDS_PREFIX: &str = "fields."; +const LIST_AGG_DELIMITER_SUFFIX: &str = ".list-agg-delimiter"; +const DEFAULT_DELIMITER: &str = ","; + +fn list_agg_delimiter<'a>(field_name: &str, options: &'a HashMap) -> &'a str { + options + .get(&format!( + "{FIELDS_PREFIX}{field_name}{LIST_AGG_DELIMITER_SUFFIX}" + )) + .map(String::as_str) + .unwrap_or(DEFAULT_DELIMITER) +} + +#[derive(Debug)] +pub(crate) struct ListaggAgg { + field_name: String, + delimiter: String, + acc: Option, +} + +impl ListaggAgg { + pub(crate) fn new( + field_name: &str, + data_type: &DataType, + table_options: &HashMap, + ) -> crate::Result { + match data_type { + DataType::Char(_) | DataType::VarChar(_) => {} + other => return Err(unsupported_type_error("listagg", field_name, other)), + } + Ok(Self { + field_name: field_name.to_string(), + delimiter: list_agg_delimiter(field_name, table_options).to_string(), + acc: None, + }) + } +} + +impl FieldAggregator for ListaggAgg { + fn name(&self) -> &'static str { + "listagg" + } + + fn reset(&mut self) { + self.acc = None; + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + if array.is_null(row_idx) { + return Ok(()); + } + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| crate::Error::DataInvalid { + message: format!( + "listagg column '{}' received non-Utf8 Arrow array {:?}", + self.field_name, + array.data_type() + ), + source: None, + })?; + let v = arr.value(row_idx); + match &mut self.acc { + None => self.acc = Some(v.to_string()), + Some(prev) => { + prev.push_str(&self.delimiter); + prev.push_str(v); + } + } + Ok(()) + } + + fn result(&self) -> crate::Result { + Ok(Arc::new(StringArray::from(vec![self.acc.clone()]))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::spec::{IntType, VarCharType}; + + fn collect(arr: ArrayRef) -> Option { + let a = arr.as_any().downcast_ref::().unwrap(); + if a.is_null(0) { + None + } else { + Some(a.value(0).to_string()) + } + } + + fn varchar_type() -> DataType { + DataType::VarChar(VarCharType::new(255).unwrap()) + } + + #[test] + fn test_listagg_default_delimiter_skips_null() { + let mut agg = ListaggAgg::new("v", &varchar_type(), &HashMap::new()).unwrap(); + let arr = StringArray::from(vec![Some("a"), None, Some("b"), Some("c")]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect(agg.result().unwrap()), Some("a,b,c".to_string())); + } + + #[test] + fn test_listagg_custom_delimiter() { + let opts = HashMap::from([("fields.v.list-agg-delimiter".to_string(), "|".to_string())]); + let mut agg = ListaggAgg::new("v", &varchar_type(), &opts).unwrap(); + let arr = StringArray::from(vec![Some("x"), Some("y")]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect(agg.result().unwrap()), Some("x|y".to_string())); + } + + #[test] + fn test_listagg_all_null_returns_null() { + let mut agg = ListaggAgg::new("v", &varchar_type(), &HashMap::new()).unwrap(); + let arr = StringArray::from(vec![None::<&str>, None]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect(agg.result().unwrap()), None); + } + + #[test] + fn test_listagg_single_value_does_not_prepend_delimiter() { + let mut agg = ListaggAgg::new("v", &varchar_type(), &HashMap::new()).unwrap(); + let arr = StringArray::from(vec![Some("only")]); + agg.agg(&arr, 0).unwrap(); + assert_eq!(collect(agg.result().unwrap()), Some("only".to_string())); + } + + #[test] + fn test_listagg_rejects_non_string_type() { + let err = + ListaggAgg::new("v", &DataType::Int(IntType::new()), &HashMap::new()).unwrap_err(); + assert!( + matches!(err, crate::Error::ConfigInvalid { message } if message.contains("listagg")) + ); + } + + #[test] + fn test_reset_clears_state() { + let mut agg = ListaggAgg::new("v", &varchar_type(), &HashMap::new()).unwrap(); + let arr = StringArray::from(vec![Some("keep_me")]); + agg.agg(&arr, 0).unwrap(); + agg.reset(); + assert_eq!(collect(agg.result().unwrap()), None); + } +} diff --git a/crates/paimon/src/table/aggregator/mod.rs b/crates/paimon/src/table/aggregator/mod.rs new file mode 100644 index 00000000..3abc7b36 --- /dev/null +++ b/crates/paimon/src/table/aggregator/mod.rs @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Per-field aggregators used by `merge-engine=aggregation`. +//! +//! Each [`FieldAggregator`] accumulates values for one output column across +//! the rows that share a primary key and produces a single-row [`ArrayRef`]. +//! The accumulator is reused across PK groups by calling [`reset`] between +//! groups. +//! +//! Reference: Java `org.apache.paimon.mergetree.compact.aggregate.FieldAggregator` +//! and the per-function factories under +//! `org.apache.paimon.mergetree.compact.aggregate.factory`. +//! +//! [`reset`]: FieldAggregator::reset + +use std::collections::HashMap; + +use arrow_array::{Array, ArrayRef}; + +use crate::spec::DataType; + +mod bool_agg; +mod listagg; +mod numeric; +mod value; + +pub(crate) use bool_agg::{BoolAndAgg, BoolOrAgg}; +pub(crate) use listagg::ListaggAgg; +pub(crate) use numeric::{CountAgg, MaxAgg, MinAgg, ProductAgg, SumAgg}; +pub(crate) use value::{FirstNonNullValueAgg, FirstValueAgg, LastNonNullValueAgg, LastValueAgg}; + +/// Per-field aggregator. +/// +/// The merge function calls [`reset`] once at the start of each primary-key +/// group, then [`agg`] once per row in the group (in user-sequence order), +/// and finally [`result`] to materialize the single-row output column. +/// +/// `agg` receives the source Arrow array plus the row index to read; the +/// implementation is expected to downcast to the appropriate typed array. +/// +/// [`reset`]: FieldAggregator::reset +/// [`agg`]: FieldAggregator::agg +/// [`result`]: FieldAggregator::result +pub(crate) trait FieldAggregator: Send + Sync + std::fmt::Debug { + /// Aggregator identifier, e.g. `"sum"`. Matches the + /// `fields..aggregate-function` option value. + fn name(&self) -> &'static str; + + /// Reset internal state at the start of a new primary-key group. + fn reset(&mut self); + + /// Accumulate one input cell. + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()>; + + /// Materialize the current accumulator as a 1-row Arrow array. + fn result(&self) -> crate::Result; +} + +/// Construct an aggregator by `name` for a column of type `data_type`. +/// +/// `field_name` and `table_options` are forwarded for per-field configuration +/// (e.g. `fields..list-agg-delimiter` for `listagg`). +/// +/// Returns [`Error::ConfigInvalid`] when the name is unknown or the column +/// type is incompatible with the requested aggregator — both indicate a user +/// configuration error and should fail at table creation rather than at read +/// time. +/// +/// [`Error::ConfigInvalid`]: crate::Error::ConfigInvalid +pub(crate) fn new_aggregator( + name: &str, + field_name: &str, + data_type: &DataType, + table_options: &HashMap, +) -> crate::Result> { + match name { + "sum" => Ok(Box::new(SumAgg::new(field_name, data_type)?)), + "product" => Ok(Box::new(ProductAgg::new(field_name, data_type)?)), + "min" => Ok(Box::new(MinAgg::new(field_name, data_type)?)), + "max" => Ok(Box::new(MaxAgg::new(field_name, data_type)?)), + "count" => Ok(Box::new(CountAgg::new(field_name, data_type)?)), + "last_value" => Ok(Box::new(LastValueAgg::new(field_name, data_type)?)), + "first_value" => Ok(Box::new(FirstValueAgg::new(field_name, data_type)?)), + "last_non_null_value" => Ok(Box::new(LastNonNullValueAgg::new(field_name, data_type)?)), + "first_non_null_value" => Ok(Box::new(FirstNonNullValueAgg::new(field_name, data_type)?)), + "bool_and" => Ok(Box::new(BoolAndAgg::new(field_name, data_type)?)), + "bool_or" => Ok(Box::new(BoolOrAgg::new(field_name, data_type)?)), + "listagg" => Ok(Box::new(ListaggAgg::new( + field_name, + data_type, + table_options, + )?)), + other => Err(crate::Error::ConfigInvalid { + message: format!( + "Unknown aggregate function '{other}' for field '{field_name}'; \ + supported: sum, product, min, max, count, last_value, first_value, \ + last_non_null_value, first_non_null_value, bool_and, bool_or, listagg" + ), + }), + } +} + +/// Helper: build a `ConfigInvalid` error for an unsupported (aggregator, type) +/// pair so every concrete aggregator emits the same phrasing. +pub(crate) fn unsupported_type_error( + agg_name: &str, + field_name: &str, + data_type: &DataType, +) -> crate::Error { + crate::Error::ConfigInvalid { + message: format!( + "Aggregate function '{agg_name}' does not support data type {data_type:?} \ + for field '{field_name}'" + ), + } +} diff --git a/crates/paimon/src/table/aggregator/numeric.rs b/crates/paimon/src/table/aggregator/numeric.rs new file mode 100644 index 00000000..981daf02 --- /dev/null +++ b/crates/paimon/src/table/aggregator/numeric.rs @@ -0,0 +1,1052 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Numeric aggregators: sum, product, min, max, count. +//! +//! `sum` operates on every integer / floating / Decimal numeric type. +//! `product` accepts the same numeric family except DECIMAL — basic mode does +//! not yet implement BigDecimal-style scale rebasing for Decimal product, so +//! Decimal columns are rejected at construction. Integer overflow on either +//! aggregator is reported as [`Error::DataInvalid`] so silent wrap cannot +//! produce misleading aggregated values. +//! +//! `min` / `max` extend to every ordered Paimon type: numerics, Decimal, +//! Date, Time, Timestamp, and Char/VarChar. Comparison is by native value +//! order (numeric for numbers, lexicographic for strings). +//! +//! `count` requires the column to be declared as BIGINT and accumulates the +//! number of non-NULL inputs encountered for the key. Non-BIGINT columns are +//! rejected at construction. +//! +//! Reference: Java `FieldSumAgg`, `FieldProductAgg`, `FieldMinAgg`, +//! `FieldMaxAgg`, `FieldCountAgg` under +//! `org.apache.paimon.mergetree.compact.aggregate`. +//! +//! [`Error::DataInvalid`]: crate::Error::DataInvalid + +use std::sync::Arc; + +use arrow_array::builder::Decimal128Builder; +use arrow_array::{ + Array, ArrayRef, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, StringArray, Time32MillisecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, +}; +use arrow_schema::TimeUnit; + +use super::{unsupported_type_error, FieldAggregator}; +use crate::spec::DataType; + +// --------------------------------------------------------------------------- +// Sum +// --------------------------------------------------------------------------- + +/// `sum` accumulator state, parameterized by the column's numeric kind. +#[derive(Debug)] +enum SumState { + I8(Option), + I16(Option), + I32(Option), + I64(Option), + F32(Option), + F64(Option), + Decimal128 { + precision: u8, + scale: i8, + acc: Option, + }, +} + +#[derive(Debug)] +pub(crate) struct SumAgg { + field_name: String, + state: SumState, +} + +impl SumAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result { + let state = match data_type { + DataType::TinyInt(_) => SumState::I8(None), + DataType::SmallInt(_) => SumState::I16(None), + DataType::Int(_) => SumState::I32(None), + DataType::BigInt(_) => SumState::I64(None), + DataType::Float(_) => SumState::F32(None), + DataType::Double(_) => SumState::F64(None), + DataType::Decimal(d) => SumState::Decimal128 { + precision: decimal_precision(d.precision(), field_name)?, + scale: decimal_scale(d.scale(), field_name)?, + acc: None, + }, + other => return Err(unsupported_type_error("sum", field_name, other)), + }; + Ok(Self { + field_name: field_name.to_string(), + state, + }) + } +} + +impl FieldAggregator for SumAgg { + fn name(&self) -> &'static str { + "sum" + } + + fn reset(&mut self) { + match &mut self.state { + SumState::I8(acc) => *acc = None, + SumState::I16(acc) => *acc = None, + SumState::I32(acc) => *acc = None, + SumState::I64(acc) => *acc = None, + SumState::F32(acc) => *acc = None, + SumState::F64(acc) => *acc = None, + SumState::Decimal128 { acc, .. } => *acc = None, + } + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + if array.is_null(row_idx) { + return Ok(()); + } + match &mut self.state { + SumState::I8(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + SumState::I16(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + SumState::I32(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + SumState::I64(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + SumState::F32(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(acc.map_or(v, |prev| prev + v)); + } + SumState::F64(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(acc.map_or(v, |prev| prev + v)); + } + SumState::Decimal128 { acc, .. } => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_add(v) + .ok_or_else(|| overflow_error("sum", &self.field_name))?, + }); + } + } + Ok(()) + } + + fn result(&self) -> crate::Result { + Ok(match &self.state { + SumState::I8(acc) => Arc::new(Int8Array::from(vec![*acc])), + SumState::I16(acc) => Arc::new(Int16Array::from(vec![*acc])), + SumState::I32(acc) => Arc::new(Int32Array::from(vec![*acc])), + SumState::I64(acc) => Arc::new(Int64Array::from(vec![*acc])), + SumState::F32(acc) => Arc::new(Float32Array::from(vec![*acc])), + SumState::F64(acc) => Arc::new(Float64Array::from(vec![*acc])), + SumState::Decimal128 { + precision, + scale, + acc, + } => decimal_array(*precision, *scale, *acc, "sum", &self.field_name)?, + }) + } +} + +// --------------------------------------------------------------------------- +// Product +// --------------------------------------------------------------------------- + +#[derive(Debug)] +enum ProductState { + I8(Option), + I16(Option), + I32(Option), + I64(Option), + F32(Option), + F64(Option), + // DECIMAL `product` is intentionally rejected at construction (see + // `ProductAgg::new`); add a variant here when the BigDecimal-style + // scale handling lands. +} + +#[derive(Debug)] +pub(crate) struct ProductAgg { + field_name: String, + state: ProductState, +} + +impl ProductAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result { + let state = match data_type { + DataType::TinyInt(_) => ProductState::I8(None), + DataType::SmallInt(_) => ProductState::I16(None), + DataType::Int(_) => ProductState::I32(None), + DataType::BigInt(_) => ProductState::I64(None), + DataType::Float(_) => ProductState::F32(None), + DataType::Double(_) => ProductState::F64(None), + // Decimal `product` would need BigDecimal-style scale rebasing + // (multiply raw i128, then divide by 10^scale, with precision + // checks). The basic mode does not implement that yet, so we + // reject DECIMAL columns explicitly rather than silently produce + // a scale-shifted result. + DataType::Decimal(_) => { + return Err(crate::Error::ConfigInvalid { + message: format!( + "Aggregate function 'product' on DECIMAL field '{field_name}' is not \ + supported in the basic mode; use a BIGINT/DOUBLE column or wait for a \ + follow-up commit that adds Decimal product semantics aligned with Java \ + BigDecimal" + ), + }); + } + other => return Err(unsupported_type_error("product", field_name, other)), + }; + Ok(Self { + field_name: field_name.to_string(), + state, + }) + } +} + +impl FieldAggregator for ProductAgg { + fn name(&self) -> &'static str { + "product" + } + + fn reset(&mut self) { + match &mut self.state { + ProductState::I8(acc) => *acc = None, + ProductState::I16(acc) => *acc = None, + ProductState::I32(acc) => *acc = None, + ProductState::I64(acc) => *acc = None, + ProductState::F32(acc) => *acc = None, + ProductState::F64(acc) => *acc = None, + } + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + if array.is_null(row_idx) { + return Ok(()); + } + match &mut self.state { + ProductState::I8(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_mul(v) + .ok_or_else(|| overflow_error("product", &self.field_name))?, + }); + } + ProductState::I16(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_mul(v) + .ok_or_else(|| overflow_error("product", &self.field_name))?, + }); + } + ProductState::I32(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_mul(v) + .ok_or_else(|| overflow_error("product", &self.field_name))?, + }); + } + ProductState::I64(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(match *acc { + None => v, + Some(prev) => prev + .checked_mul(v) + .ok_or_else(|| overflow_error("product", &self.field_name))?, + }); + } + ProductState::F32(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(acc.map_or(v, |prev| prev * v)); + } + ProductState::F64(acc) => { + let v = downcast::(array, &self.field_name)?.value(row_idx); + *acc = Some(acc.map_or(v, |prev| prev * v)); + } + } + Ok(()) + } + + fn result(&self) -> crate::Result { + Ok(match &self.state { + ProductState::I8(acc) => Arc::new(Int8Array::from(vec![*acc])), + ProductState::I16(acc) => Arc::new(Int16Array::from(vec![*acc])), + ProductState::I32(acc) => Arc::new(Int32Array::from(vec![*acc])), + ProductState::I64(acc) => Arc::new(Int64Array::from(vec![*acc])), + ProductState::F32(acc) => Arc::new(Float32Array::from(vec![*acc])), + ProductState::F64(acc) => Arc::new(Float64Array::from(vec![*acc])), + }) + } +} + +// --------------------------------------------------------------------------- +// Min / Max — generic comparator-driven implementation +// --------------------------------------------------------------------------- + +/// `min` / `max` accumulator state. Each variant stores `Option` where +/// `None` means "no non-null value seen yet for the current group". +#[derive(Debug)] +enum MinMaxState { + I8(Option), + I16(Option), + I32(Option), + I64(Option), + F32(Option), + F64(Option), + Decimal128 { + precision: u8, + scale: i8, + acc: Option, + }, + Date32(Option), + /// Paimon `TIME` is encoded as Arrow `Time32(Millisecond)` regardless of + /// declared precision, so a single accumulator variant suffices. + Time32Ms(Option), + Timestamp { + unit: TimeUnit, + acc: Option, + }, + Utf8(Option), +} + +fn make_minmax_state( + field_name: &str, + data_type: &DataType, + op: &str, +) -> crate::Result { + Ok(match data_type { + DataType::TinyInt(_) => MinMaxState::I8(None), + DataType::SmallInt(_) => MinMaxState::I16(None), + DataType::Int(_) => MinMaxState::I32(None), + DataType::BigInt(_) => MinMaxState::I64(None), + DataType::Float(_) => MinMaxState::F32(None), + DataType::Double(_) => MinMaxState::F64(None), + DataType::Decimal(d) => MinMaxState::Decimal128 { + precision: decimal_precision(d.precision(), field_name)?, + scale: decimal_scale(d.scale(), field_name)?, + acc: None, + }, + DataType::Date(_) => MinMaxState::Date32(None), + DataType::Time(_) => MinMaxState::Time32Ms(None), + DataType::Timestamp(t) => MinMaxState::Timestamp { + unit: timestamp_time_unit(t.precision())?, + acc: None, + }, + DataType::Char(_) | DataType::VarChar(_) => MinMaxState::Utf8(None), + other => return Err(unsupported_type_error(op, field_name, other)), + }) +} + +fn timestamp_time_unit(precision: u32) -> crate::Result { + match precision { + 0..=3 => Ok(TimeUnit::Millisecond), + 4..=6 => Ok(TimeUnit::Microsecond), + 7..=9 => Ok(TimeUnit::Nanosecond), + other => Err(crate::Error::Unsupported { + message: format!("Unsupported TIMESTAMP precision {other} for min/max aggregator"), + }), + } +} + +fn agg_minmax( + state: &mut MinMaxState, + array: &dyn Array, + row_idx: usize, + field_name: &str, + keep_smaller: bool, +) -> crate::Result<()> { + if array.is_null(row_idx) { + return Ok(()); + } + macro_rules! update_primitive { + ($acc:expr, $ty:ty) => {{ + let v = downcast::<$ty>(array, field_name)?.value(row_idx); + *$acc = Some(match *$acc { + None => v, + Some(prev) => { + if (keep_smaller && v < prev) || (!keep_smaller && v > prev) { + v + } else { + prev + } + } + }); + }}; + } + macro_rules! update_float { + ($acc:expr, $ty:ty) => {{ + let v = downcast::<$ty>(array, field_name)?.value(row_idx); + if v.is_nan() { + return Ok(()); // mirror Java: NaN values are ignored + } + *$acc = Some(match *$acc { + None => v, + Some(prev) => { + let take_new = if keep_smaller { v < prev } else { v > prev }; + if take_new { + v + } else { + prev + } + } + }); + }}; + } + match state { + MinMaxState::I8(acc) => update_primitive!(acc, Int8Array), + MinMaxState::I16(acc) => update_primitive!(acc, Int16Array), + MinMaxState::I32(acc) => update_primitive!(acc, Int32Array), + MinMaxState::I64(acc) => update_primitive!(acc, Int64Array), + MinMaxState::F32(acc) => update_float!(acc, Float32Array), + MinMaxState::F64(acc) => update_float!(acc, Float64Array), + MinMaxState::Decimal128 { acc, .. } => update_primitive!(acc, Decimal128Array), + MinMaxState::Date32(acc) => update_primitive!(acc, Date32Array), + MinMaxState::Time32Ms(acc) => update_primitive!(acc, Time32MillisecondArray), + MinMaxState::Timestamp { unit, acc } => match unit { + TimeUnit::Millisecond => update_primitive!(acc, TimestampMillisecondArray), + TimeUnit::Microsecond => update_primitive!(acc, TimestampMicrosecondArray), + TimeUnit::Nanosecond => update_primitive!(acc, TimestampNanosecondArray), + other => { + return Err(crate::Error::DataInvalid { + message: format!( + "Timestamp with unit {other:?} not expected for field '{field_name}'" + ), + source: None, + }); + } + }, + MinMaxState::Utf8(acc) => { + let v = downcast::(array, field_name)?.value(row_idx); + *acc = Some(match acc.take() { + None => v.to_string(), + Some(prev) => { + let take_new = if keep_smaller { + v < prev.as_str() + } else { + v > prev.as_str() + }; + if take_new { + v.to_string() + } else { + prev + } + } + }); + } + } + Ok(()) +} + +fn minmax_result(state: &MinMaxState, agg_name: &str, field_name: &str) -> crate::Result { + Ok(match state { + MinMaxState::I8(acc) => Arc::new(Int8Array::from(vec![*acc])), + MinMaxState::I16(acc) => Arc::new(Int16Array::from(vec![*acc])), + MinMaxState::I32(acc) => Arc::new(Int32Array::from(vec![*acc])), + MinMaxState::I64(acc) => Arc::new(Int64Array::from(vec![*acc])), + MinMaxState::F32(acc) => Arc::new(Float32Array::from(vec![*acc])), + MinMaxState::F64(acc) => Arc::new(Float64Array::from(vec![*acc])), + MinMaxState::Decimal128 { + precision, + scale, + acc, + } => decimal_array(*precision, *scale, *acc, agg_name, field_name)?, + MinMaxState::Date32(acc) => Arc::new(Date32Array::from(vec![*acc])), + MinMaxState::Time32Ms(acc) => Arc::new(Time32MillisecondArray::from(vec![*acc])), + MinMaxState::Timestamp { unit, acc } => match unit { + TimeUnit::Millisecond => Arc::new(TimestampMillisecondArray::from(vec![*acc])), + TimeUnit::Microsecond => Arc::new(TimestampMicrosecondArray::from(vec![*acc])), + TimeUnit::Nanosecond => Arc::new(TimestampNanosecondArray::from(vec![*acc])), + other => { + return Err(crate::Error::DataInvalid { + message: format!( + "Timestamp with unit {other:?} not expected for field '{field_name}'" + ), + source: None, + }); + } + }, + MinMaxState::Utf8(acc) => Arc::new(StringArray::from(vec![acc.clone()])), + }) +} + +fn reset_minmax(state: &mut MinMaxState) { + match state { + MinMaxState::I8(acc) => *acc = None, + MinMaxState::I16(acc) => *acc = None, + MinMaxState::I32(acc) => *acc = None, + MinMaxState::I64(acc) => *acc = None, + MinMaxState::F32(acc) => *acc = None, + MinMaxState::F64(acc) => *acc = None, + MinMaxState::Decimal128 { acc, .. } => *acc = None, + MinMaxState::Date32(acc) => *acc = None, + MinMaxState::Time32Ms(acc) => *acc = None, + MinMaxState::Timestamp { acc, .. } => *acc = None, + MinMaxState::Utf8(acc) => *acc = None, + } +} + +#[derive(Debug)] +pub(crate) struct MinAgg { + field_name: String, + state: MinMaxState, +} + +impl MinAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result { + Ok(Self { + field_name: field_name.to_string(), + state: make_minmax_state(field_name, data_type, "min")?, + }) + } +} + +impl FieldAggregator for MinAgg { + fn name(&self) -> &'static str { + "min" + } + + fn reset(&mut self) { + reset_minmax(&mut self.state); + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + agg_minmax(&mut self.state, array, row_idx, &self.field_name, true) + } + + fn result(&self) -> crate::Result { + minmax_result(&self.state, "min", &self.field_name) + } +} + +#[derive(Debug)] +pub(crate) struct MaxAgg { + field_name: String, + state: MinMaxState, +} + +impl MaxAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result { + Ok(Self { + field_name: field_name.to_string(), + state: make_minmax_state(field_name, data_type, "max")?, + }) + } +} + +impl FieldAggregator for MaxAgg { + fn name(&self) -> &'static str { + "max" + } + + fn reset(&mut self) { + reset_minmax(&mut self.state); + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + agg_minmax(&mut self.state, array, row_idx, &self.field_name, false) + } + + fn result(&self) -> crate::Result { + minmax_result(&self.state, "max", &self.field_name) + } +} + +// --------------------------------------------------------------------------- +// Count +// --------------------------------------------------------------------------- + +/// `count` accumulates the number of non-NULL inputs. The output is a +/// `BIGINT` scalar; the input column must also be declared as BIGINT so the +/// existing data-file layout can hold the i64 counter without an extra cast +/// layer. Non-BIGINT columns are rejected at construction. +/// +/// Aligns with Java `FieldCountAgg`, which likewise produces a BIGINT output. +#[derive(Debug)] +pub(crate) struct CountAgg { + field_name: String, + count: i64, +} + +impl CountAgg { + pub(crate) fn new(field_name: &str, data_type: &DataType) -> crate::Result { + // Count requires the output column to be BIGINT to hold an i64 value. + match data_type { + DataType::BigInt(_) => Ok(Self { + field_name: field_name.to_string(), + count: 0, + }), + other => Err(crate::Error::ConfigInvalid { + message: format!( + "Aggregate function 'count' requires field '{field_name}' to be \ + declared as BIGINT, found {other:?}" + ), + }), + } + } +} + +impl FieldAggregator for CountAgg { + fn name(&self) -> &'static str { + "count" + } + + fn reset(&mut self) { + self.count = 0; + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + // BIGINT input column: existing data files for the count column hold + // i64 placeholders (caller writes the per-row count, typically 1). + // We follow the Java reference and only check the null bit; the actual + // value is ignored. This lets BIGINT-typed input flow through the + // sort-merge reader without an extra cast layer. + if !array.is_null(row_idx) { + self.count = self + .count + .checked_add(1) + .ok_or_else(|| crate::Error::DataInvalid { + message: format!( + "Aggregate function 'count' overflowed i64 for field '{}'", + self.field_name + ), + source: None, + })?; + } + Ok(()) + } + + fn result(&self) -> crate::Result { + Ok(Arc::new(Int64Array::from(vec![self.count]))) + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn downcast<'a, T: Array + 'static>( + array: &'a dyn Array, + field_name: &str, +) -> crate::Result<&'a T> { + array + .as_any() + .downcast_ref::() + .ok_or_else(|| crate::Error::DataInvalid { + message: format!( + "Aggregate column '{field_name}' received Arrow array of unexpected \ + type {:?}; expected {}", + array.data_type(), + std::any::type_name::() + ), + source: None, + }) +} + +fn decimal_precision(precision: u32, field_name: &str) -> crate::Result { + u8::try_from(precision).map_err(|_| crate::Error::Unsupported { + message: format!( + "Decimal precision {precision} on field '{field_name}' exceeds u8 (Arrow limit)" + ), + }) +} + +fn decimal_scale(scale: u32, field_name: &str) -> crate::Result { + i8::try_from(scale as i32).map_err(|_| crate::Error::Unsupported { + message: format!( + "Decimal scale {scale} on field '{field_name}' is out of i8 range (Arrow limit)" + ), + }) +} + +fn overflow_error(agg_name: &str, field_name: &str) -> crate::Error { + crate::Error::DataInvalid { + message: format!("Aggregate function '{agg_name}' overflowed on field '{field_name}'"), + source: None, + } +} + +fn decimal_array( + precision: u8, + scale: i8, + value: Option, + agg_name: &str, + field_name: &str, +) -> crate::Result { + let mut builder = Decimal128Builder::with_capacity(1) + .with_precision_and_scale(precision, scale) + .map_err(|e| crate::Error::DataInvalid { + message: format!( + "Aggregate function '{agg_name}' failed to build Decimal128 array for \ + field '{field_name}': {e}" + ), + source: Some(Box::new(e)), + })?; + match value { + Some(v) => builder.append_value(v), + None => builder.append_null(), + } + Ok(Arc::new(builder.finish())) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::spec::{ + BigIntType, CharType, DateType, DecimalType, DoubleType, FloatType, IntType, SmallIntType, + TimeType, TimestampType, TinyIntType, VarCharType, + }; + use arrow_array::builder::Decimal128Builder; + + fn sum_agg(dt: DataType) -> SumAgg { + SumAgg::new("v", &dt).unwrap() + } + fn min_agg(dt: DataType) -> MinAgg { + MinAgg::new("v", &dt).unwrap() + } + fn max_agg(dt: DataType) -> MaxAgg { + MaxAgg::new("v", &dt).unwrap() + } + + fn collect_i32(arr: ArrayRef) -> Option { + let a = arr.as_any().downcast_ref::().unwrap(); + if a.is_null(0) { + None + } else { + Some(a.value(0)) + } + } + + fn collect_i64(arr: ArrayRef) -> Option { + let a = arr.as_any().downcast_ref::().unwrap(); + if a.is_null(0) { + None + } else { + Some(a.value(0)) + } + } + + fn collect_string(arr: ArrayRef) -> Option { + let a = arr.as_any().downcast_ref::().unwrap(); + if a.is_null(0) { + None + } else { + Some(a.value(0).to_string()) + } + } + + #[test] + fn test_sum_int_aggregates_non_null_values() { + let mut agg = sum_agg(DataType::Int(IntType::new())); + let arr = Int32Array::from(vec![Some(1), None, Some(2), Some(3)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i32(agg.result().unwrap()), Some(6)); + } + + #[test] + fn test_sum_all_null_returns_null() { + let mut agg = sum_agg(DataType::BigInt(BigIntType::new())); + let arr = Int64Array::from(vec![None::, None]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i64(agg.result().unwrap()), None); + } + + #[test] + fn test_sum_rejects_overflow() { + let mut agg = sum_agg(DataType::Int(IntType::new())); + let arr = Int32Array::from(vec![i32::MAX, 1]); + agg.agg(&arr, 0).unwrap(); + let err = agg.agg(&arr, 1).unwrap_err(); + assert!( + matches!(err, crate::Error::DataInvalid { message, .. } if message.contains("overflowed")) + ); + } + + #[test] + fn test_sum_rejects_non_numeric_type() { + let err = SumAgg::new("v", &DataType::VarChar(VarCharType::new(255).unwrap())).unwrap_err(); + assert!(matches!(err, crate::Error::ConfigInvalid { message } if message.contains("sum"))); + } + + #[test] + fn test_sum_reset_clears_state() { + let mut agg = sum_agg(DataType::Int(IntType::new())); + let arr = Int32Array::from(vec![Some(10)]); + agg.agg(&arr, 0).unwrap(); + agg.reset(); + assert_eq!(collect_i32(agg.result().unwrap()), None); + } + + #[test] + fn test_sum_float_skips_null_and_handles_partial() { + let mut agg = sum_agg(DataType::Double(DoubleType::new())); + let arr = Float64Array::from(vec![Some(1.5), None, Some(2.5)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + let a = agg.result().unwrap(); + let v = a.as_any().downcast_ref::().unwrap().value(0); + assert!((v - 4.0).abs() < 1e-9); + } + + #[test] + fn test_sum_decimal_aggregates_raw_values() { + let mut agg = sum_agg(DataType::Decimal(DecimalType::new(10, 2).unwrap())); + let mut b = Decimal128Builder::with_capacity(2) + .with_precision_and_scale(10, 2) + .unwrap(); + b.append_value(100); // 1.00 + b.append_value(250); // 2.50 + let arr = b.finish(); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + let out = agg.result().unwrap(); + let out_arr = out.as_any().downcast_ref::().unwrap(); + assert_eq!(out_arr.value(0), 350); // 3.50 + } + + #[test] + fn test_product_int_aggregates() { + let mut agg = ProductAgg::new("v", &DataType::Int(IntType::new())).unwrap(); + let arr = Int32Array::from(vec![Some(2), None, Some(3), Some(4)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i32(agg.result().unwrap()), Some(24)); + } + + #[test] + fn test_product_rejects_overflow() { + let mut agg = ProductAgg::new("v", &DataType::SmallInt(SmallIntType::new())).unwrap(); + let arr = Int16Array::from(vec![i16::MAX, 2]); + agg.agg(&arr, 0).unwrap(); + let err = agg.agg(&arr, 1).unwrap_err(); + assert!(matches!(err, crate::Error::DataInvalid { .. })); + } + + #[test] + fn test_product_all_null_returns_null() { + let mut agg = ProductAgg::new("v", &DataType::Int(IntType::new())).unwrap(); + let arr = Int32Array::from(vec![None::, None]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i32(agg.result().unwrap()), None); + } + + #[test] + fn test_product_rejects_decimal_until_scale_handling_lands() { + // DECIMAL multiplication needs BigDecimal-style scale rebasing; the + // basic mode rejects it explicitly instead of silently shifting the + // implied scale. + let err = + ProductAgg::new("v", &DataType::Decimal(DecimalType::new(10, 2).unwrap())).unwrap_err(); + assert!( + matches!(err, crate::Error::ConfigInvalid { ref message } if message.contains("DECIMAL")) + ); + } + + #[test] + fn test_min_int_picks_smallest_skipping_null() { + let mut agg = min_agg(DataType::Int(IntType::new())); + let arr = Int32Array::from(vec![Some(3), None, Some(1), Some(2)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i32(agg.result().unwrap()), Some(1)); + } + + #[test] + fn test_max_string_picks_lex_largest() { + let mut agg = max_agg(DataType::Char(CharType::new(8).unwrap())); + let arr = StringArray::from(vec![Some("ant"), None, Some("zebra"), Some("bee")]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!( + collect_string(agg.result().unwrap()), + Some("zebra".to_string()) + ); + } + + #[test] + fn test_min_max_skip_nan_floats() { + let mut agg = min_agg(DataType::Float(FloatType::new())); + let arr = Float32Array::from(vec![Some(f32::NAN), Some(1.0), Some(0.5)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + let v = agg.result().unwrap(); + let v = v.as_any().downcast_ref::().unwrap().value(0); + assert!((v - 0.5).abs() < 1e-6); + } + + #[test] + fn test_min_max_all_null_returns_null() { + let mut agg = max_agg(DataType::Int(IntType::new())); + let arr = Int32Array::from(vec![None::, None]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i32(agg.result().unwrap()), None); + } + + #[test] + fn test_min_rejects_unsupported_type() { + // Boolean has no <, > defined for min/max in Paimon basic mode. + let err = + MinAgg::new("v", &DataType::Boolean(crate::spec::BooleanType::new())).unwrap_err(); + assert!(matches!(err, crate::Error::ConfigInvalid { message } if message.contains("min"))); + } + + #[test] + fn test_min_max_date_and_timestamp_supported() { + // Date32 + let mut agg = min_agg(DataType::Date(DateType::new())); + let arr = Date32Array::from(vec![Some(100), Some(50), Some(200)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + let v = agg.result().unwrap(); + let v = v.as_any().downcast_ref::().unwrap().value(0); + assert_eq!(v, 50); + + // Timestamp(6) → Microsecond + let mut agg = max_agg(DataType::Timestamp(TimestampType::new(6).unwrap())); + let arr = + TimestampMicrosecondArray::from(vec![Some(1_000_000), Some(2_000_000), Some(500_000)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + let v = agg.result().unwrap(); + let v = v + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(v, 2_000_000); + } + + #[test] + fn test_min_max_time_supported() { + // Paimon `TIME` is always stored as Arrow Time32(Millisecond) by + // `paimon_type_to_arrow`, so milliseconds is the only carrier here. + let mut agg = min_agg(DataType::Time(TimeType::new(3).unwrap())); + let arr = Time32MillisecondArray::from(vec![Some(60_000), Some(30_000), Some(90_000)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + let v = agg.result().unwrap(); + let v = v + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(v, 30_000); + } + + #[test] + fn test_count_counts_non_null_and_outputs_bigint() { + let mut agg = CountAgg::new("c", &DataType::BigInt(BigIntType::new())).unwrap(); + let arr = Int64Array::from(vec![Some(1), None, Some(1), Some(1)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i64(agg.result().unwrap()), Some(3)); + } + + #[test] + fn test_count_reset_clears_state() { + let mut agg = CountAgg::new("c", &DataType::BigInt(BigIntType::new())).unwrap(); + let arr = Int64Array::from(vec![Some(1), Some(1)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + agg.reset(); + assert_eq!(collect_i64(agg.result().unwrap()), Some(0)); + } + + #[test] + fn test_count_rejects_non_bigint_column() { + let err = CountAgg::new("c", &DataType::Int(IntType::new())).unwrap_err(); + assert!( + matches!(err, crate::Error::ConfigInvalid { message } if message.contains("BIGINT")) + ); + } + + #[test] + fn test_count_empty_group_returns_zero() { + let agg = CountAgg::new("c", &DataType::BigInt(BigIntType::new())).unwrap(); + assert_eq!(collect_i64(agg.result().unwrap()), Some(0)); + } + + #[test] + fn test_tinyint_sum_supported() { + let mut agg = sum_agg(DataType::TinyInt(TinyIntType::new())); + let arr = Int8Array::from(vec![Some(1i8), Some(2)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + let out = agg.result().unwrap(); + let v = out.as_any().downcast_ref::().unwrap().value(0); + assert_eq!(v, 3); + } +} diff --git a/crates/paimon/src/table/aggregator/value.rs b/crates/paimon/src/table/aggregator/value.rs new file mode 100644 index 00000000..499c85fe --- /dev/null +++ b/crates/paimon/src/table/aggregator/value.rs @@ -0,0 +1,240 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Generic "pick a row" aggregators: `last_value`, `first_value`, +//! `last_non_null_value`, `first_non_null_value`. +//! +//! These accept any Paimon type because the accumulator stores a 1-row Arrow +//! slice of the source array rather than a typed scalar. The merge function +//! feeds rows in user-sequence ascending order, so "last" means the row with +//! the highest sequence and "first" the lowest. +//! +//! Reference: Java `FieldLastValueAgg`, `FieldFirstValueAgg`, +//! `FieldLastNonNullValueAgg`, `FieldFirstNonNullValueAgg` under +//! `org.apache.paimon.mergetree.compact.aggregate`. + +use arrow_array::{new_null_array, Array, ArrayRef}; +use arrow_schema::DataType as ArrowDataType; + +use super::FieldAggregator; +use crate::arrow::paimon_type_to_arrow; +use crate::spec::DataType; + +/// What constitutes "a winning row" for a given pick-style aggregator. +#[derive(Clone, Copy, Debug)] +enum PickPolicy { + /// Replace the winner on every call, including NULL inputs. + Last, + /// Keep only the first call; subsequent calls (NULL or otherwise) are + /// ignored. + First, + /// Replace on every non-NULL input. + LastNonNull, + /// Keep only the first non-NULL input; later inputs are ignored. + FirstNonNull, +} + +/// Internal accumulator shared by all four pick-style aggregators. Only the +/// outer typed wrappers (e.g. [`LastValueAgg`]) implement [`FieldAggregator`]; +/// this struct exposes inherent methods that the wrappers delegate to. +#[derive(Debug)] +struct PickValueAgg { + policy: PickPolicy, + arrow_type: ArrowDataType, + /// 1-row Arrow array holding the currently-winning value; `None` means + /// no winning row has been observed yet for the current group. + winner: Option, +} + +impl PickValueAgg { + fn new(policy: PickPolicy, data_type: &DataType) -> crate::Result { + Ok(Self { + policy, + arrow_type: paimon_type_to_arrow(data_type)?, + winner: None, + }) + } + + fn should_replace(&self, is_null: bool) -> bool { + match self.policy { + PickPolicy::Last => true, + PickPolicy::First => self.winner.is_none(), + PickPolicy::LastNonNull => !is_null, + PickPolicy::FirstNonNull => self.winner.is_none() && !is_null, + } + } + + fn reset(&mut self) { + self.winner = None; + } + + fn agg(&mut self, array: &dyn Array, row_idx: usize) { + if self.should_replace(array.is_null(row_idx)) { + self.winner = Some(array.slice(row_idx, 1)); + } + } + + fn result(&self) -> ArrayRef { + match &self.winner { + Some(arr) => arr.clone(), + None => new_null_array(&self.arrow_type, 1), + } + } +} + +macro_rules! pick_agg { + ($struct_name:ident, $factory_name:literal, $policy:expr) => { + #[derive(Debug)] + pub(crate) struct $struct_name(PickValueAgg); + + impl $struct_name { + pub(crate) fn new(_field_name: &str, data_type: &DataType) -> crate::Result { + Ok(Self(PickValueAgg::new($policy, data_type)?)) + } + } + + impl FieldAggregator for $struct_name { + fn name(&self) -> &'static str { + $factory_name + } + fn reset(&mut self) { + self.0.reset(); + } + fn agg(&mut self, array: &dyn Array, row_idx: usize) -> crate::Result<()> { + self.0.agg(array, row_idx); + Ok(()) + } + fn result(&self) -> crate::Result { + Ok(self.0.result()) + } + } + }; +} + +pick_agg!(LastValueAgg, "last_value", PickPolicy::Last); +pick_agg!(FirstValueAgg, "first_value", PickPolicy::First); +pick_agg!( + LastNonNullValueAgg, + "last_non_null_value", + PickPolicy::LastNonNull +); +pick_agg!( + FirstNonNullValueAgg, + "first_non_null_value", + PickPolicy::FirstNonNull +); + +#[cfg(test)] +mod tests { + use super::*; + use crate::spec::{IntType, VarCharType}; + use arrow_array::{Int32Array, StringArray}; + + fn collect_i32(arr: ArrayRef) -> Option { + let a = arr.as_any().downcast_ref::().unwrap(); + if a.is_null(0) { + None + } else { + Some(a.value(0)) + } + } + + fn collect_str(arr: ArrayRef) -> Option { + let a = arr.as_any().downcast_ref::().unwrap(); + if a.is_null(0) { + None + } else { + Some(a.value(0).to_string()) + } + } + + #[test] + fn test_last_value_includes_trailing_null() { + let mut agg = LastValueAgg::new("v", &DataType::Int(IntType::new())).unwrap(); + let arr = Int32Array::from(vec![Some(1), Some(2), None]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + // Last row was NULL; last_value preserves it. + assert_eq!(collect_i32(agg.result().unwrap()), None); + } + + #[test] + fn test_first_value_locks_first_row_including_null() { + let mut agg = FirstValueAgg::new("v", &DataType::Int(IntType::new())).unwrap(); + let arr = Int32Array::from(vec![None, Some(2), Some(3)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i32(agg.result().unwrap()), None); + } + + #[test] + fn test_last_non_null_value_skips_trailing_null() { + let mut agg = LastNonNullValueAgg::new("v", &DataType::Int(IntType::new())).unwrap(); + let arr = Int32Array::from(vec![Some(1), Some(2), None]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i32(agg.result().unwrap()), Some(2)); + } + + #[test] + fn test_first_non_null_value_locks_first_non_null() { + let mut agg = FirstNonNullValueAgg::new("v", &DataType::Int(IntType::new())).unwrap(); + let arr = Int32Array::from(vec![None, Some(5), Some(7)]); + for i in 0..arr.len() { + agg.agg(&arr, i).unwrap(); + } + assert_eq!(collect_i32(agg.result().unwrap()), Some(5)); + } + + #[test] + fn test_pick_aggregators_handle_string() { + let dt = DataType::VarChar(VarCharType::new(255).unwrap()); + let arr = StringArray::from(vec![Some("a"), None, Some("c")]); + + let mut last = LastNonNullValueAgg::new("v", &dt).unwrap(); + for i in 0..arr.len() { + last.agg(&arr, i).unwrap(); + } + assert_eq!(collect_str(last.result().unwrap()), Some("c".to_string())); + + let mut first = FirstNonNullValueAgg::new("v", &dt).unwrap(); + for i in 0..arr.len() { + first.agg(&arr, i).unwrap(); + } + assert_eq!(collect_str(first.result().unwrap()), Some("a".to_string())); + } + + #[test] + fn test_empty_group_returns_null_array_of_correct_type() { + let agg = LastValueAgg::new("v", &DataType::Int(IntType::new())).unwrap(); + let out = agg.result().unwrap(); + assert!(out.is_null(0)); + assert_eq!(out.data_type(), &ArrowDataType::Int32); + } + + #[test] + fn test_reset_clears_state() { + let mut agg = LastValueAgg::new("v", &DataType::Int(IntType::new())).unwrap(); + let arr = Int32Array::from(vec![Some(99)]); + agg.agg(&arr, 0).unwrap(); + agg.reset(); + assert!(agg.result().unwrap().is_null(0)); + } +} diff --git a/crates/paimon/src/table/bucket_assigner_cross.rs b/crates/paimon/src/table/bucket_assigner_cross.rs index a021761b..c3f1acc2 100644 --- a/crates/paimon/src/table/bucket_assigner_cross.rs +++ b/crates/paimon/src/table/bucket_assigner_cross.rs @@ -192,6 +192,13 @@ impl GlobalPartitionIndex { message: "CrossPartitionAssigner does not support merge-engine=partial-update yet".to_string(), }); } + MergeEngine::Aggregation => { + return Err(crate::Error::Unsupported { + message: + "CrossPartitionAssigner does not support merge-engine=aggregation yet" + .to_string(), + }); + } } } diff --git a/crates/paimon/src/table/kv_file_reader.rs b/crates/paimon/src/table/kv_file_reader.rs index 128b2dab..c6e21f5d 100644 --- a/crates/paimon/src/table/kv_file_reader.rs +++ b/crates/paimon/src/table/kv_file_reader.rs @@ -25,7 +25,8 @@ use super::data_file_reader::DataFileReader; use super::sort_merge::{ - DeduplicateMergeFunction, PartialUpdateMergeFunction, SortMergeReaderBuilder, + AggregateMergeFunction, DeduplicateMergeFunction, PartialUpdateMergeFunction, + SortMergeReaderBuilder, }; use crate::arrow::build_target_arrow_schema; use crate::io::FileIO; @@ -102,6 +103,9 @@ impl KeyValueFileReader { merge_engine: MergeEngine, table_options: &HashMap, table_name: &str, + merge_output_fields: &[DataField], + primary_keys: &[String], + sequence_fields: &[String], ) -> crate::Result> { match merge_engine { MergeEngine::Deduplicate => Ok(Box::new(DeduplicateMergeFunction)), @@ -112,6 +116,13 @@ impl KeyValueFileReader { MergeEngine::FirstRow => Err(Error::Unsupported { message: "KeyValueFileReader does not support merge-engine=first-row; first-row reads should use the non-KV path".to_string(), }), + MergeEngine::Aggregation => Ok(Box::new(AggregateMergeFunction::new( + table_options, + table_name, + merge_output_fields, + primary_keys, + sequence_fields, + )?)), } } @@ -264,6 +275,8 @@ impl KeyValueFileReader { let table_name = self.config.table_name; let table_options = self.config.table_options; let predicates = self.config.predicates; + let primary_keys = self.config.primary_keys; + let sequence_fields = self.config.sequence_fields; // Build the merge output schema (keys + values, no system columns). let mut merge_output_fields: Vec = Vec::new(); @@ -328,7 +341,14 @@ impl KeyValueFileReader { user_sequence_indices.clone(), value_indices.clone(), merge_output_schema.clone(), - Self::new_merge_function(merge_engine, &table_options, &table_name)?, + Self::new_merge_function( + merge_engine, + &table_options, + &table_name, + &merge_output_fields, + &primary_keys, + &sequence_fields, + )?, ) .build()?; diff --git a/crates/paimon/src/table/kv_file_writer.rs b/crates/paimon/src/table/kv_file_writer.rs index a6d71800..f534dbab 100644 --- a/crates/paimon/src/table/kv_file_writer.rs +++ b/crates/paimon/src/table/kv_file_writer.rs @@ -30,8 +30,9 @@ use crate::arrow::format::create_format_writer; use crate::io::FileIO; use crate::spec::stats::{compute_column_stats, BinaryTableStats}; use crate::spec::{ - extract_datum_from_arrow, BinaryRowBuilder, DataFileMeta, DataType, MergeEngine, - PartialUpdateConfig, EMPTY_SERIALIZED_ROW, SEQUENCE_NUMBER_FIELD_NAME, VALUE_KIND_FIELD_NAME, + extract_datum_from_arrow, AggregationConfig, BinaryRowBuilder, DataFileMeta, DataType, + MergeEngine, PartialUpdateConfig, EMPTY_SERIALIZED_ROW, SEQUENCE_NUMBER_FIELD_NAME, + VALUE_KIND_FIELD_NAME, }; use crate::Result; use arrow_array::{Int64Array, Int8Array, RecordBatch}; @@ -101,6 +102,20 @@ impl KeyValueFileWriter { } } + if config.merge_engine == MergeEngine::Aggregation { + AggregationConfig::new(&config.table_options) + .validate_runtime_mode(true, &config.table_name)?; + + if config.deletion_vectors_enabled { + return Err(crate::Error::Unsupported { + message: format!( + "Table '{}' uses merge-engine=aggregation with deletion-vectors.enabled=true, which is not supported yet", + config.table_name + ), + }); + } + } + Ok(Self { file_io, config, @@ -392,7 +407,9 @@ impl KeyValueFileWriter { MergeEngine::Deduplicate | MergeEngine::FirstRow => { self.dedup_sorted_indices(batch, sorted_indices) } - MergeEngine::PartialUpdate => Ok((0..sorted_indices.len()) + // Aggregation, like PartialUpdate, keeps every row on flush and + // performs the per-field merge on the read side. + MergeEngine::PartialUpdate | MergeEngine::Aggregation => Ok((0..sorted_indices.len()) .map(|idx| sorted_indices.value(idx)) .collect()), } @@ -452,8 +469,9 @@ impl KeyValueFileWriter { MergeEngine::Deduplicate => group_winner = cur, // FirstRow: keep first (lowest seq), so don't update. MergeEngine::FirstRow => {} - MergeEngine::PartialUpdate => unreachable!( - "partial-update should use select_flush_indices and skip dedup" + MergeEngine::PartialUpdate | MergeEngine::Aggregation => unreachable!( + "{:?} should use select_flush_indices and skip dedup", + self.config.merge_engine ), } } else { @@ -524,8 +542,14 @@ mod tests { fn test_write_config(merge_engine: MergeEngine) -> KeyValueWriteConfig { let mut table_options = HashMap::new(); - if merge_engine == MergeEngine::PartialUpdate { - table_options.insert("merge-engine".to_string(), "partial-update".to_string()); + match merge_engine { + MergeEngine::PartialUpdate => { + table_options.insert("merge-engine".to_string(), "partial-update".to_string()); + } + MergeEngine::Aggregation => { + table_options.insert("merge-engine".to_string(), "aggregation".to_string()); + } + MergeEngine::Deduplicate | MergeEngine::FirstRow => {} } KeyValueWriteConfig { @@ -647,4 +671,68 @@ mod tests { if message.contains("fields.price.aggregate-function") )); } + + #[test] + fn test_select_flush_indices_keeps_all_rows_for_aggregation_engine() { + let schema = Arc::new(ArrowSchema::new(vec![ + Arc::new(ArrowField::new("id", ArrowDataType::Int32, false)), + Arc::new(ArrowField::new("seq", ArrowDataType::Int64, false)), + ])); + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int32Array::from(vec![1, 1])) as Arc, + Arc::new(Int64Array::from(vec![10, 20])) as Arc, + ], + ) + .unwrap(); + let sorted_indices = UInt32Array::from(vec![0, 1]); + let writer = KeyValueFileWriter::new( + FileIOBuilder::new("memory").build().unwrap(), + test_write_config(MergeEngine::Aggregation), + 0, + ) + .unwrap(); + + let selected = writer + .select_flush_indices(&batch, &sorted_indices) + .unwrap(); + + assert_eq!(selected, vec![0, 1]); + } + + #[test] + fn test_new_rejects_aggregation_with_deletion_vectors() { + let mut config = test_write_config(MergeEngine::Aggregation); + config.deletion_vectors_enabled = true; + + let err = KeyValueFileWriter::new(FileIOBuilder::new("memory").build().unwrap(), config, 0) + .err() + .unwrap(); + + assert!(matches!( + err, + crate::Error::Unsupported { message } + if message.contains("deletion-vectors.enabled=true") + )); + } + + #[test] + fn test_new_rejects_unsupported_aggregation_options() { + let mut config = test_write_config(MergeEngine::Aggregation); + config.table_options.insert( + "fields.price.ignore-retract".to_string(), + "true".to_string(), + ); + + let err = KeyValueFileWriter::new(FileIOBuilder::new("memory").build().unwrap(), config, 0) + .err() + .unwrap(); + + assert!(matches!( + err, + crate::Error::Unsupported { message } + if message.contains("fields.price.ignore-retract") + )); + } } diff --git a/crates/paimon/src/table/mod.rs b/crates/paimon/src/table/mod.rs index c00fa786..7c3bb513 100644 --- a/crates/paimon/src/table/mod.rs +++ b/crates/paimon/src/table/mod.rs @@ -17,6 +17,7 @@ //! Table API for Apache Paimon +pub(crate) mod aggregator; pub(crate) mod bin_pack; mod blob_file_writer; mod branch_manager; diff --git a/crates/paimon/src/table/sort_merge.rs b/crates/paimon/src/table/sort_merge.rs index 8a9ece4f..710b0429 100644 --- a/crates/paimon/src/table/sort_merge.rs +++ b/crates/paimon/src/table/sort_merge.rs @@ -26,7 +26,8 @@ //! - DataFusion: `SortPreservingMergeStream` (LoserTree layout) //! - Arrow-row: `RowConverter` for efficient key comparison -use crate::spec::{PartialUpdateConfig, RowKind}; +use crate::spec::{AggregationConfig, DataField, PartialUpdateConfig, RowKind}; +use crate::table::aggregator::{new_aggregator, FieldAggregator}; use crate::table::ArrowRecordBatchStream; use crate::Error; use arrow_array::{new_null_array, ArrayRef, Int64Array, Int8Array, RecordBatch}; @@ -37,6 +38,8 @@ use async_stream::try_stream; use futures::StreamExt; use std::cmp::Ordering; use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Mutex; // --------------------------------------------------------------------------- // MergeFunction @@ -272,6 +275,212 @@ impl MergeFunction for PartialUpdateMergeFunction { } } +// --------------------------------------------------------------------------- +// AggregateMergeFunction +// --------------------------------------------------------------------------- + +/// Basic aggregation merge: for each non-key, non-sequence column, apply a +/// per-field aggregator across all rows sharing the same primary key. +/// +/// The merge function honors the following contract: +/// +/// - Primary-key columns are copied from any row of the group (Paimon +/// guarantees they all share the same value); no aggregator is constructed +/// for them. +/// - Columns listed in the `sequence_fields` constructor argument (which the +/// reader populates from the `sequence.field` table option) are forced to +/// `last_value` regardless of any per-field configuration, matching Java +/// `AggregateMergeFunction#createFieldAggregators`. +/// - Every other output column requires either +/// `fields..aggregate-function` or a fall-back +/// `fields.default-aggregate-function`; otherwise construction fails with +/// [`Error::ConfigInvalid`]. +/// - DELETE / UPDATE_BEFORE rows are rejected at runtime; retract handling +/// is left to a follow-up commit. +/// +/// `aggregators` is held behind a `Mutex` so the implementation can mutate +/// per-key accumulators inside `MergeFunction::merge(&self, ...)` without +/// changing the trait signature shared with the other merge functions. The +/// merge function is invoked sequentially by `sort_merge_stream`, so the +/// lock is effectively uncontended. +/// +/// Reference: Java `org.apache.paimon.mergetree.compact.aggregate.AggregateMergeFunction`. +/// +/// [`Error::ConfigInvalid`]: crate::Error::ConfigInvalid +#[derive(Debug)] +pub(crate) struct AggregateMergeFunction { + /// One slot per output column. `None` marks primary-key columns that are + /// copied through; `Some` holds the aggregator that owns the column. + aggregators: Mutex>>>, +} + +impl AggregateMergeFunction { + pub(crate) fn new( + table_options: &HashMap, + table_name: &str, + output_fields: &[DataField], + primary_keys: &[String], + sequence_fields: &[String], + ) -> crate::Result { + let config = AggregationConfig::new(table_options); + config.validate_runtime_mode(true, table_name)?; + + let pk_set: HashSet<&str> = primary_keys.iter().map(String::as_str).collect(); + let seq_set: HashSet<&str> = sequence_fields.iter().map(String::as_str).collect(); + + let aggregators: Vec>> = output_fields + .iter() + .map(|field| -> crate::Result>> { + let name = field.name(); + if pk_set.contains(name) { + return Ok(None); + } + // Sequence fields are forced to last_value, mirroring Java + // AggregateMergeFunction#createFieldAggregators. + let agg_name: String = if seq_set.contains(name) { + "last_value".to_string() + } else if let Some(per_field) = config.agg_function_for_field(name) { + per_field.to_string() + } else if let Some(default) = config.default_agg_function() { + default.to_string() + } else { + return Err(crate::Error::ConfigInvalid { + message: format!( + "Field '{name}' has no aggregate-function configured for \ + merge-engine=aggregation on table '{table_name}'; set \ + fields.{name}.aggregate-function or fields.default-aggregate-function" + ), + }); + }; + Ok(Some(new_aggregator( + agg_name.as_str(), + name, + field.data_type(), + table_options, + )?)) + }) + .collect::>>()?; + + Ok(Self { + aggregators: Mutex::new(aggregators), + }) + } +} + +impl MergeFunction for AggregateMergeFunction { + fn merge( + &self, + rows: &[MergeRow], + batch_buffer: &[BufferedBatch], + source_output_col_indices: &[usize], + output_schema: &SchemaRef, + ) -> crate::Result { + if rows.is_empty() { + return Err(Error::UnexpectedError { + message: "merge called with empty rows".to_string(), + source: None, + }); + } + + // Reject retract rows up-front so partial accumulation cannot leak + // into the output if a DELETE shows up mid-group. + for row in rows { + if !RowKind::from_value(row.value_kind)?.is_add() { + return Err(crate::Error::Unsupported { + message: "merge-engine=aggregation basic mode does not support DELETE or UPDATE_BEFORE rows".to_string(), + }); + } + } + + // Sort row indices by user sequence (if configured), then system + // sequence, so per-field aggregators see the canonical "ascending + // sequence" order documented in their contracts. + let mut ordered_row_indices: Vec = (0..rows.len()).collect(); + ordered_row_indices.sort_by(|&lhs_idx, &rhs_idx| { + compare_sequence_order(&rows[lhs_idx], &rows[rhs_idx]) + .then_with(|| lhs_idx.cmp(&rhs_idx)) + }); + + let mut aggregators = self + .aggregators + .lock() + .map_err(|e| Error::UnexpectedError { + message: format!("AggregateMergeFunction aggregator mutex poisoned: {e}"), + source: None, + })?; + for slot in aggregators.iter_mut() { + if let Some(agg) = slot.as_mut() { + agg.reset(); + } + } + + for &row_idx in &ordered_row_indices { + let row = &rows[row_idx]; + for (col_idx, slot) in aggregators.iter_mut().enumerate() { + if let Some(agg) = slot.as_mut() { + let source_array = batch_buffer[row.batch_idx] + .column_for_output(col_idx, source_output_col_indices); + agg.agg(source_array, row.row_idx)?; + } + } + } + + // Use the last sorted row to source primary-key column values: every + // row in the group shares the same PK by construction, so any row + // works; picking the last one keeps the slice cheap to compute. + let pk_source = &rows[*ordered_row_indices.last().unwrap()]; + + let output_columns: Vec = aggregators + .iter() + .enumerate() + .map(|(col_idx, slot)| -> crate::Result { + match slot { + Some(agg) => agg.result(), + None => Ok(batch_buffer[pk_source.batch_idx] + .column_for_output(col_idx, source_output_col_indices) + .slice(pk_source.row_idx, 1)), + } + }) + .collect::>>()?; + + // Defensive check: non-nullable fields must not contain NULL on the + // merged output (e.g. `min` on an all-NULL value group, or a NULL + // primary-key cell on the source row). Split the message so the + // operator knows whether to look at the aggregator config or at the + // upstream data. + for (col_idx, field) in output_schema.fields().iter().enumerate() { + if !field.is_nullable() && output_columns[col_idx].is_null(0) { + let message = match aggregators[col_idx].as_ref() { + Some(agg) => format!( + "merge-engine=aggregation: aggregator '{}' produced NULL for \ + non-nullable field '{}'", + agg.name(), + field.name() + ), + None => format!( + "merge-engine=aggregation: primary-key column '{}' contains NULL on a \ + source row; declare the column nullable or fix the upstream data", + field.name() + ), + }; + return Err(Error::DataInvalid { + message, + source: None, + }); + } + } + + let batch = RecordBatch::try_new(output_schema.clone(), output_columns).map_err(|e| { + Error::UnexpectedError { + message: format!("Failed to build aggregation materialized row: {e}"), + source: Some(Box::new(e)), + } + })?; + + Ok(MergeResult::MaterializedRow(batch)) + } +} + // --------------------------------------------------------------------------- // SortMergeCursor // --------------------------------------------------------------------------- @@ -1719,4 +1928,327 @@ mod tests { if message.contains("fields.price.aggregate-function") )); } + + // ---------- AggregateMergeFunction ---------- + + use crate::spec::{DataType as PaimonDataType, IntType, VarCharType}; + + fn aggregation_output_fields() -> Vec { + vec![ + DataField::new(0, "pk".into(), PaimonDataType::Int(IntType::new())), + DataField::new(1, "amount".into(), PaimonDataType::Int(IntType::new())), + DataField::new( + 2, + "tag".into(), + PaimonDataType::VarChar(VarCharType::new(255).unwrap()), + ), + ] + } + + fn aggregation_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("pk", DataType::Int32, false), + Field::new("_SEQUENCE_NUMBER", DataType::Int64, false), + Field::new("_VALUE_KIND", DataType::Int8, false), + Field::new("amount", DataType::Int32, true), + Field::new("tag", DataType::Utf8, true), + ])) + } + + fn aggregation_output_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("pk", DataType::Int32, false), + Field::new("amount", DataType::Int32, true), + Field::new("tag", DataType::Utf8, true), + ])) + } + + fn agg_options(pairs: &[(&str, &str)]) -> HashMap { + let mut opts = HashMap::from([("merge-engine".to_string(), "aggregation".to_string())]); + for (k, v) in pairs { + opts.insert((*k).to_string(), (*v).to_string()); + } + opts + } + + fn build_agg_function(options: HashMap) -> Box { + Box::new( + AggregateMergeFunction::new( + &options, + "test_table", + &aggregation_output_fields(), + &["pk".to_string()], + &[], + ) + .unwrap(), + ) + } + + #[tokio::test] + async fn test_aggregate_merge_sum_and_listagg() { + let schema = aggregation_schema(); + let output_schema = aggregation_output_schema(); + let s0 = stream_from_batches(vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2])), + Arc::new(Int64Array::from(vec![1, 1])), + Arc::new(Int8Array::from(vec![0, 0])), + Arc::new(Int32Array::from(vec![Some(10), Some(20)])), + Arc::new(StringArray::from(vec![Some("a"), Some("x")])), + ], + ) + .unwrap()]); + let s1 = stream_from_batches(vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int64Array::from(vec![2, 2, 1])), + Arc::new(Int8Array::from(vec![0, 0, 0])), + Arc::new(Int32Array::from(vec![Some(5), Some(7), Some(99)])), + Arc::new(StringArray::from(vec![Some("b"), None, Some("solo")])), + ], + ) + .unwrap()]); + + let options = agg_options(&[ + ("fields.amount.aggregate-function", "sum"), + ("fields.tag.aggregate-function", "listagg"), + ("fields.tag.list-agg-delimiter", "|"), + ]); + + let batches = SortMergeReaderBuilder::new( + vec![s0, s1], + schema, + vec![0], + 1, + 2, + vec![], + vec![3, 4], + output_schema, + build_agg_function(options), + ) + .build() + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let mut rows: Vec<(i32, Option, Option)> = Vec::new(); + for batch in &batches { + let pks = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let amounts = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let tags = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + for i in 0..batch.num_rows() { + rows.push(( + pks.value(i), + if amounts.is_null(i) { + None + } else { + Some(amounts.value(i)) + }, + if tags.is_null(i) { + None + } else { + Some(tags.value(i).to_string()) + }, + )); + } + } + rows.sort_by_key(|row| row.0); + + assert_eq!( + rows, + vec![ + (1, Some(15), Some("a|b".to_string())), + (2, Some(27), Some("x".to_string())), + (3, Some(99), Some("solo".to_string())), + ] + ); + } + + #[tokio::test] + async fn test_aggregate_merge_rejects_delete_rows() { + let schema = aggregation_schema(); + let output_schema = aggregation_output_schema(); + let s0 = stream_from_batches(vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int64Array::from(vec![1])), + Arc::new(Int8Array::from(vec![0])), + Arc::new(Int32Array::from(vec![Some(10)])), + Arc::new(StringArray::from(vec![Some("a")])), + ], + ) + .unwrap()]); + // Row kind 3 = DELETE + let s1 = stream_from_batches(vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int64Array::from(vec![2])), + Arc::new(Int8Array::from(vec![3])), + Arc::new(Int32Array::from(vec![Some(99)])), + Arc::new(StringArray::from(vec![Some("b")])), + ], + ) + .unwrap()]); + + let options = agg_options(&[ + ("fields.amount.aggregate-function", "sum"), + ("fields.tag.aggregate-function", "last_value"), + ]); + + let err = SortMergeReaderBuilder::new( + vec![s0, s1], + schema, + vec![0], + 1, + 2, + vec![], + vec![3, 4], + output_schema, + build_agg_function(options), + ) + .build() + .unwrap() + .try_collect::>() + .await + .unwrap_err(); + + assert!(matches!( + err, + Error::Unsupported { ref message } + if message.contains("aggregation basic mode does not support DELETE") + )); + } + + #[test] + fn test_aggregate_merge_function_requires_agg_function_per_field() { + // amount has no per-field nor default aggregate-function configured. + let options = HashMap::from([("merge-engine".to_string(), "aggregation".to_string())]); + let err = AggregateMergeFunction::new( + &options, + "test_table", + &aggregation_output_fields(), + &["pk".to_string()], + &[], + ) + .unwrap_err(); + assert!( + matches!(err, Error::ConfigInvalid { message } if message.contains("aggregate-function")) + ); + } + + #[test] + fn test_aggregate_merge_function_sequence_field_forced_last_value() { + // 'tag' is the sequence field; even though user configured listagg, + // it should be forced to last_value (so the latest tag survives). + let schema = aggregation_schema(); + let output_schema = aggregation_output_schema(); + let options = agg_options(&[ + ("fields.amount.aggregate-function", "sum"), + ("fields.tag.aggregate-function", "listagg"), + ("sequence.field", "tag"), + ]); + let mf = AggregateMergeFunction::new( + &options, + "test_table", + &aggregation_output_fields(), + &["pk".to_string()], + &["tag".to_string()], + ) + .unwrap(); + + let s0 = stream_from_batches(vec![RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 1])), + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(Int8Array::from(vec![0, 0])), + Arc::new(Int32Array::from(vec![Some(3), Some(4)])), + Arc::new(StringArray::from(vec![Some("first"), Some("second")])), + ], + ) + .unwrap()]); + + // Use the merge function via builder. + let batches = futures::executor::block_on(async { + SortMergeReaderBuilder::new( + vec![s0], + schema, + vec![0], + 1, + 2, + vec![], + vec![3, 4], + output_schema, + Box::new(mf), + ) + .build() + .unwrap() + .try_collect::>() + .await + .unwrap() + }); + + let batch = &batches[0]; + let tags = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let amounts = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(tags.value(0), "second"); // last_value, not listagg. + assert_eq!(amounts.value(0), 7); // sum of 3 + 4. + } + + #[test] + fn test_aggregate_merge_function_default_function_applies_when_per_field_absent() { + let options = agg_options(&[("fields.default-aggregate-function", "last_non_null_value")]); + let mf = AggregateMergeFunction::new( + &options, + "test_table", + &aggregation_output_fields(), + &["pk".to_string()], + &[], + ) + .unwrap(); + // Construction must succeed: amount and tag fall back to the default. + // Tag is VarChar — last_non_null_value supports any type. + let _ = mf; + } + + #[test] + fn test_aggregate_merge_function_rejects_unsupported_options() { + let options = agg_options(&[("fields.amount.ignore-retract", "true")]); + let err = AggregateMergeFunction::new( + &options, + "test_table", + &aggregation_output_fields(), + &["pk".to_string()], + &[], + ) + .unwrap_err(); + assert!( + matches!(err, Error::Unsupported { message } if message.contains("ignore-retract")) + ); + } } diff --git a/crates/paimon/src/table/table_read.rs b/crates/paimon/src/table/table_read.rs index 54939383..b957aefa 100644 --- a/crates/paimon/src/table/table_read.rs +++ b/crates/paimon/src/table/table_read.rs @@ -79,10 +79,12 @@ impl<'a> TableRead<'a> { // PK table with Deduplicate engine: splits containing level-0 files // need KeyValueFileReader for sort-merge dedup; splits with only // compacted files (level > 0) can use the faster DataFileReader. + // PartialUpdate / Aggregation always go through KeyValueFileReader so + // that per-key materialization can run on the read side. if has_primary_keys && matches!( merge_engine, - MergeEngine::Deduplicate | MergeEngine::PartialUpdate + MergeEngine::Deduplicate | MergeEngine::PartialUpdate | MergeEngine::Aggregation ) { return self.read_pk(data_splits, &core_options); @@ -95,14 +97,20 @@ impl<'a> TableRead<'a> { } } - /// Read PK table with Deduplicate engine: level-0 splits go through - /// KeyValueFileReader for sort-merge dedup, compacted splits use DataFileReader. + /// Read a PK table. For `Deduplicate`, splits containing level-0 files go + /// through `KeyValueFileReader` (sort-merge dedup) while compacted-only + /// splits short-cut through `DataFileReader`. `PartialUpdate` and + /// `Aggregation` always go through `KeyValueFileReader` because their merge + /// semantics require per-key materialization even for compacted runs. fn read_pk( &self, data_splits: &[DataSplit], core_options: &CoreOptions, ) -> crate::Result { - if core_options.merge_engine()? == MergeEngine::PartialUpdate { + if matches!( + core_options.merge_engine()?, + MergeEngine::PartialUpdate | MergeEngine::Aggregation + ) { return self.read_kv(data_splits, core_options); } diff --git a/crates/paimon/src/table/table_write.rs b/crates/paimon/src/table/table_write.rs index 49c00198..ee1a98a9 100644 --- a/crates/paimon/src/table/table_write.rs +++ b/crates/paimon/src/table/table_write.rs @@ -233,6 +233,14 @@ impl TableWrite { }); } + if is_dynamic_cross_partition && merge_engine == MergeEngine::Aggregation { + return Err(crate::Error::Unsupported { + message: + "merge-engine=aggregation with cross-partition update is not supported yet" + .to_string(), + }); + } + if has_primary_keys && core_options.rowkind_field().is_some() { return Err(crate::Error::Unsupported { message: "KeyValueFileWriter does not support rowkind.field".to_string(), @@ -1936,6 +1944,40 @@ mod tests { )); } + #[test] + fn test_rejects_cross_partition_aggregation() { + let file_io = test_file_io(); + let table_path = "memory:/test_cross_aggregation"; + let schema = Schema::builder() + .column("pt", DataType::VarChar(VarCharType::string_type())) + .column("id", DataType::Int(IntType::new())) + .column("value", DataType::Int(IntType::new())) + .primary_key(["id"]) + .partition_keys(["pt"]) + .option("merge-engine", "aggregation") + .build() + .unwrap(); + let table = Table::new( + file_io, + Identifier::new("default", "test_cross_aggregation"), + table_path.to_string(), + TableSchema::new(0, &schema), + None, + ); + + let err = match TableWrite::new(&table, "test-user".to_string()) { + Ok(_) => panic!("cross-partition aggregation should be rejected"), + Err(err) => err, + }; + + assert!(matches!( + err, + crate::Error::Unsupported { message } + if message.contains("merge-engine=aggregation") + && message.contains("cross-partition update") + )); + } + #[tokio::test] async fn test_cross_partition_write_same_partition() { let file_io = test_file_io();