Skip to content

Commit 17c0c15

Browse files
committed
Support vectors in client side index.
1 parent 7ee1955 commit 17c0c15

File tree

5 files changed

+193
-2
lines changed

5 files changed

+193
-2
lines changed

packages/firestore/src/index/firestore_index_value_writer.ts

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717

1818
import { DocumentKey } from '../model/document_key';
19+
import { isVectorValue, VECTOR_MAP_VECTORS_KEY } from '../model/map_type';
1920
import {
2021
normalizeByteString,
2122
normalizeNumber,
@@ -41,6 +42,7 @@ const INDEX_TYPE_BLOB = 30;
4142
const INDEX_TYPE_REFERENCE = 37;
4243
const INDEX_TYPE_GEOPOINT = 45;
4344
const INDEX_TYPE_ARRAY = 50;
45+
const INDEX_TYPE_VECTOR = 53;
4446
const INDEX_TYPE_MAP = 55;
4547
const INDEX_TYPE_REFERENCE_SEGMENT = 60;
4648

@@ -121,6 +123,8 @@ export class FirestoreIndexValueWriter {
121123
} else if ('mapValue' in indexValue) {
122124
if (isMaxValue(indexValue)) {
123125
this.writeValueTypeLabel(encoder, Number.MAX_SAFE_INTEGER);
126+
} else if (isVectorValue(indexValue)) {
127+
this.writeIndexVector(indexValue.mapValue!, encoder);
124128
} else {
125129
this.writeIndexMap(indexValue.mapValue!, encoder);
126130
this.writeTruncationMarker(encoder);
@@ -160,6 +164,24 @@ export class FirestoreIndexValueWriter {
160164
}
161165
}
162166

167+
private writeIndexVector(
168+
mapIndexValue: MapValue,
169+
encoder: DirectionalIndexByteEncoder
170+
): void {
171+
const map = mapIndexValue.fields || {};
172+
this.writeValueTypeLabel(encoder, INDEX_TYPE_VECTOR);
173+
174+
// Vectors sort first by length
175+
const key = VECTOR_MAP_VECTORS_KEY;
176+
const length = map[key].arrayValue?.values?.length || 0;
177+
this.writeValueTypeLabel(encoder, INDEX_TYPE_NUMBER);
178+
encoder.writeNumber(normalizeNumber(length));
179+
180+
// Vectors then sort by position value
181+
this.writeIndexString(key, encoder);
182+
this.writeIndexValueAux(map[key], encoder);
183+
}
184+
163185
private writeIndexArray(
164186
arrayIndexValue: ArrayValue,
165187
encoder: DirectionalIndexByteEncoder

packages/firestore/test/integration/api/database.test.ts

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ import {
7676
withTestDoc,
7777
withTestDocAndInitialData,
7878
withNamedTestDbsOrSkipUnlessUsingEmulator,
79-
toDataArray
79+
toDataArray,
80+
checkOnlineAndOfflineResultsMatch
8081
} from '../util/helpers';
8182
import { DEFAULT_SETTINGS, DEFAULT_PROJECT_ID } from '../util/settings';
8283

@@ -743,6 +744,63 @@ apiDescribe('Database', persistence => {
743744
expect(toDataArray(watchSnapshot)).to.deep.equal(docsInOrder);
744745
});
745746
});
747+
748+
// eslint-disable-next-line no-restricted-properties
749+
(persistence.gc === 'lru' ? describe : describe.skip)('From Cache', () => {
750+
it('SDK orders vector field the same way online and offline', async () => {
751+
// Test data in the order that we expect the SDK to sort it.
752+
const docsInOrder = [
753+
{ embedding: [1, 2, 3, 4, 5, 6] },
754+
{ embedding: [100] },
755+
{ embedding: vector([Number.NEGATIVE_INFINITY]) },
756+
{ embedding: vector([-100]) },
757+
{ embedding: vector([100]) },
758+
{ embedding: vector([Number.POSITIVE_INFINITY]) },
759+
{ embedding: vector([1, 2]) },
760+
{ embedding: vector([2, 2]) },
761+
{ embedding: vector([1, 2, 3]) },
762+
{ embedding: vector([1, 2, 3, 4]) },
763+
{ embedding: vector([1, 2, 3, 4, 5]) },
764+
{ embedding: vector([1, 2, 100, 4, 4]) },
765+
{ embedding: vector([100, 2, 3, 4, 5]) },
766+
{ embedding: { HELLO: 'WORLD' } },
767+
{ embedding: { hello: 'world' } }
768+
];
769+
770+
const documentIds: string[] = [];
771+
const docs = docsInOrder.reduce((obj, doc, index) => {
772+
const documentId = index.toString();
773+
documentIds.push(documentId);
774+
obj[documentId] = doc;
775+
return obj;
776+
}, {} as { [i: string]: DocumentData });
777+
778+
return withTestCollection(persistence, docs, async randomCol => {
779+
const orderedQuery = query(randomCol, orderBy('embedding'));
780+
await checkOnlineAndOfflineResultsMatch(orderedQuery, ...documentIds);
781+
782+
const orderedQueryLessThan = query(
783+
randomCol,
784+
orderBy('embedding'),
785+
where('embedding', '<', vector([1, 2, 100, 4, 4]))
786+
);
787+
await checkOnlineAndOfflineResultsMatch(
788+
orderedQueryLessThan,
789+
...documentIds.slice(2, 11)
790+
);
791+
792+
const orderedQueryGreaterThan = query(
793+
randomCol,
794+
orderBy('embedding'),
795+
where('embedding', '>', vector([1, 2, 100, 4, 4]))
796+
);
797+
await checkOnlineAndOfflineResultsMatch(
798+
orderedQueryGreaterThan,
799+
...documentIds.slice(12, 13)
800+
);
801+
});
802+
});
803+
});
746804
});
747805

