Skip to content

Commit a510fd6

Browse files
Still open: binaryvector failures in integration tests
1 parent 4ac2283 commit a510fd6

File tree

5 files changed

+107
-64
lines changed

5 files changed

+107
-64
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MappingMongoJsonSchemaCreator.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,10 @@ private JsonSchemaProperty computeSchemaForProperty(List<MongoPersistentProperty
185185
Class<?> rawTargetType = computeTargetType(property); // target type before conversion
186186
Class<?> targetType = converter.getTypeMapper().getWriteTargetTypeFor(rawTargetType); // conversion target type
187187

188+
if((rawTargetType.isPrimitive() || ClassUtils.isPrimitiveArray(rawTargetType)) && targetType == Object.class) {
189+
targetType = rawTargetType;
190+
}
191+
188192
if (!isCollection(property) && ObjectUtils.nullSafeEquals(rawTargetType, targetType)) {
189193
if (property.isEntity() || mergeProperties.containsKey(stringPath)) {
190194
List<JsonSchemaProperty> targetProperties = new ArrayList<>();

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/MongoConverters.java

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import java.util.Collection;
2828
import java.util.Currency;
2929
import java.util.List;
30-
import java.util.Set;
3130
import java.util.UUID;
3231
import java.util.concurrent.atomic.AtomicInteger;
3332
import java.util.concurrent.atomic.AtomicLong;
@@ -51,7 +50,6 @@
5150
import org.springframework.core.convert.ConversionFailedException;
5251
import org.springframework.core.convert.TypeDescriptor;
5352
import org.springframework.core.convert.converter.ConditionalConverter;
54-
import org.springframework.core.convert.converter.ConditionalGenericConverter;
5553
import org.springframework.core.convert.converter.Converter;
5654
import org.springframework.core.convert.converter.ConverterFactory;
5755
import org.springframework.data.convert.ReadingConverter;
@@ -61,7 +59,6 @@
6159
import org.springframework.data.mongodb.core.mapping.MongoVector;
6260
import org.springframework.data.mongodb.core.query.Term;
6361
import org.springframework.data.mongodb.core.script.NamedMongoScript;
64-
import org.springframework.lang.Nullable;
6562
import org.springframework.util.Assert;
6663
import org.springframework.util.NumberUtils;
6764
import org.springframework.util.StringUtils;
@@ -121,6 +118,8 @@ static Collection<Object> getConvertersToRegister() {
121118
converters.add(reading(BsonUndefined.class, Object.class, it -> null));
122119
converters.add(reading(String.class, URI.class, URI::create).andWriting(URI::toString));
123120

121+
converters.add(ByteArrayConverterFactory.INSTANCE);
122+
124123
return converters;
125124
}
126125

@@ -475,35 +474,47 @@ public Vector convert(BinaryVector source) {
475474
}
476475
}
477476

478-
// @WritingConverter
479-
// enum BytesToBinaryVectorConverter implements ConditionalGenericConverter {
480-
// INSTANCE;
481-
//
482-
// @Nullable
483-
// public BinaryVector convert(byte[] source) {
484-
// return BinaryVector.int8Vector(source);
485-
// }
486-
//
487-
// @Override
488-
// public boolean matches(TypeDescriptor sourceType, TypeDescriptor targetType) {
489-
// return sourceType.getType() == byte[].class && targetType.getType() == BinaryVector.class;
490-
// }
491-
//
492-
// @Nullable
493-
// @Override
494-
// public Set<ConvertiblePair> getConvertibleTypes() {
495-
// return Set.of(new ConvertiblePair(byte[].class, BinaryVector.class));
496-
// }
497-
//
498-
// @Nullable
499-
// @Override
500-
// public Object convert(@Nullable Object source, TypeDescriptor sourceType, TypeDescriptor targetType) {
501-
// if(!matches(sourceType, targetType)) {
502-
// return source;
503-
// }
504-
// return convert((byte[]) source);
505-
// }
506-
// }
477+
@WritingConverter
478+
enum ByteArrayConverterFactory implements ConverterFactory<byte[], Object>, ConditionalConverter {
479+
480+
INSTANCE;
481+
482+
@Override
483+
public <T> Converter<byte[], T> getConverter(Class<T> targetType) {
484+
return new ByteArrayConverter<>(targetType);
485+
}
486+
487+
@Override
488+
public boolean matches(TypeDescriptor sourceType, TypeDescriptor targetType) {
489+
return targetType.getType() != Object.class && !sourceType.equals(targetType);
490+
}
491+
492+
private final static class ByteArrayConverter<T> implements Converter<byte[], T> {
493+
494+
private final Class<T> targetType;
495+
496+
/**
497+
* Creates a new {@link ByteArrayConverter} for the given target type.
498+
*
499+
* @param targetType must not be {@literal null}.
500+
*/
501+
public ByteArrayConverter(Class<T> targetType) {
502+
503+
Assert.notNull(targetType, "Target type must not be null");
504+
505+
this.targetType = targetType;
506+
}
507+
508+
@Override
509+
public T convert(byte[] source) {
510+
511+
if (this.targetType == BinaryVector.class) {
512+
return (T) BinaryVector.int8Vector(source);
513+
}
514+
return (T) source;
515+
}
516+
}
517+
}
507518

508519
/**
509520
* {@link ConverterFactory} implementation converting {@link AtomicLong} into {@link Long}.

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchTests.java

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,14 @@
4646
/**
4747
* @author Christoph Strobl
4848
*/
49-
@Testcontainers
49+
@Testcontainers(disabledWithoutDocker = true)
5050
public class VectorSearchTests {
5151

52-
static final String COLLECTION_NAME = "movies";
5352
public static final String SCORE_FIELD = "vector-search-tests";
54-
55-
private static @Container AtlasContainer atlasLocal = new AtlasContainer();
53+
static final String COLLECTION_NAME = "collection-1";
5654
static MongoClient client;
5755
static MongoTestTemplate template;
56+
private static @Container AtlasContainer atlasLocal = new AtlasContainer();
5857

5958
@BeforeAll
6059
static void beforeAll() throws InterruptedException {
@@ -73,21 +72,6 @@ static void afterAll() {
7372
template.dropCollection(WithVectorFields.class);
7473
}
7574

76-
@ParameterizedTest // GH-4706
77-
@MethodSource("aggregations")
78-
void searchWithScoreTests(VectorSearchOperation searchOperation) {
79-
80-
VectorSearchOperation $search = searchOperation.withSearchScore(SCORE_FIELD);
81-
82-
AggregationResults<Document> results = template.aggregate(Aggregation.newAggregation($search),
83-
WithVectorFields.class, Document.class);
84-
85-
assertThat(results).hasSize(10);
86-
assertScoreIsDecreasing(results);
87-
assertThat(results.iterator().next()).containsKey(SCORE_FIELD)
88-
.extracting(it -> it.get(SCORE_FIELD), InstanceOfAssertFactories.DOUBLE).isEqualByComparingTo(1D);
89-
}
90-
9175
private static Stream<Arguments> aggregations() {
9276

9377
return Stream.of(//
@@ -105,6 +89,21 @@ private static Stream<Arguments> aggregations() {
10589
.vector(new double[] { 1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d }) //
10690
.limit(10)//
10791
.numCandidates(20) //
92+
.searchType(SearchType.ANN)),
93+
Arguments.arguments(VectorSearchOperation.search("wrapper-index").path("int8vector") //
94+
.vector(BinaryVector.int8Vector(new byte[] { 0, 1, 2, 3, 4 })) //
95+
.limit(10)//
96+
.numCandidates(20) //
97+
.searchType(SearchType.ANN)),
98+
Arguments.arguments(VectorSearchOperation.search("wrapper-index").path("float32vector") //
99+
.vector(BinaryVector.floatVector(new float[] { 0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f })) //
100+
.limit(10)//
101+
.numCandidates(20) //
102+
.searchType(SearchType.ANN)),
103+
Arguments.arguments(VectorSearchOperation.search("wrapper-index").path("float64vector") //
104+
.vector(Vector.of(1.0001d, 2.12345d, 3.23456d, 4.34567d, 5.45678d)) //
105+
.limit(10)//
106+
.numCandidates(20) //
108107
.searchType(SearchType.ANN)));
109108
}
110109

@@ -145,6 +144,21 @@ private static void assertScoreIsDecreasing(final Iterable<Document> documents)
145144
}
146145
}
147146

147+
@ParameterizedTest // GH-4706
148+
@MethodSource("aggregations")
149+
void searchWithScoreTests(VectorSearchOperation searchOperation) {
150+
151+
VectorSearchOperation $search = searchOperation.withSearchScore(SCORE_FIELD);
152+
153+
AggregationResults<Document> results = template.aggregate(Aggregation.newAggregation($search),
154+
WithVectorFields.class, Document.class);
155+
156+
assertThat(results).hasSize(10);
157+
assertScoreIsDecreasing(results);
158+
assertThat(results.iterator().next()).containsKey(SCORE_FIELD)
159+
.extracting(it -> it.get(SCORE_FIELD), InstanceOfAssertFactories.DOUBLE).isEqualByComparingTo(1D);
160+
}
161+
148162
@org.springframework.data.mongodb.core.mapping.Document(COLLECTION_NAME)
149163
static class WithVectorFields {
150164

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2260,18 +2260,6 @@ void mapsDateAsObjectIdWhenAnnotatedWithFieldTargetType() {
22602260
.isEqualTo(new ObjectId(source.dateAsObjectId).getTimestamp());
22612261
}
22622262

2263-
@Test // DATAMONGO-2328
2264-
void mapsByteArrayAsVectorWhenAnnotatedWithFieldTargetType() {
2265-
2266-
WithExplicitTargetTypes source = new WithExplicitTargetTypes();
2267-
source.asVector = new byte[] { 0, 1, 2 };
2268-
2269-
org.bson.Document target = new org.bson.Document();
2270-
converter.write(source, target);
2271-
2272-
assertThatNoException().isThrownBy(() -> target.get("asVector", BinaryVector.class));
2273-
}
2274-
22752263
@Test // DATAMONGO-2328
22762264
void readsObjectIdAsDateWhenAnnotatedWithFieldTargetType() {
22772265

@@ -3392,6 +3380,31 @@ void shouldReadVectorValues() {
33923380
assertThat(withVector.embeddings.toDoubleArray()).contains(1.1d, 2.2d, 3.3d);
33933381
}
33943382

3383+
@Test // GH-4706
3384+
void mapsByteArrayAsVectorWhenAnnotatedWithFieldTargetType() {
3385+
3386+
WithExplicitTargetTypes source = new WithExplicitTargetTypes();
3387+
source.asVector = new byte[] { 0, 1, 2 };
3388+
3389+
org.bson.Document target = new org.bson.Document();
3390+
converter.write(source, target);
3391+
3392+
assertThatNoException().isThrownBy(() -> target.get("asVector", BinaryVector.class));
3393+
}
3394+
3395+
@Test // GH-4706
3396+
void writesByteArrayAsIsIfNoFieldInstructionsGiven() {
3397+
3398+
WithArrays source = new WithArrays();
3399+
source.arrayOfPrimitiveBytes = new byte[] { 0, 1, 2 };
3400+
3401+
org.bson.Document target = new org.bson.Document();
3402+
converter.write(source, target);
3403+
3404+
assertThat(target.get("arrayOfPrimitiveBytes", byte[].class)).isSameAs(source.arrayOfPrimitiveBytes);
3405+
3406+
}
3407+
33953408
org.bson.Document write(Object source) {
33963409

33973410
org.bson.Document target = new org.bson.Document();
@@ -3935,6 +3948,7 @@ public WithArrayInConstructor(String[] array) {
39353948

39363949
static class WithArrays {
39373950
String[] arrayOfStrings;
3951+
byte[] arrayOfPrimitiveBytes;
39383952
}
39393953

39403954
// DATAMONGO-1898

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/index/VectorIndexIntegrationTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
* @author Christoph Strobl
4747
* @author Mark Paluch
4848
*/
49-
@Testcontainers
49+
@Testcontainers(disabledWithoutDocker = true)
5050
class VectorIndexIntegrationTests {
5151

5252
private static @Container AtlasContainer atlasLocal = new AtlasContainer();

0 commit comments

Comments
 (0)