Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d1b7692

Browse files
GH-3002 - Introduce vector search on repository level
Signed-off-by: Gerrit Meier <meistermeier@gmail.com> Closes #3002
1 parent aed6ab6 commit d1b7692

22 files changed

+877
-12
lines changed

‎src/main/antora/modules/ROOT/nav.adoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
** xref:repositories/sdn-extension.adoc[]
2727
** xref:repositories/query-keywords-reference.adoc[]
2828
** xref:repositories/query-return-types-reference.adoc[]
29+
** xref:repositories/vector-search.adoc[]
2930
3031
* xref:repositories/projections.adoc[]
3132
** xref:projections/sdn-projections.adoc[]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
[[sdn-vector-search]]
2+
= Neo4j Vector Search
3+
4+
== The `@VectorSearch` annotation
5+
Spring Data Neo4j supports Neo4j's vector search on the repository level by using the `@VectorSearch` annotation.
6+
For this to work, Neo4j needs to have a vector index in place.
7+
How to create a vector index is explained in the https://neo4j.com/docs/cypher-manual/current/indexes/search-performance-indexes/managing-indexes/[Neo4j documentation].
8+
9+
NOTE: It's not required to have any (Spring Data) Vector typed property be defined in the domain entities for this to work
10+
because the search operates exclusively on the index.
11+
12+
The `@VectorSearch` annotation requires two arguments:
13+
The name of the vector index to be used and the number of nearest neighbours.
14+
15+
For a general vector search over the whole domain, it's possible to use a derived finder method without any property.
16+
[source,java,indent=0,tabsize=4]
17+
----
18+
include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findall]
19+
----
20+
21+
The vector index can be combined with any property-based finder method to filter down the results.
22+
23+
NOTE: For technical reasons, the vector search will always be executed before the property search gets invoked.
24+
E.g. if the property filter looks for a person named "Helge",
25+
but the vector search only yields "Hannes", there won't be a result.
26+
27+
[source,java,indent=0,tabsize=4]
28+
----
29+
include::example$integration/imperative/VectorSearchIT.java[tags=sdn-vector-search.usage;sdn-vector-search.usage.findbyproperty]
30+
----

‎src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,12 @@ private Optional<Neo4jClient.RecordFetchSpec<T>> createFetchSpec() {
14891489
statement = nodesAndRelationshipsById.toStatement(entityMetaData);
14901490
}
14911491
else {
1492-
statement = queryFragments.toStatement();
1492+
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1493+
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1494+
}
1495+
else {
1496+
statement = queryFragments.toStatement();
1497+
}
14931498
}
14941499
cypherQuery = Neo4jTemplate.this.renderer.render(statement);
14951500
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);

‎src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,8 +1408,13 @@ public <T> Mono<ExecutableQuery<T>> toExecutableQuery(PreparedQuery<T> preparedQ
14081408
return new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec);
14091409
});
14101410
}
1411-
1412-
Statement statement = queryFragments.toStatement();
1411+
Statement statement = null;
1412+
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1413+
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1414+
}
1415+
else {
1416+
statement = queryFragments.toStatement();
1417+
}
14131418
cypherQuery = this.renderer.render(statement);
14141419
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);
14151420
}