748806
describe('documents: ', () => {

packages/firestore/test/integration/util/helpers.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ function apiDescribeInternal(
182182
testSuite: (persistence: PersistenceMode) => void
183183
): void {
184184
const persistenceModes: PersistenceMode[] = [
185-
new MemoryEagerPersistenceMode()
185+
new MemoryEagerPersistenceMode(),
186+
new MemoryLruPersistenceMode()
186187
];
187188
if (isPersistenceAvailable()) {
188189
persistenceModes.push(new IndexedDbPersistenceMode());

packages/firestore/test/unit/index/firestore_index_value_writer.test.ts

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,5 +169,82 @@ describe('Firestore Index Value Writer', () => {
169169
compareIndexEncodedValues(value3, value4, IndexKind.DESCENDING)
170170
).to.equal(-1);
171171
});
172+
173+
it('sorts vector as a different type from array and map, with unique rules', () => {
174+
const vector1 = {
175+
mapValue: {
176+
fields: {
177+
'__type__': { stringValue: '__vector__' },
178+
'value': {
179+
arrayValue: { values: [{ doubleValue: 100 }] }
180+
}
181+
}
182+
}
183+
};
184+
const vector2 = {
185+
mapValue: {
186+
fields: {
187+
'__type__': { stringValue: '__vector__' },
188+
'value': {
189+
arrayValue: { values: [{ doubleValue: 1 }, { doubleValue: 2 }] }
190+
}
191+
}
192+
}
193+
};
194+
const vector3 = {
195+
mapValue: {
196+
fields: {
197+
'__type__': { stringValue: '__vector__' },
198+
'value': {
199+
arrayValue: { values: [{ doubleValue: 1 }, { doubleValue: 3 }] }
200+
}
201+
}
202+
}
203+
};
204+
const map1 = {
205+
mapValue: {
206+
fields: {
207+
'value': {
208+
arrayValue: { values: [{ doubleValue: 1 }, { doubleValue: 2 }] }
209+
}
210+
}
211+
}
212+
};
213+
const array1 = {
214+
arrayValue: { values: [{ doubleValue: 1 }, { doubleValue: 2 }] }
215+
};
216+
217+
// Array sorts before vector
218+
expect(
219+
compareIndexEncodedValues(array1, vector1, IndexKind.ASCENDING)
220+
).to.equal(-1);
221+
expect(
222+
compareIndexEncodedValues(array1, vector1, IndexKind.DESCENDING)
223+
).to.equal(1);
224+
225+
// Vector sorts before map
226+
expect(
227+
compareIndexEncodedValues(vector3, map1, IndexKind.ASCENDING)
228+
).to.equal(-1);
229+
expect(
230+
compareIndexEncodedValues(vector3, map1, IndexKind.DESCENDING)
231+
).to.equal(1);
232+
233+
// Shorter vectors sort before longer vectors
234+
expect(
235+
compareIndexEncodedValues(vector1, vector2, IndexKind.ASCENDING)
236+
).to.equal(-1);
237+
expect(
238+
compareIndexEncodedValues(vector1, vector2, IndexKind.DESCENDING)
239+
).to.equal(1);
240+
241+
// Vectors of the same length sort by value
242+
expect(
243+
compareIndexEncodedValues(vector2, vector3, IndexKind.ASCENDING)
244+
).to.equal(-1);
245+
expect(
246+
compareIndexEncodedValues(vector2, vector3, IndexKind.DESCENDING)
247+
).to.equal(1);
248+
});
172249
});
173250
});

