Skip to content

Commit 4925e6c

Browse files
authored
fix: dataframe function count_all with alias (#17282)
* fix: dataframe function count_all with alias
1 parent fa26515 commit 4925e6c

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

datafusion/core/src/physical_planner.rs

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1782,21 +1782,24 @@ pub fn create_aggregate_expr_and_maybe_filter(
17821782
physical_input_schema: &Schema,
17831783
execution_props: &ExecutionProps,
17841784
) -> Result<AggregateExprWithOptionalArgs> {
1785-
// unpack (nested) aliased logical expressions, e.g. "sum(col) as total"
1785+
// Unpack (potentially nested) aliased logical expressions, e.g. "sum(col) as total"
1786+
// Some functions like `count_all()` create internal aliases,
1787+
// Unwrap all alias layers to get to the underlying aggregate function
17861788
let (name, human_display, e) = match e {
1787-
Expr::Alias(Alias { expr, name, .. }) => {
1788-
(Some(name.clone()), String::default(), expr.as_ref())
1789+
Expr::Alias(Alias { name, .. }) => {
1790+
let unaliased = e.clone().unalias_nested().data;
1791+
(Some(name.clone()), e.human_display().to_string(), unaliased)
17891792
}
17901793
Expr::AggregateFunction(_) => (
17911794
Some(e.schema_name().to_string()),
17921795
e.human_display().to_string(),
1793-
e,
1796+
e.clone(),
17941797
),
1795-
_ => (None, String::default(), e),
1798+
_ => (None, String::default(), e.clone()),
17961799
};
17971800

17981801
create_aggregate_expr_with_name_and_maybe_filter(
1799-
e,
1802+
&e,
18001803
name,
18011804
human_display,
18021805
logical_input_schema,
@@ -2416,6 +2419,7 @@ mod tests {
24162419
use datafusion_expr::{
24172420
col, lit, LogicalPlanBuilder, Operator, UserDefinedLogicalNodeCore,
24182421
};
2422+
use datafusion_functions_aggregate::count::count_all;
24192423
use datafusion_functions_aggregate::expr_fn::sum;
24202424
use datafusion_physical_expr::expressions::{BinaryExpr, IsNotNullExpr};
24212425
use datafusion_physical_expr::EquivalenceProperties;
@@ -2876,6 +2880,25 @@ mod tests {
28762880
Ok(())
28772881
}
28782882

2883+
#[tokio::test]
2884+
async fn test_aggregate_count_all_with_alias() -> Result<()> {
2885+
let schema = Arc::new(Schema::new(vec![
2886+
Field::new("c1", DataType::Utf8, false),
2887+
Field::new("c2", DataType::UInt32, false),
2888+
]));
2889+
2890+
let logical_plan = scan_empty(None, schema.as_ref(), None)?
2891+
.aggregate(Vec::<Expr>::new(), vec![count_all().alias("total_rows")])?
2892+
.build()?;
2893+
2894+
let physical_plan = plan(&logical_plan).await?;
2895+
assert_eq!(
2896+
"total_rows",
2897+
physical_plan.schema().field(0).name().as_str()
2898+
);
2899+
Ok(())
2900+
}
2901+
28792902
#[tokio::test]
28802903
async fn test_explain() {
28812904
let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1316,3 +1316,28 @@ async fn test_count_wildcard() -> Result<()> {
13161316

13171317
Ok(())
13181318
}
1319+
1320+
/// Call count wildcard with alias from dataframe API
1321+
#[tokio::test]
1322+
async fn test_count_wildcard_with_alias() -> Result<()> {
1323+
let df = create_test_table().await?;
1324+
let result_df = df.aggregate(vec![], vec![count_all().alias("total_count")])?;
1325+
1326+
let schema = result_df.schema();
1327+
assert_eq!(schema.fields().len(), 1);
1328+
assert_eq!(schema.field(0).name(), "total_count");
1329+
assert_eq!(*schema.field(0).data_type(), DataType::Int64);
1330+
1331+
let batches = result_df.collect().await?;
1332+
assert_eq!(batches.len(), 1);
1333+
assert_eq!(batches[0].num_rows(), 1);
1334+
1335+
let count_array = batches[0]
1336+
.column(0)
1337+
.as_any()
1338+
.downcast_ref::<arrow::array::Int64Array>()
1339+
.unwrap();
1340+
assert_eq!(count_array.value(0), 4);
1341+
1342+
Ok(())
1343+
}

0 commit comments

Comments
 (0)