‎src/main/java/org/springframework/data/neo4j/core/mapping/Constants.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,21 @@ public final class Constants {
163163
*/
164164
public static final String TO_ID_PARAMETER_NAME = "toId";
165165

166+
/**
167+
* The name SDN uses for vector search score.
168+
*/
169+
public static final String NAME_OF_SCORE = "__score__";
170+
171+
/**
172+
* Vector search vector parameter name.
173+
*/
174+
public static final String VECTOR_SEARCH_VECTOR_PARAMETER = "vectorSearchParam";
175+
176+
/**
177+
* Vector search score parameter name.
178+
*/
179+
public static final String VECTOR_SEARCH_SCORE_PARAMETER = "scoreParam";
180+
166181
private Constants() {
167182
}
168183

‎src/main/java/org/springframework/data/neo4j/repository/query/AbstractNeo4jQuery.java

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.function.LongSupplier;
2424
import java.util.function.Supplier;
2525
import java.util.function.UnaryOperator;
26+
import java.util.stream.Collectors;
2627

2728
import org.jspecify.annotations.Nullable;
2829
import org.neo4j.driver.types.MapAccessor;
@@ -32,6 +33,8 @@
3233
import org.springframework.data.domain.Page;
3334
import org.springframework.data.domain.PageRequest;
3435
import org.springframework.data.domain.Pageable;
36+
import org.springframework.data.domain.SearchResult;
37+
import org.springframework.data.domain.SearchResults;
3538
import org.springframework.data.domain.Slice;
3639
import org.springframework.data.domain.SliceImpl;
3740
import org.springframework.data.geo.GeoPage;
@@ -98,11 +101,30 @@ boolean isGeoNearQuery() {
98101
return GeoPage.class.isAssignableFrom(returnType);
99102
}
100103

104+
boolean isVectorSearchQuery() {
105+
var repositoryMethod = this.queryMethod.getMethod();
106+
Class<?> returnType = repositoryMethod.getReturnType();
107+
108+
for (Class<?> type : Neo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
109+
if (type.isAssignableFrom(returnType)) {
110+
return true;
111+
}
112+
}
113+
114+
if (Iterable.class.isAssignableFrom(returnType)) {
115+
TypeInformation<?> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
116+
return from.getComponentType() != null && SearchResult.class.equals(from.getComponentType().getType());
117+
}
118+
119+
return false;
120+
}
121+
101122
@Override
102123
@Nullable public final Object execute(Object[] parameters) {
103124

104125
boolean incrementLimit = this.queryMethod.incrementLimit();
105126
boolean geoNearQuery = isGeoNearQuery();
127+
boolean vectorSearchQuery = isVectorSearchQuery();
106128
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
107129
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
108130

@@ -111,7 +133,7 @@ boolean isGeoNearQuery() {
111133
ReturnedType returnedType = resultProcessor.getReturnedType();
112134
PreparedQuery<?> preparedQuery = prepareQuery(returnedType.getReturnedType(),
113135
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
114-
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
136+
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
115137
incrementLimit ? l -> l + 1 : UnaryOperator.identity());
116138

117139
Object rawResult = new Neo4jQueryExecution.DefaultQueryExecution(this.neo4jOperations).execute(preparedQuery,
@@ -143,6 +165,9 @@ else if (this.queryMethod.isScrollQuery()) {
143165
else if (geoNearQuery) {
144166
rawResult = newGeoResults(rawResult);
145167
}
168+
else if (this.queryMethod.isSearchQuery()) {
169+
rawResult = createSearchResult((List<?>) rawResult, returnedType.getReturnedType());
170+
}
146171

147172
return resultProcessor.processResult(rawResult, preparingConverter);
148173
}
@@ -182,6 +207,13 @@ private Slice<?> createSlice(boolean incrementLimit, Neo4jParameterAccessor para
182207
}
183208
}
184209

210+
private <T> SearchResults<?> createSearchResult(List<?> rawResult, Class<T> returnedType) {
211+
List<SearchResult<T>> searchResults = rawResult.stream()
212+
.map(rawValue -> (SearchResult<T>) rawValue)
213+
.collect(Collectors.toUnmodifiableList());
214+
return new SearchResults<>(searchResults);
215+
}
216+
185217
protected abstract <T> PreparedQuery<T> prepareQuery(Class<T> returnedType,
186218
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jParameterAccessor parameterAccessor,
187219
@Nullable Neo4jQueryType queryType,

‎src/main/java/org/springframework/data/neo4j/repository/query/AbstractReactiveNeo4jQuery.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import org.neo4j.driver.types.MapAccessor;
2525
import org.neo4j.driver.types.TypeSystem;
2626
import reactor.core.publisher.Flux;
27+
import reactor.core.publisher.Mono;
2728

2829
import org.springframework.core.convert.converter.Converter;
30+
import org.springframework.data.domain.SearchResult;
2931
import org.springframework.data.geo.GeoResult;
3032
import org.springframework.data.neo4j.core.PreparedQuery;
3133
import org.springframework.data.neo4j.core.PropertyFilterSupport;
@@ -89,11 +91,31 @@ boolean isGeoNearQuery() {
8991
return false;
9092
}
9193

94+
boolean isVectorSearchQuery() {
95+
var repositoryMethod = this.queryMethod.getMethod();
96+
Class<?> returnType = repositoryMethod.getReturnType();
97+
98+
for (Class<?> type : ReactiveNeo4jQueryMethod.VECTOR_SEARCH_RESULTS) {
99+
if (type.isAssignableFrom(returnType)) {
100+
return true;
101+
}
102+
}
103+
104+
if (Flux.class.isAssignableFrom(returnType) || Mono.class.isAssignableFrom(returnType)) {
105+
TypeInformation<?> from = TypeInformation.fromReturnTypeOf(repositoryMethod);
106+
TypeInformation<?> componentType = from.getComponentType();
107+
return componentType != null && SearchResult.class.equals(componentType.getType());
108+
}
109+
110+
return false;
111+
}
112+
92113
@Override
93114
@Nullable public final Object execute(Object[] parameters) {
94115

95116
boolean incrementLimit = this.queryMethod.incrementLimit();
96117
boolean geoNearQuery = isGeoNearQuery();
118+
boolean vectorSearchQuery = isVectorSearchQuery();
97119
Neo4jParameterAccessor parameterAccessor = new Neo4jParameterAccessor(
98120
(Neo4jQueryMethod.Neo4jParameters) this.queryMethod.getParameters(), parameters);
99121
ResultProcessor resultProcessor = this.queryMethod.getResultProcessor()
@@ -102,7 +124,7 @@ boolean isGeoNearQuery() {
102124
ReturnedType returnedType = resultProcessor.getReturnedType();
103125
PreparedQuery<?> preparedQuery = prepareQuery(returnedType.getReturnedType(),
104126
PropertyFilterSupport.getInputProperties(resultProcessor, this.factory, this.mappingContext),
105-
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
127+
parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery, vectorSearchQuery),
106128
incrementLimit ? l -> l + 1 : UnaryOperator.identity());
107129

108130
Object rawResult = new Neo4jQueryExecution.ReactiveQueryExecution(this.neo4jOperations).execute(preparedQuery,
@@ -126,10 +148,17 @@ parameterAccessor, null, getMappingFunction(resultProcessor, geoNearQuery),
126148
.map(rawResultList -> createWindow(resultProcessor, incrementLimit, parameterAccessor, rawResultList,
127149
preparedQuery.getQueryFragmentsAndParameters()));
128150
}
151+
else if (this.queryMethod.isSearchQuery()) {
152+
rawResult = createSearchResult((Flux<?>) rawResult, returnedType.getReturnedType());
153+
}
129154

130155
return resultProcessor.processResult(rawResult, preparingConverter);
131156
}
132157

158+
private <T> Flux<SearchResult<?>> createSearchResult(Flux<?> rawResult, Class<T> returnedType) {
159+
return rawResult.map(rawValue -> (SearchResult<T>) rawValue);
160+
}
161+
133162
protected abstract <T extends Object> PreparedQuery<T> prepareQuery(Class<T> returnedType,
134163
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jParameterAccessor parameterAccessor,
135164
@Nullable Neo4jQueryType queryType, Supplier<BiFunction<TypeSystem, MapAccessor, ?>> mappingFunction,

‎src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@
4545
import org.springframework.data.domain.OffsetScrollPosition;
4646
import org.springframework.data.domain.Pageable;
4747
import org.springframework.data.domain.Range;
48+
import org.springframework.data.domain.Score;
4849
import org.springframework.data.domain.ScrollPosition;
4950
import org.springframework.data.domain.Sort;
51+
import org.springframework.data.domain.Vector;
5052
import org.springframework.data.geo.Box;
5153
import org.springframework.data.geo.Circle;
5254
import org.springframework.data.geo.Distance;
@@ -60,7 +62,6 @@
6062
import org.springframework.data.neo4j.core.mapping.Neo4jPersistentProperty;
6163
import org.springframework.data.neo4j.core.mapping.NodeDescription;
6264
import org.springframework.data.neo4j.core.mapping.PropertyFilter;
63-
import org.springframework.data.repository.query.QueryMethod;
6465
import org.springframework.data.repository.query.parser.AbstractQueryCreator;
6566
import org.springframework.data.repository.query.parser.Part;
6667
import org.springframework.data.repository.query.parser.PartTree;
@@ -119,14 +120,22 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
119120

120121
private final boolean keysetRequiresSort;
121122

122-
private final List<Expression> distanceExpressions = new ArrayList<>();
123+
private final List<Expression> additionalReturnExpression = new ArrayList<>();
123124

124125
/**
125126
* Can be used to modify the limit of a paged or sliced query.
126127
*/
127128
private final UnaryOperator<Integer> limitModifier;
128129

129-
CypherQueryCreator(Neo4jMappingContext mappingContext, QueryMethod queryMethod, Class<?> domainType,
130+
private final Neo4jQueryMethod queryMethod;
131+
132+
@Nullable
133+
private final Vector vectorSearchParameter;
134+
135+
@Nullable
136+
private final Score scoreParameter;
137+
138+
CypherQueryCreator(Neo4jMappingContext mappingContext, Neo4jQueryMethod queryMethod, Class<?> domainType,
130139
Neo4jQueryType queryType, PartTree tree, Neo4jParameterAccessor actualParameters,
131140
Collection<PropertyFilter.ProjectedPath> includedProperties,
132141
BiFunction<Object, Neo4jPersistentPropertyConverter<?>, Object> parameterConversion,
@@ -148,6 +157,8 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
148157

149158
this.pagingParameter = actualParameters.getPageable();
150159
this.scrollPosition = actualParameters.getScrollPosition();
160+
this.vectorSearchParameter = actualParameters.getVector();
161+
this.scoreParameter = actualParameters.getScore();
151162
this.limitModifier = limitModifier;
152163

153164
AtomicInteger symbolicNameIndex = new AtomicInteger();
@@ -160,6 +171,7 @@ final class CypherQueryCreator extends AbstractQueryCreator<QueryFragmentsAndPar
160171

161172
this.keysetRequiresSort = queryMethod.isScrollQuery()
162173
&& actualParameters.getScrollPosition() instanceof KeysetScrollPosition;
174+
this.queryMethod = queryMethod;
163175
}
164176

165177
@Override
@@ -196,6 +208,21 @@ protected QueryFragmentsAndParameters complete(@Nullable Condition condition, So
196208
if (this.keysetRequiresSort && theSort.isUnsorted()) {
197209
throw new UnsupportedOperationException("Unsorted keyset based scrolling is not supported.");
198210
}
211+
if (this.queryMethod.hasVectorSearchAnnotation()) {
212+
var vectorSearchAnnotation = this.queryMethod.getVectorSearchAnnotation().orElseThrow();
213+
var indexName = vectorSearchAnnotation.indexName();
214+
var numberOfNodes = vectorSearchAnnotation.numberOfNodes();
215+
convertedParameters.put(Constants.VECTOR_SEARCH_VECTOR_PARAMETER,
216+
this.vectorSearchParameter.toDoubleArray());
217+
if (this.scoreParameter != null) {
218+
convertedParameters.put(Constants.VECTOR_SEARCH_SCORE_PARAMETER, this.scoreParameter.getValue());
219+
}
220+
var vectorSearchFragment = new VectorSearchFragment(indexName, numberOfNodes,
221+
(this.scoreParameter != null) ? this.scoreParameter.getValue() : null);
222+
var queryFragmentsAndParameters = new QueryFragmentsAndParameters(this.nodeDescription, queryFragments,
223+
vectorSearchFragment, convertedParameters, theSort);
224+
return queryFragmentsAndParameters;
225+
}
199226
return new QueryFragmentsAndParameters(this.nodeDescription, queryFragments, convertedParameters, theSort);
200227
}
201228

@@ -274,11 +301,14 @@ else if (this.scrollPosition instanceof OffsetScrollPosition offsetScrollPositio
274301
? this.maxResults.intValue() : this.pagingParameter.getPageSize()));
275302
}
276303

304+
if (this.queryMethod.hasVectorSearchAnnotation()) {
305+
this.additionalReturnExpression.add(Cypher.name(Constants.NAME_OF_SCORE));
306+
}
277307
var finalSortItems = new ArrayList<>(this.sortItems);
278308
theSort.stream().map(CypherAdapterUtils.sortAdapterFor(this.nodeDescription)).forEach(finalSortItems::add);
279309

280310
queryFragments.setReturnBasedOn(this.nodeDescription, this.includedProperties, this.isDistinct,
281-
this.distanceExpressions);
311+
this.additionalReturnExpression);
282312
queryFragments.setOrderBy(finalSortItems);
283313
}
284314

@@ -438,7 +468,7 @@ else if (p2.isPresent() && p2.get().value instanceof Point) {
438468
// property to be later retrieved and mapped
439469
Neo4jPersistentEntity<?> owner = (Neo4jPersistentEntity<?>) leafProperty.getOwner();
440470
String containerName = getContainerName(path, owner);
441-
this.distanceExpressions
471+
this.additionalReturnExpression
442472
.add(distanceFunction.as("__distance_" + containerName + "_" + leafProperty.getPropertyName() + "__"));
443473

444474
this.sortItems.add(distanceFunction.ascending());

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /