Skip to content

Commit 26de539

Browse files
authored
Support Sum/Avg in Android SDK (#4735)
* Tests if existing count test passes * add integration tests * Add test coverage * Clean up code * Address feedback * Address feedback2
1 parent eaefea8 commit 26de539

File tree

12 files changed

+1618
-55
lines changed

12 files changed

+1618
-55
lines changed

firebase-firestore/src/androidTest/java/com/google/firebase/firestore/AggregationTest.java

Lines changed: 1231 additions & 0 deletions
Large diffs are not rendered by default.

firebase-firestore/src/androidTest/java/com/google/firebase/firestore/CountTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ public void testTerminateDoesNotCrashWithFlyingCountQuery() {
135135
}
136136

137137
@Test
138-
public void testSnapshotEquals() {
138+
public void testCountSnapshotEquals() {
139139
CollectionReference collection =
140140
testCollectionWithDocs(
141141
map(
@@ -167,7 +167,7 @@ public void testSnapshotEquals() {
167167
}
168168

169169
@Test
170-
public void testCanRunCollectionGroupQuery() {
170+
public void testCanRunCountCollectionGroupQuery() {
171171
FirebaseFirestore db = testFirestore();
172172
// Use .document() to get a random collection group name to use but ensure it starts with 'b'
173173
// for predictable ordering.
@@ -201,7 +201,7 @@ public void testCanRunCollectionGroupQuery() {
201201
}
202202

203203
@Test
204-
public void testCanRunCountWithFiltersAndLimits() {
204+
public void testCanRunCountAggregateWithFiltersAndLimits() {
205205
CollectionReference collection =
206206
testCollectionWithDocs(
207207
map(
@@ -241,7 +241,7 @@ public void testCanRunCountOnNonExistentCollection() {
241241
}
242242

243243
@Test
244-
public void testFailWithoutNetwork() {
244+
public void testCountFailWithoutNetwork() {
245245
CollectionReference collection =
246246
testCollectionWithDocs(
247247
map(
@@ -261,7 +261,7 @@ public void testFailWithoutNetwork() {
261261
}
262262

263263
@Test
264-
public void testFailWithGoodMessageIfMissingIndex() {
264+
public void testCountFailWithGoodMessageIfMissingIndex() {
265265
assumeFalse(
266266
"Skip this test when running against the Firestore emulator because the Firestore emulator "
267267
+ "does not use indexes and never fails with a 'missing index' error",
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Copyright 2023 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package com.google.firebase.firestore;
16+
17+
import androidx.annotation.NonNull;
18+
import androidx.annotation.Nullable;
19+
import androidx.annotation.RestrictTo;
20+
import java.util.Objects;
21+
22+
// TODO(sumavg): Remove the `hide` and scope annotations.
23+
/** @hide */
24+
@RestrictTo(RestrictTo.Scope.LIBRARY)
25+
public abstract class AggregateField {
26+
@Nullable private final FieldPath fieldPath;
27+
28+
@NonNull private final String operator;
29+
30+
@NonNull private final String alias;
31+
32+
private AggregateField(@Nullable FieldPath fieldPath, @NonNull String operator) {
33+
this.fieldPath = fieldPath;
34+
this.operator = operator;
35+
36+
// Use $operator_$field format if it's an aggregation of a specific field. For example: sum_foo.
37+
// Use $operator format if there's no field. For example: count.
38+
this.alias = operator + (fieldPath == null ? "" : "_" + fieldPath);
39+
}
40+
41+
/**
42+
* Returns the field on which the aggregation takes place. Returns an empty string if there's no
43+
* field (e.g. for count).
44+
*/
45+
@RestrictTo(RestrictTo.Scope.LIBRARY)
46+
@NonNull
47+
public String getFieldPath() {
48+
return fieldPath == null ? "" : fieldPath.toString();
49+
}
50+
51+
/** Returns the alias used internally for this aggregate field. */
52+
@RestrictTo(RestrictTo.Scope.LIBRARY)
53+
@NonNull
54+
public String getAlias() {
55+
return alias;
56+
}
57+
58+
/** Returns a string representation of this aggregation's operator. For example: "sum" */
59+
@RestrictTo(RestrictTo.Scope.LIBRARY)
60+
@NonNull
61+
public String getOperator() {
62+
return operator;
63+
}
64+
65+
/**
66+
* Returns true if the given object is equal to this object. Two `AggregateField` objects are
67+
* considered equal if they have the same operator and operate on the same field.
68+
*/
69+
@Override
70+
public boolean equals(Object other) {
71+
if (this == other) {
72+
return true;
73+
}
74+
if (!(other instanceof AggregateField)) {
75+
return false;
76+
}
77+
78+
AggregateField otherAggregateField = (AggregateField) other;
79+
if (fieldPath == null || otherAggregateField.fieldPath == null) {
80+
return fieldPath == null && otherAggregateField.fieldPath == null;
81+
}
82+
return operator.equals(otherAggregateField.getOperator())
83+
&& getFieldPath().equals(otherAggregateField.getFieldPath());
84+
}
85+
86+
/** Calculates and returns the hash code for this object. */
87+
@Override
88+
public int hashCode() {
89+
return Objects.hash(getOperator(), getFieldPath());
90+
}
91+
92+
@NonNull
93+
public static CountAggregateField count() {
94+
return new CountAggregateField();
95+
}
96+
97+
@NonNull
98+
public static SumAggregateField sum(@NonNull String field) {
99+
return new SumAggregateField(FieldPath.fromDotSeparatedPath(field));
100+
}
101+
102+
@NonNull
103+
public static SumAggregateField sum(@NonNull FieldPath fieldPath) {
104+
return new SumAggregateField(fieldPath);
105+
}
106+
107+
@NonNull
108+
public static AverageAggregateField average(@NonNull String field) {
109+
return new AverageAggregateField(FieldPath.fromDotSeparatedPath(field));
110+
}
111+
112+
@NonNull
113+
public static AverageAggregateField average(@NonNull FieldPath fieldPath) {
114+
return new AverageAggregateField(fieldPath);
115+
}
116+
117+
public static class CountAggregateField extends AggregateField {
118+
private CountAggregateField() {
119+
super(null, "count");
120+
}
121+
}
122+
123+
public static class SumAggregateField extends AggregateField {
124+
private SumAggregateField(@NonNull FieldPath fieldPath) {
125+
super(fieldPath, "sum");
126+
}
127+
}
128+
129+
public static class AverageAggregateField extends AggregateField {
130+
private AverageAggregateField(@NonNull FieldPath fieldPath) {
131+
super(fieldPath, "average");
132+
}
133+
}
134+
}

firebase-firestore/src/main/java/com/google/firebase/firestore/AggregateQuery.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
package com.google.firebase.firestore;
1616

1717
import androidx.annotation.NonNull;
18+
import androidx.annotation.RestrictTo;
1819
import com.google.android.gms.tasks.Task;
1920
import com.google.android.gms.tasks.TaskCompletionSource;
2021
import com.google.firebase.firestore.util.Executors;
2122
import com.google.firebase.firestore.util.Preconditions;
23+
import java.util.List;
24+
import java.util.Objects;
2225

2326
/**
2427
* A query that calculates aggregations over an underlying query.
@@ -29,10 +32,13 @@
2932
*/
3033
public class AggregateQuery {
3134

32-
private final Query query;
35+
@NonNull private final Query query;
3336

34-
AggregateQuery(@NonNull Query query) {
37+
@NonNull private final List<AggregateField> aggregateFieldList;
38+
39+
AggregateQuery(@NonNull Query query, @NonNull List<AggregateField> aggregateFieldList) {
3540
this.query = query;
41+
this.aggregateFieldList = aggregateFieldList;
3642
}
3743

3844
/** Returns the query whose aggregations will be calculated by this object. */
@@ -41,6 +47,15 @@ public Query getQuery() {
4147
return query;
4248
}
4349

50+
/** Returns the AggregateFields included inside this object. */
51+
// TODO(sumavg): Remove the `hide` and scope annotations.
52+
/** @hide */
53+
@RestrictTo(RestrictTo.Scope.LIBRARY)
54+
@NonNull
55+
public List<AggregateField> getAggregateFields() {
56+
return aggregateFieldList;
57+
}
58+
4459
/**
4560
* Executes this query.
4661
*
@@ -54,7 +69,7 @@ public Task<AggregateQuerySnapshot> get(@NonNull AggregateSource source) {
5469
query
5570
.firestore
5671
.getClient()
57-
.runCountQuery(query.query)
72+
.runAggregateQuery(query.query, aggregateFieldList)
5873
.continueWith(
5974
Executors.DIRECT_EXECUTOR,
6075
(task) -> {
@@ -90,7 +105,7 @@ public boolean equals(Object object) {
90105
if (this == object) return true;
91106
if (!(object instanceof AggregateQuery)) return false;
92107
AggregateQuery other = (AggregateQuery) object;
93-
return query.equals(other.query);
108+
return query.equals(other.query) && aggregateFieldList.equals(other.aggregateFieldList);
94109
}
95110

96111
/**
@@ -100,6 +115,6 @@ public boolean equals(Object object) {
100115
*/
101116
@Override
102117
public int hashCode() {
103-
return query.hashCode();
118+
return Objects.hash(query, aggregateFieldList);
104119
}
105120
}

0 commit comments

Comments
 (0)