Skip to content

Support Sum/Avg in Android SDK #4735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ public void testTerminateDoesNotCrashWithFlyingCountQuery() {
}

@Test
public void testSnapshotEquals() {
public void testCountSnapshotEquals() {
CollectionReference collection =
testCollectionWithDocs(
map(
Expand Down Expand Up @@ -167,7 +167,7 @@ public void testSnapshotEquals() {
}

@Test
public void testCanRunCollectionGroupQuery() {
public void testCanRunCountCollectionGroupQuery() {
FirebaseFirestore db = testFirestore();
// Use .document() to get a random collection group name to use but ensure it starts with 'b'
// for predictable ordering.
Expand Down Expand Up @@ -201,7 +201,7 @@ public void testCanRunCollectionGroupQuery() {
}

@Test
public void testCanRunCountWithFiltersAndLimits() {
public void testCanRunCountAggregateWithFiltersAndLimits() {
CollectionReference collection =
testCollectionWithDocs(
map(
Expand Down Expand Up @@ -241,7 +241,7 @@ public void testCanRunCountOnNonExistentCollection() {
}

@Test
public void testFailWithoutNetwork() {
public void testCountFailWithoutNetwork() {
CollectionReference collection =
testCollectionWithDocs(
map(
Expand All @@ -261,7 +261,7 @@ public void testFailWithoutNetwork() {
}

@Test
public void testFailWithGoodMessageIfMissingIndex() {
public void testCountFailWithGoodMessageIfMissingIndex() {
assumeFalse(
"Skip this test when running against the Firestore emulator because the Firestore emulator "
+ "does not use indexes and never fails with a 'missing index' error",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.firebase.firestore;

import androidx.annotation.NonNull;
import androidx.annotation.Nullable;
import androidx.annotation.RestrictTo;
import java.util.Objects;

// TODO(sumavg): Remove the `hide` and scope annotations.
/** @hide */
@RestrictTo(RestrictTo.Scope.LIBRARY)
public abstract class AggregateField {
@Nullable private final FieldPath fieldPath;

@NonNull private final String operator;

@NonNull private final String alias;

private AggregateField(@Nullable FieldPath fieldPath, @NonNull String operator) {
this.fieldPath = fieldPath;
this.operator = operator;

// Use $operator_$field format if it's an aggregation of a specific field. For example: sum_foo.
// Use $operator format if there's no field. For example: count.
this.alias = operator + (fieldPath == null ? "" : "_" + fieldPath);
}

/**
* Returns the field on which the aggregation takes place. Returns an empty string if there's no
* field (e.g. for count).
*/
@RestrictTo(RestrictTo.Scope.LIBRARY)
@NonNull
public String getFieldPath() {
return fieldPath == null ? "" : fieldPath.toString();
}

/** Returns the alias used internally for this aggregate field. */
@RestrictTo(RestrictTo.Scope.LIBRARY)
@NonNull
public String getAlias() {
return alias;
}

/** Returns a string representation of this aggregation's operator. For example: "sum" */
@RestrictTo(RestrictTo.Scope.LIBRARY)
@NonNull
public String getOperator() {
return operator;
}

/**
* Returns true if the given object is equal to this object. Two `AggregateField` objects are
* considered equal if they have the same operator and operate on the same field.
*/
@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof AggregateField)) {
return false;
}

AggregateField otherAggregateField = (AggregateField) other;
if (fieldPath == null || otherAggregateField.fieldPath == null) {
return fieldPath == null && otherAggregateField.fieldPath == null;
}
return operator.equals(otherAggregateField.getOperator())
&& getFieldPath().equals(otherAggregateField.getFieldPath());
}

/** Calculates and returns the hash code for this object. */
@Override
public int hashCode() {
return Objects.hash(getOperator(), getFieldPath());
}

@NonNull
public static CountAggregateField count() {
return new CountAggregateField();
}

@NonNull
public static SumAggregateField sum(@NonNull String field) {
return new SumAggregateField(FieldPath.fromDotSeparatedPath(field));
}

@NonNull
public static SumAggregateField sum(@NonNull FieldPath fieldPath) {
return new SumAggregateField(fieldPath);
}

@NonNull
public static AverageAggregateField average(@NonNull String field) {
return new AverageAggregateField(FieldPath.fromDotSeparatedPath(field));
}

@NonNull
public static AverageAggregateField average(@NonNull FieldPath fieldPath) {
return new AverageAggregateField(fieldPath);
}

public static class CountAggregateField extends AggregateField {
private CountAggregateField() {
super(null, "count");
}
}

public static class SumAggregateField extends AggregateField {
private SumAggregateField(@NonNull FieldPath fieldPath) {
super(fieldPath, "sum");
}
}

public static class AverageAggregateField extends AggregateField {
private AverageAggregateField(@NonNull FieldPath fieldPath) {
super(fieldPath, "average");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
package com.google.firebase.firestore;

import androidx.annotation.NonNull;
import androidx.annotation.RestrictTo;
import com.google.android.gms.tasks.Task;
import com.google.android.gms.tasks.TaskCompletionSource;
import com.google.firebase.firestore.util.Executors;
import com.google.firebase.firestore.util.Preconditions;
import java.util.List;
import java.util.Objects;

/**
* A query that calculates aggregations over an underlying query.
Expand All @@ -29,10 +32,13 @@
*/
public class AggregateQuery {

private final Query query;
@NonNull private final Query query;

AggregateQuery(@NonNull Query query) {
@NonNull private final List<AggregateField> aggregateFieldList;

AggregateQuery(@NonNull Query query, @NonNull List<AggregateField> aggregateFieldList) {
this.query = query;
this.aggregateFieldList = aggregateFieldList;
}

/** Returns the query whose aggregations will be calculated by this object. */
Expand All @@ -41,6 +47,15 @@ public Query getQuery() {
return query;
}

/** Returns the AggregateFields included inside this object. */
// TODO(sumavg): Remove the `hide` and scope annotations.
/** @hide */
@RestrictTo(RestrictTo.Scope.LIBRARY)
@NonNull
public List<AggregateField> getAggregateFields() {
return aggregateFieldList;
}

/**
* Executes this query.
*
Expand All @@ -54,7 +69,7 @@ public Task<AggregateQuerySnapshot> get(@NonNull AggregateSource source) {
query
.firestore
.getClient()
.runCountQuery(query.query)
.runAggregateQuery(query.query, aggregateFieldList)
.continueWith(
Executors.DIRECT_EXECUTOR,
(task) -> {
Expand Down Expand Up @@ -90,7 +105,7 @@ public boolean equals(Object object) {
if (this == object) return true;
if (!(object instanceof AggregateQuery)) return false;
AggregateQuery other = (AggregateQuery) object;
return query.equals(other.query);
return query.equals(other.query) && aggregateFieldList.equals(other.aggregateFieldList);
}

/**
Expand All @@ -100,6 +115,6 @@ public boolean equals(Object object) {
*/
@Override
public int hashCode() {
return query.hashCode();
return Objects.hash(query, aggregateFieldList);
}
}
Loading