packages/firestore/test/unit/local/index_manager.test.ts

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import {
3030
queryWithLimit,
3131
queryWithStartAt
3232
} from '../../../src/core/query';
33+
import { vector } from '../../../src/lite-api/field_value_impl';
3334
import { Timestamp } from '../../../src/lite-api/timestamp';
3435
import {
3536
displayNameForIndexType,
@@ -1100,6 +1101,38 @@ describe('IndexedDbIndexManager', async () => {
11001101
await verifyResults(q, 'coll/val6', 'coll/val3', 'coll/val4', 'coll/val5');
11011102
});
11021103

1104+
it('can index VectorValue fields', async () => {
1105+
await indexManager.addFieldIndex(
1106+
fieldIndex('coll', { fields: [['embedding', IndexKind.ASCENDING]] })
1107+
);
1108+
1109+
await addDoc('coll/arr1', { 'embedding': [1, 2, 3] });
1110+
await addDoc('coll/map2', { 'embedding': {} });
1111+
await addDoc('coll/doc3', { 'embedding': vector([4, 5, 6]) });
1112+
await addDoc('coll/doc4', { 'embedding': vector([5]) });
1113+
1114+
let q = queryWithAddedOrderBy(query('coll'), orderBy('embedding'));
1115+
await verifyResults(q, 'coll/arr1', 'coll/doc4', 'coll/doc3', 'coll/map2');
1116+
1117+
q = queryWithAddedFilter(
1118+
query('coll'),
1119+
filter('embedding', '==', vector([4, 5, 6]))
1120+
);
1121+
await verifyResults(q, 'coll/doc3');
1122+
1123+
q = queryWithAddedFilter(
1124+
query('coll'),
1125+
filter('embedding', '>', vector([4, 5, 6]))
1126+
);
1127+
await verifyResults(q, 'coll/map2');
1128+
1129+
q = queryWithAddedFilter(
1130+
query('coll'),
1131+
filter('embedding', '>=', vector([4]))
1132+
);
1133+
await verifyResults(q, 'coll/doc4', 'coll/doc3', 'coll/map2');
1134+
});
1135+
11031136
it('support advances queries', async () => {
11041137
// This test compares local query results with those received from the Java
11051138
// Server SDK.

0 commit comments

Comments
 (0)