Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 47d8a98

Browse files
fix: Make sure the vector search feature is null-safe, too. (#3044)
This change introduces the necessary checks and annotations to make a build utilising Nullaway succeed: ``` ./mvnw clean verify -Pnullaway ``` Also, `VectorSearchFragment` has been made `public`, as it is exposed via constructor of the public class `QueryFragmentsAndParameters`. Signed-off-by: Michael Simons <michael@simons.ac>
1 parent f7c355c commit 47d8a98

File tree

7 files changed

+25
-32
lines changed

7 files changed

+25
-32
lines changed

‎src/main/java/org/springframework/data/neo4j/core/Neo4jTemplate.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ public <T> T save(T instance) {
449449
});
450450
}
451451

452-
private <T> T saveImpl(T instance, Collection<PropertyFilter.ProjectedPath> includedProperties,
452+
private <T> T saveImpl(T instance, @NullableCollection<PropertyFilter.ProjectedPath> includedProperties,
453453
@Nullable NestedRelationshipProcessingStateMachine stateMachine) {
454454

455455
if (stateMachine != null && stateMachine.hasProcessedValue(instance)) {
@@ -1489,12 +1489,7 @@ private Optional<Neo4jClient.RecordFetchSpec<T>> createFetchSpec() {
14891489
statement = nodesAndRelationshipsById.toStatement(entityMetaData);
14901490
}
14911491
else {
1492-
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1493-
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1494-
}
1495-
else {
1496-
statement = queryFragments.toStatement();
1497-
}
1492+
statement = queryFragmentsAndParameters.toStatement();
14981493
}
14991494
cypherQuery = Neo4jTemplate.this.renderer.render(statement);
15001495
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);

‎src/main/java/org/springframework/data/neo4j/core/ReactiveNeo4jTemplate.java

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -469,13 +469,13 @@ <T, R> Flux<R> doSave(Iterable<R> instances, Class<T> domainType) {
469469
});
470470
}
471471

