Expand Up
@@ -16,27 +16,20 @@
package com.mongodb.client.model.search;
import com.mongodb.MongoInterruptedException;
import com.mongodb.MongoNamespace;
import com.mongodb.client.model.Aggregates;
import com.mongodb.client.model.SearchIndexType;
import com.mongodb.client.test.CollectionHelper;
import com.mongodb.internal.operation.SearchIndexRequest;
import org.bson.BinaryVector;
import org.bson.BsonDocument;
import org.bson.Document;
import org.bson.codecs.DocumentCodec;
import org.bson.conversions.Bson;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Stream;
Expand All
@@ -57,7 +50,6 @@
import static com.mongodb.client.model.search.SearchPath.fieldPath;
import static com.mongodb.client.model.search.VectorSearchOptions.approximateVectorSearchOptions;
import static com.mongodb.client.model.search.VectorSearchOptions.exactVectorSearchOptions;
import static java.lang.String.format;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.assertAll;
Expand All
@@ -67,127 +59,73 @@
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.junit.jupiter.params.provider.Arguments.arguments;
class AggregatesBinaryVectorSearchIntegrationTest {
private static final String EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE =
"Exceeded maximum attempts waiting for Search Index creation in Atlas cluster. Index document: %s";
/**
* This test runs on an atlas qa cluster in the `javaExtraTests.binaryVectorTests` namespace.
* With readOnly user permissions.
* <p>
* With the following index:
* <code>
* {
* "name": "vector_search_index", "type": "vectorSearch",
* "definition": {"fields": [
* {"path": "int8Vector", "numDimensions": 5, "similarity": "cosine", "type": "vector"},
* {"path": "float32Vector", "numDimensions": 5, "similarity": "cosine", "type": "vector"},
* {"path": "legacyDoubleVector", "numDimensions": 5, "similarity": "cosine", "type": "vector"},
* {"path": "year", "type": "filter"}]}
* }
* </code>
* <p>
* And the following test data:
* <code>
* [{"_id":0, "int8Vector":{"$binary":{"base64":"AwAAAQIDBA==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwAXt9E4Ns2PPwgDD0B1H1ZA8Z2OQA==", "subType":"09"}},
* "legacyDoubleVector":[0.0001,1.12345,2.23456,3.34567,4.45678], "year":2016},
* {"_id":1, "int8Vector":{"$binary":{"base64":"AwABAgMEBQ==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwBHA4A/m+YHQAgDT0C7D4tA8Z2uQA==", "subType":"09"}},
* "legacyDoubleVector":[1.0001,2.12345,3.23456,4.34567,5.45678], "year":2017},
* {"_id":2, "int8Vector":{"$binary":{"base64":"AwACAwQFBg==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwBHAwBAm+ZHQISBh0C7D6tA8Z3OQA==", "subType":"09"}},
* "legacyDoubleVector":[2.0002,3.12345,4.23456,5.34567,6.45678], "year":2018}},
* {"_id":3, "int8Vector":{"$binary":{"base64":"AwADBAUGBw==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwDqBEBATfODQISBp0C7D8tA8Z3uQA==", "subType":"09"}},
* "legacyDoubleVector":[3.0003,4.12345,5.23456,6.34567,7.45678], "year":2019}},
* {"_id":4, "int8Vector":{"$binary":{"base64":"AwAEBQYHCA==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwBHA4BATfOjQISBx0C7D+tA+U4HQQ==", "subType":"09"}},
* "legacyDoubleVector":[4.0004,5.12345,6.23456,7.34567,8.45678], "year":2020}},
* {"_id":5, "int8Vector":{"$binary":{"base64":"AwAFBgcICQ==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwAZBKBATfPDQISB50DdhwVB+U4XQQ==", "subType":"09"}},
* "legacyDoubleVector":[5.0005,6.12345,7.23456,8.34567,9.45678], "year":2021}},
* {"_id":6, "int8Vector":{"$binary":{"base64":"AwAGBwgJCg==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwDqBMBATfPjQMLAA0HdhxVB+U4nQQ==", "subType":"09"}},
* "legacyDoubleVector":[6.0006,7.12345,8.23456,9.34567,10.45678], "year":2022}},
* {"_id":7, "int8Vector":{"$binary":{"base64":"AwAHCAkKCw==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwC8BeBAp/kBQcLAE0HdhyVB+U43QQ==", "subType":"09"}},
* "legacyDoubleVector":[7.0007,8.12345,9.23456,10.34567,11.45678], "year":2023}},
* {"_id":8, "int8Vector":{"$binary":{"base64":"AwAICQoLDA==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwBHAwBBp/kRQcLAI0HdhzVB+U5HQQ==", "subType":"09"}},
* "legacyDoubleVector":[8.0008,9.12345,10.23456,11.34567,12.45678], "year":2024}},
* {"_id":9, "int8Vector":{"$binary":{"base64":"AwAJCgsMDQ==", "subType":"09"}},
* "float32Vector":{"$binary":{"base64":"JwCwAxBBp/khQcLAM0Hdh0VB+U5XQQ==", "subType":"09"}},
* "legacyDoubleVector":[9.0009,10.12345,11.23456,12.34567,13.45678], "year":2025}]
* </code>
*/
class AggregatesBinaryVectorSearchIntegrationTest {
private static final MongoNamespace BINARY_VECTOR_NAMESPACE = new MongoNamespace("javaExtraTests", "binaryVectorTests");
private static final String VECTOR_INDEX = "vector_search_index";
private static final String VECTOR_FIELD_INT_8 = "int8Vector";
private static final String VECTOR_FIELD_FLOAT_32 = "float32Vector";
private static final String VECTOR_FIELD_LEGACY_DOUBLE_LIST = "legacyDoubleVector";
private static final int LIMIT = 5;
private static final String FIELD_YEAR = "year";
private static CollectionHelper<Document> collectionHelper;
private static final BsonDocument VECTOR_SEARCH_INDEX_DEFINITION = BsonDocument.parse(
"{"
+ " fields: ["
+ " {"
+ " path: '" + VECTOR_FIELD_INT_8 + "',"
+ " numDimensions: 5,"
+ " similarity: 'cosine',"
+ " type: 'vector',"
+ " },"
+ " {"
+ " path: '" + VECTOR_FIELD_FLOAT_32 + "',"
+ " numDimensions: 5,"
+ " similarity: 'cosine',"
+ " type: 'vector',"
+ " },"
+ " {"
+ " path: '" + VECTOR_FIELD_LEGACY_DOUBLE_LIST + "',"
+ " numDimensions: 5,"
+ " similarity: 'cosine',"
+ " type: 'vector',"
+ " },"
+ " {"
+ " path: '" + FIELD_YEAR + "',"
+ " type: 'filter',"
+ " },"
+ " ]"
+ "}");
@BeforeAll
static void beforeAll() {
assumeTrue(isAtlasSearchTest());
assumeTrue(serverVersionAtLeast(6, 0));
collectionHelper =
new CollectionHelper<>(new DocumentCodec(), new MongoNamespace("javaVectorSearchTest", AggregatesBinaryVectorSearchIntegrationTest.class.getSimpleName()));
collectionHelper.drop();
collectionHelper.insertDocuments(
new Document()
.append("_id", 0)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{0, 1, 2, 3, 4}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{0.0001f, 1.12345f, 2.23456f, 3.34567f, 4.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{0.0001, 1.12345, 2.23456, 3.34567, 4.45678})
.append(FIELD_YEAR, 2016),
new Document()
.append("_id", 1)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{1, 2, 3, 4, 5}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{1.0001f, 2.12345f, 3.23456f, 4.34567f, 5.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{1.0001, 2.12345, 3.23456, 4.34567, 5.45678})
.append(FIELD_YEAR, 2017),
new Document()
.append("_id", 2)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{2, 3, 4, 5, 6}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{2.0002f, 3.12345f, 4.23456f, 5.34567f, 6.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{2.0002, 3.12345, 4.23456, 5.34567, 6.45678})
.append(FIELD_YEAR, 2018),
new Document()
.append("_id", 3)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{3, 4, 5, 6, 7}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{3.0003f, 4.12345f, 5.23456f, 6.34567f, 7.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{3.0003, 4.12345, 5.23456, 6.34567, 7.45678})
.append(FIELD_YEAR, 2019),
new Document()
.append("_id", 4)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{4, 5, 6, 7, 8}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{4.0004f, 5.12345f, 6.23456f, 7.34567f, 8.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{4.0004, 5.12345, 6.23456, 7.34567, 8.45678})
.append(FIELD_YEAR, 2020),
new Document()
.append("_id", 5)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{5, 6, 7, 8, 9}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{5.0005f, 6.12345f, 7.23456f, 8.34567f, 9.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{5.0005, 6.12345, 7.23456, 8.34567, 9.45678})
.append(FIELD_YEAR, 2021),
new Document()
.append("_id", 6)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{6, 7, 8, 9, 10}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{6.0006f, 7.12345f, 8.23456f, 9.34567f, 10.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{6.0006, 7.12345, 8.23456, 9.34567, 10.45678})
.append(FIELD_YEAR, 2022),
new Document()
.append("_id", 7)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{7, 8, 9, 10, 11}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{7.0007f, 8.12345f, 9.23456f, 10.34567f, 11.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{7.0007, 8.12345, 9.23456, 10.34567, 11.45678})
.append(FIELD_YEAR, 2023),
new Document()
.append("_id", 8)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{8, 9, 10, 11, 12}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{8.0008f, 9.12345f, 10.23456f, 11.34567f, 12.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{8.0008, 9.12345, 10.23456, 11.34567, 12.45678})
.append(FIELD_YEAR, 2024),
new Document()
.append("_id", 9)
.append(VECTOR_FIELD_INT_8, BinaryVector.int8Vector(new byte[]{9, 10, 11, 12, 13}))
.append(VECTOR_FIELD_FLOAT_32, BinaryVector.floatVector(new float[]{9.0009f, 10.12345f, 11.23456f, 12.34567f, 13.45678f}))
.append(VECTOR_FIELD_LEGACY_DOUBLE_LIST, new double[]{9.0009, 10.12345, 11.23456, 12.34567, 13.45678})
.append(FIELD_YEAR, 2025)
);
collectionHelper.createSearchIndex(
new SearchIndexRequest(VECTOR_SEARCH_INDEX_DEFINITION, VECTOR_INDEX,
SearchIndexType.vectorSearch()));
awaitIndexCreation();
}
@AfterAll
static void afterAll() {
if (collectionHelper != null) {
collectionHelper.drop();
}
collectionHelper = new CollectionHelper<>(new DocumentCodec(), BINARY_VECTOR_NAMESPACE);
}
private static Stream<Arguments> provideSupportedVectors() {
Expand Down
Expand Up
@@ -268,7 +206,7 @@ void shouldSearchByVector(final BinaryVector vector,
final FieldSearchPath fieldSearchPath,
final VectorSearchOptions vectorSearchOptions) {
//given
List<Bson> pipeline = asList (
List<Bson> pipeline = singletonList (
Aggregates.vectorSearch(
fieldSearchPath,
vector,
Expand Down
Expand Up
@@ -327,27 +265,4 @@ private static void assertScoreIsDecreasing(final List<Document> aggregate) {
}
}
private static void awaitIndexCreation() {
int attempts = 10;
Optional<Document> searchIndex = Optional.empty();
while (attempts-- > 0) {
searchIndex = collectionHelper.listSearchIndex(VECTOR_INDEX);
if (searchIndex.filter(document -> document.getBoolean("queryable"))
.isPresent()) {
return;
}
try {
TimeUnit.SECONDS.sleep(5);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new MongoInterruptedException(null, e);
}
}
searchIndex.ifPresent(document ->
Assertions.fail(format(EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE, document.toJson())));
Assertions.fail(format(EXCEED_WAIT_ATTEMPTS_ERROR_MESSAGE, "null"));
}
}