Skip to content

Commit addbd2d

Browse files
ncordonjan-elastic
authored andcommitted
Propagates filter() to aggregation functions' surrogates (elastic#134461)
--------- Co-authored-by: Jan Kuipers <[email protected]> Co-authored-by: Jan Kuipers <[email protected]>
1 parent 539987c commit addbd2d

File tree

11 files changed

+247
-10
lines changed

11 files changed

+247
-10
lines changed

docs/changelog/134461.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 134461
2+
summary: Propagates filter() to aggregation functions' surrogates
3+
area: Aggregations
4+
type: bug
5+
issues:
6+
- 134380

x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3145,3 +3145,183 @@ FROM employees
31453145
m:datetime | x:integer | d:boolean
31463146
1999-04-30T00:00:00.000Z | 2 | true
31473147
;
3148+
3149+
sumWithConditions
3150+
required_capability: stats_with_filtered_surrogate_fixed
3151+
required_capability: aggregate_metric_double_convert_to
3152+
3153+
FROM employees
3154+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(1)
3155+
| STATS sum1 = SUM(1),
3156+
sum2 = SUM(1) WHERE emp_no == 10080,
3157+
sum3 = SUM(1) WHERE emp_no < 10080,
3158+
sum4 = SUM(1) WHERE emp_no >= 10080,
3159+
sum5 = SUM(agg_metric),
3160+
sum6 = SUM(agg_metric) WHERE emp_no == 10080
3161+
;
3162+
3163+
sum1:long | sum2:long | sum3:long | sum4:long | sum5:double | sum6:double
3164+
100 | 1 | 79 | 21 | 100.0 | 1.0
3165+
;
3166+
3167+
weightedAvgWithConditions
3168+
required_capability: stats_with_filtered_surrogate_fixed
3169+
3170+
ROW x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
3171+
| MV_EXPAND x
3172+
| STATS w_avg1 = WEIGHTED_AVG(x, 1) WHERE x == 5,
3173+
w_avg2 = WEIGHTED_AVG(x, x) WHERE x == 5,
3174+
w_avg3 = WEIGHTED_AVG(x, 2) WHERE x <= 5,
3175+
w_avg4 = WEIGHTED_AVG(x, x) WHERE x > 5,
3176+
w_avg5 = WEIGHTED_AVG([1,2,3], 1),
3177+
w_avg6 = WEIGHTED_AVG([1,2,3], 1) WHERE x == 5
3178+
;
3179+
3180+
w_avg1:double | w_avg2:double | w_avg3:double | w_avg4:double | w_avg5:double | w_avg6:double
3181+
5.0 | 5.0 | 3.0 | 8.25 | 2.0 | 2.0
3182+
;
3183+
3184+
maxWithConditions
3185+
required_capability: stats_with_filtered_surrogate_fixed
3186+
required_capability: aggregate_metric_double_convert_to
3187+
3188+
ROW x = [1, 2, 3, 4, 5]
3189+
| MV_EXPAND x
3190+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3191+
| STATS max1 = MAX(agg_metric) WHERE x <= 3,
3192+
max2 = MAX(agg_metric),
3193+
max3 = MAX(x),
3194+
max4 = MAX(x) WHERE x > 3
3195+
;
3196+
3197+
max1:double | max2:double | max3:integer | max4:integer
3198+
3.0 | 5.0 | 5 | 5
3199+
;
3200+
3201+
minWithConditions
3202+
required_capability: stats_with_filtered_surrogate_fixed
3203+
required_capability: aggregate_metric_double_convert_to
3204+
3205+
ROW x = [1, 2, 3, 4, 5]
3206+
| MV_EXPAND x
3207+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3208+
| STATS min1 = MIN(agg_metric) WHERE x <= 3,
3209+
min2 = MIN(agg_metric),
3210+
min3 = MIN(x),
3211+
min4 = MIN(x) WHERE x > 3
3212+
;
3213+
3214+
min1:double | min2:double | min3:integer | min4:integer
3215+
1.0 | 1.0 | 1 | 4
3216+
;
3217+
3218+
countWithConditions
3219+
required_capability: stats_with_filtered_surrogate_fixed
3220+
required_capability: aggregate_metric_double_convert_to
3221+
3222+
ROW x = [1, 2, 3, 4, 5]
3223+
| MV_EXPAND x
3224+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3225+
| STATS count1 = COUNT(x) WHERE x >= 3,
3226+
count2 = COUNT(x),
3227+
count3 = COUNT(agg_metric),
3228+
count4 = COUNT(agg_metric) WHERE x >=3,
3229+
count5 = COUNT(4) WHERE x >= 3,
3230+
count6 = COUNT(*) WHERE x >= 3,
3231+
count7 = COUNT([1,2,3]) WHERE x >= 3,
3232+
count8 = COUNT([1,2,3])
3233+
;
3234+
3235+
count1:long | count2:long | count3:long | count4:long | count5:long | count6:long | count7:long | count8:long
3236+
3 | 5 | 5 | 3 | 3 | 3 | 9 | 15
3237+
;
3238+
3239+
countDistinctWithConditions
3240+
required_capability: stats_with_filtered_surrogate_fixed
3241+
3242+
ROW x = [1, 2, 3, 4, 5]
3243+
| MV_EXPAND x
3244+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3245+
| STATS count1 = COUNT_DISTINCT(x) WHERE x <= 3,
3246+
count2 = COUNT_DISTINCT(x),
3247+
count3 = COUNT_DISTINCT(1) WHERE x <= 3,
3248+
count4 = COUNT_DISTINCT(1)
3249+
;
3250+
3251+
count1:long | count2:long | count3:long | count4:long
3252+
3 | 5 | 1 | 1
3253+
;
3254+
3255+
avgWithConditions
3256+
required_capability: stats_with_filtered_surrogate_fixed
3257+
required_capability: aggregate_metric_double_convert_to
3258+
3259+
ROW x = [1, 2, 3, 4, 5]
3260+
| MV_EXPAND x
3261+
| EVAL agg_metric = TO_AGGREGATE_METRIC_DOUBLE(x)
3262+
| STATS avg1 = AVG(x) WHERE x <= 3,
3263+
avg2 = AVG(x),
3264+
avg3 = AVG(agg_metric) WHERE x <=3,
3265+
avg4 = AVG(agg_metric)
3266+
;
3267+
3268+
avg1:double | avg2:double | avg3:double | avg4:double
3269+
2.0 | 3.0 | 2.0 | 3.0
3270+
;
3271+
3272+
percentileWithConditions
3273+
required_capability: stats_with_filtered_surrogate_fixed
3274+
3275+
ROW x = [1, 2, 3, 4, 5]
3276+
| MV_EXPAND x
3277+
| STATS percentile1 = PERCENTILE(x, 50) WHERE x <= 3,
3278+
percentile2 = PERCENTILE(x, 50)
3279+
;
3280+
3281+
percentile1:double | percentile2:double
3282+
2.0 | 3.0
3283+
;
3284+
3285+
medianWithConditions
3286+
required_capability: stats_with_filtered_surrogate_fixed
3287+
3288+
ROW x = [1, 2, 3, 4, 5]
3289+
| MV_EXPAND x
3290+
| STATS median1 = MEDIAN(x) WHERE x <= 3,
3291+
median2 = MEDIAN(x),
3292+
median3 = MEDIAN([5,6,7,8,9]) WHERE x <= 3,
3293+
median4 = MEDIAN([5,6,7,8,9])
3294+
;
3295+
3296+
median1:double | median2:double | median3:double | median4:double
3297+
2.0 | 3.0 | 7.0 | 7.0
3298+
;
3299+
3300+
medianAbsoluteDeviationWithConditions
3301+
required_capability: stats_with_filtered_surrogate_fixed
3302+
3303+
ROW x = [1, 3, 4, 7, 11, 18]
3304+
| MV_EXPAND x
3305+
| STATS median_dev1 = MEDIAN_ABSOLUTE_DEVIATION(x) WHERE x <= 3,
3306+
median_dev2 = MEDIAN_ABSOLUTE_DEVIATION(x),
3307+
median_dev3 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25]) WHERE x <= 3,
3308+
median_dev4 = MEDIAN_ABSOLUTE_DEVIATION([3, 11, 14, 25])
3309+
;
3310+
3311+
median_dev1:double | median_dev2:double | median_dev3:double | median_dev4:double
3312+
1.0 | 3.5 | 5.5 | 5.5
3313+
;
3314+
3315+
topWithConditions
3316+
required_capability: stats_with_filtered_surrogate_fixed
3317+
3318+
FROM employees
3319+
| STATS min1 = TOP(emp_no, 1, "ASC") WHERE emp_no > 10010,
3320+
min2 = TOP(emp_no, 2, "ASC") WHERE emp_no > 10010,
3321+
max1 = TOP(emp_no, 1, "DESC") WHERE emp_no < 10080,
3322+
max2 = TOP(emp_no, 2, "DESC") WHERE emp_no < 10080
3323+
;
3324+
3325+
min1:integer | min2:integer | max1:integer | max2:integer
3326+
10011 | [10011, 10012] | 10079 | [10079, 10078]
3327+
;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,14 @@ public enum Cap {
14691469
*/
14701470
FN_PRESENT,
14711471

1472+
/**
1473+
* Bugfix for STATS {{expression}} WHERE {{condition}} when the
1474+
* expression is replaced by something else on planning
1475+
* e.g. STATS SUM(1) WHERE x==3 is replaced by
1476+
* STATS MV_SUM(const)*COUNT(*) WHERE x == 3.
1477+
*/
1478+
STATS_WITH_FILTERED_SURROGATE_FIXED,
1479+
14721480
/**
14731481
* TO_DENSE_VECTOR function.
14741482
*/

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Count.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ public Expression surrogate() {
152152
var s = source();
153153
var field = field();
154154
if (field.dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
155-
return new Sum(s, FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT));
155+
return new Sum(
156+
s,
157+
FromAggregateMetricDouble.withMetric(source(), field, AggregateMetricDoubleBlockBuilder.Metric.COUNT),
158+
filter()
159+
);
156160
}
157161

158162
if (field.foldable()) {
@@ -169,7 +173,7 @@ public Expression surrogate() {
169173
return new Mul(
170174
s,
171175
new Coalesce(s, new MvCount(s, field), List.of(new Literal(s, 0, DataType.INTEGER))),
172-
new Count(s, Literal.keyword(s, StringUtils.WILDCARD))
176+
new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())
173177
);
174178
}
175179

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ public final AggregatorFunctionSupplier supplier() {
160160
@Override
161161
public Expression surrogate() {
162162
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
163-
return new Max(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX));
163+
return new Max(
164+
source(),
165+
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MAX),
166+
filter()
167+
);
164168
}
165169
return field().foldable() ? new MvMax(source(), field()) : null;
166170
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Median.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,6 @@ public Expression surrogate() {
117117

118118
return field.foldable()
119119
? new MvMedian(s, new ToDouble(s, field))
120-
: new Percentile(source(), field(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
120+
: new Percentile(source(), field(), filter(), new Literal(source(), (int) QuantileStates.MEDIAN, DataType.INTEGER));
121121
}
122122
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,11 @@ public final AggregatorFunctionSupplier supplier() {
160160
@Override
161161
public Expression surrogate() {
162162
if (field().dataType() == DataType.AGGREGATE_METRIC_DOUBLE) {
163-
return new Min(source(), FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN));
163+
return new Min(
164+
source(),
165+
FromAggregateMetricDouble.withMetric(source(), field(), AggregateMetricDoubleBlockBuilder.Metric.MIN),
166+
filter()
167+
);
164168
}
165169
return field().foldable() ? new MvMin(source(), field()) : null;
166170
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ public Sum(Source source, @Param(name = "number", type = { "aggregate_metric_dou
6666
this(source, field, Literal.TRUE, SummationMode.COMPENSATED_LITERAL);
6767
}
6868

69+
public Sum(
70+
Source source,
71+
@Param(name = "number", type = { "aggregate_metric_double", "double", "integer", "long" }) Expression field,
72+
Expression filter
73+
) {
74+
this(source, field, filter, SummationMode.COMPENSATED_LITERAL);
75+
}
76+
6977
public Sum(Source source, Expression field, Expression filter, Expression summationMode) {
7078
super(source, field, filter, List.of(summationMode));
7179
this.summationMode = summationMode;
@@ -163,6 +171,6 @@ public Expression surrogate() {
163171
}
164172

165173
// SUM(const) is equivalent to MV_SUM(const)*COUNT(*).
166-
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD))) : null;
174+
return field.foldable() ? new Mul(s, new MvSum(s, field), new Count(s, Literal.keyword(s, StringUtils.WILDCARD), filter())) : null;
167175
}
168176
}

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,9 @@ public Expression surrogate() {
289289
var s = source();
290290
if (orderField() instanceof Literal && limitField() instanceof Literal && limitValue() == 1) {
291291
if (orderValue()) {
292-
return new Min(s, field());
292+
return new Min(s, field(), filter());
293293
} else {
294-
return new Max(s, field());
294+
return new Max(s, field(), filter());
295295
}
296296
}
297297
return null;

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,9 @@ public Expression surrogate() {
160160
return new MvAvg(s, field);
161161
}
162162
if (weight.foldable()) {
163-
return new Div(s, new Sum(s, field), new Count(s, field), dataType());
163+
return new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType());
164164
} else {
165-
return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType());
165+
return new Div(s, new Sum(s, new Mul(s, field, weight), filter()), new Sum(s, weight, filter()), dataType());
166166
}
167167
}
168168

0 commit comments

Comments
 (0)