472-
private <T> Mono<T> saveImpl(T instance, Collection<PropertyFilter.ProjectedPath> includedProperties,
472+
private <T> Mono<T> saveImpl(T instance, @NullableCollection<PropertyFilter.ProjectedPath> includedProperties,
473473
@Nullable NestedRelationshipProcessingStateMachine stateMachine) {
474474
return saveImpl(instance, includedProperties, stateMachine, new HashSet<>());
475475
}
476476

477477
@SuppressWarnings("deprecation")
478-
private <T> Mono<T> saveImpl(T instance, Collection<PropertyFilter.ProjectedPath> includedProperties,
478+
private <T> Mono<T> saveImpl(T instance, @NullableCollection<PropertyFilter.ProjectedPath> includedProperties,
479479
@Nullable NestedRelationshipProcessingStateMachine stateMachine, Collection<Object> knownRelationshipsIds) {
480480

481481
if (stateMachine != null && stateMachine.hasProcessedValue(instance)) {
@@ -1408,13 +1408,7 @@ public <T> Mono<ExecutableQuery<T>> toExecutableQuery(PreparedQuery<T> preparedQ
14081408
return new DefaultReactiveExecutableQuery<>(preparedQuery, fetchSpec);
14091409
});
14101410
}
1411-
Statement statement = null;
1412-
if (queryFragmentsAndParameters.hasVectorSearchFragment()) {
1413-
statement = queryFragments.toStatement(queryFragmentsAndParameters.getVectorSearchFragment());
1414-
}
1415-
else {
1416-
statement = queryFragments.toStatement();
1417-
}
1411+
Statement statement = queryFragmentsAndParameters.toStatement();
14181412
cypherQuery = this.renderer.render(statement);
14191413
finalParameters = TemplateSupport.mergeParameters(statement, finalParameters);
14201414
}

‎src/main/java/org/springframework/data/neo4j/core/TemplateSupport.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ else if (candidate != type) {
137137
return candidate;
138138
}
139139

140-
static PropertyFilter computeIncludePropertyPredicate(Collection<PropertyFilter.ProjectedPath> includedProperties,
141-
NodeDescription<?> nodeDescription) {
140+
static PropertyFilter computeIncludePropertyPredicate(
141+
@NullableCollection<PropertyFilter.ProjectedPath> includedProperties, NodeDescription<?> nodeDescription) {
142142

143-
return PropertyFilter.from(includedProperties, nodeDescription);
143+
return PropertyFilter.from(Objects.requireNonNullElseGet(includedProperties, List::of), nodeDescription);
144144
}
145145

146146
static void updateVersionPropertyIfPossible(Neo4jPersistentEntity<?> entityMetaData,
@@ -207,8 +207,8 @@ static Map<String, Object> mergeParameters(Statement statement, Map<String, Obje
207207
* @return a new binder function that only works on the included properties.
208208
*/
209209
static <T> FilteredBinderFunction<T> createAndApplyPropertyFilter(
210-
Collection<PropertyFilter.ProjectedPath> includedProperties, Neo4jPersistentEntity<?> entityMetaData,
211-
Function<T, Map<String, Object>> binderFunction) {
210+
@NullableCollection<PropertyFilter.ProjectedPath> includedProperties,
211+
Neo4jPersistentEntity<?> entityMetaData, Function<T, Map<String, Object>> binderFunction) {
212212

213213
PropertyFilter includeProperty = TemplateSupport.computeIncludePropertyPredicate(includedProperties,
214214
entityMetaData);

‎src/main/java/org/springframework/data/neo4j/repository/query/CypherQueryCreator.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ protected QueryFragmentsAndParameters complete(@Nullable Condition condition, So
208208
if (this.keysetRequiresSort && theSort.isUnsorted()) {
209209
throw new UnsupportedOperationException("Unsorted keyset based scrolling is not supported.");
210210
}
211-
if (this.queryMethod.hasVectorSearchAnnotation()) {
211+
if (this.queryMethod.hasVectorSearchAnnotation() && this.vectorSearchParameter != null) {
212212
var vectorSearchAnnotation = this.queryMethod.getVectorSearchAnnotation().orElseThrow();
213213
var indexName = vectorSearchAnnotation.indexName();
214214
var numberOfNodes = vectorSearchAnnotation.numberOfNodes();
@@ -219,9 +219,8 @@ protected QueryFragmentsAndParameters complete(@Nullable Condition condition, So
219219
}
220220
var vectorSearchFragment = new VectorSearchFragment(indexName, numberOfNodes,
221221
(this.scoreParameter != null) ? this.scoreParameter.getValue() : null);
222-
var queryFragmentsAndParameters = new QueryFragmentsAndParameters(this.nodeDescription, queryFragments,
223-
vectorSearchFragment, convertedParameters, theSort);
224-
return queryFragmentsAndParameters;
222+
return new QueryFragmentsAndParameters(this.nodeDescription, queryFragments, vectorSearchFragment,
223+
convertedParameters, theSort);
225224
}
226225
return new QueryFragmentsAndParameters(this.nodeDescription, queryFragments, convertedParameters, theSort);
227226
}

‎src/main/java/org/springframework/data/neo4j/repository/query/QueryFragments.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ public Statement toStatement(VectorSearchFragment vectorSearchFragment) {
208208
.with(Cypher.name("node").as(((Node) this.matchOn.get(0)).getRequiredSymbolicName().getValue()),
209209
Cypher.name("score").as(Constants.NAME_OF_SCORE));
210210

211-
StatementBuilder.OngoingReadingWithoutWhere match = null;
211+
StatementBuilder.OngoingReadingWithoutWhere match;
212212
if (vectorSearchFragment.hasScore()) {
213213
match = vectorSearch
214214
.where(Cypher.raw(Constants.NAME_OF_SCORE)

‎src/main/java/org/springframework/data/neo4j/repository/query/QueryFragmentsAndParameters.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.neo4j.cypherdsl.core.PatternElement;
3434
import org.neo4j.cypherdsl.core.RelationshipPattern;
3535
import org.neo4j.cypherdsl.core.SortItem;
36+
import org.neo4j.cypherdsl.core.Statement;
3637

3738
import org.springframework.data.domain.Example;
3839
import org.springframework.data.domain.KeysetScrollPosition;
@@ -63,6 +64,7 @@ public final class QueryFragmentsAndParameters {
6364

6465
private final QueryFragments queryFragments;
6566

67+
@Nullable
6668
private final VectorSearchFragment vectorSearchFragment;
6769

6870
@Nullable
@@ -76,7 +78,7 @@ public final class QueryFragmentsAndParameters {
7678
private NodeDescription<?> nodeDescription;
7779

7880
public QueryFragmentsAndParameters(@Nullable NodeDescription<?> nodeDescription, QueryFragments queryFragments,
79-
VectorSearchFragment vectorSearchFragment, Map<String, Object> parameters, @Nullable Sort sort) {
81+
@NullableVectorSearchFragment vectorSearchFragment, Map<String, Object> parameters, @Nullable Sort sort) {
8082
this.nodeDescription = nodeDescription;
8183
this.queryFragments = queryFragments;
8284
this.vectorSearchFragment = vectorSearchFragment;
@@ -402,10 +404,6 @@ public boolean hasVectorSearchFragment() {
402404
return this.vectorSearchFragment != null;
403405
}
404406

405-
public VectorSearchFragment getVectorSearchFragment() {
406-
return this.vectorSearchFragment;
407-
}
408-
409407
@Nullable public String getCypherQuery() {
410408
return this.cypherQuery;
411409
}
@@ -418,4 +416,11 @@ public Sort getSort() {
418416
return this.sort;
419417
}
420418

419+
public Statement toStatement() {
420+
if (this.hasVectorSearchFragment()) {
421+
return this.queryFragments.toStatement(Objects.requireNonNull(this.vectorSearchFragment));
422+
}
423+
return this.queryFragments.toStatement();
424+
}
425+
421426
}

‎src/main/java/org/springframework/data/neo4j/repository/query/VectorSearchFragment.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
* @param numberOfNodes number of nodes to fetch from the index search
2626
* @param score score filter
2727
*/
28-
record VectorSearchFragment(String indexName, int numberOfNodes, @Nullable Double score) {
28+
publicrecord VectorSearchFragment(String indexName, int numberOfNodes, @Nullable Double score) {
2929

3030
boolean hasScore() {
3131
return this.score != null;

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /