Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit bd2b115

Browse files
committed
HHH-19710 Add vector support for SAP HANA Cloud
1 parent c0cc6a2 commit bd2b115

File tree

5 files changed

+273
-1
lines changed

5 files changed

+273
-1
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.boot.model.FunctionContributions;
8+
import org.hibernate.boot.model.FunctionContributor;
9+
import org.hibernate.dialect.Dialect;
10+
import org.hibernate.dialect.HANADialect;
11+
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
12+
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
13+
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
14+
import org.hibernate.type.spi.TypeConfiguration;
15+
16+
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER;
17+
18+
public class HANAVectorFunctionContributor implements FunctionContributor {
19+
20+
@Override
21+
public void contributeFunctions(FunctionContributions functionContributions) {
22+
final Dialect dialect = functionContributions.getDialect();
23+
if ( dialect instanceof HANADialect hanaDialect && hanaDialect.isCloud() ) {
24+
final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions );
25+
26+
vectorFunctionFactory.cosineDistance( "cosine_similarity(?1,?2)" );
27+
vectorFunctionFactory.euclideanDistance( "l2distance(?1,?2)" );
28+
vectorFunctionFactory.euclideanSquaredDistance( "power(l2distance(?1,?2),2)" );
29+
30+
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
31+
vectorFunctionFactory.registerPatternVectorFunction(
32+
"vector_dims",
33+
"cardinality(?1)",
34+
typeConfiguration.getBasicTypeForJavaType( Integer.class ),
35+
1
36+
);
37+
vectorFunctionFactory.registerNamedVectorFunction(
38+
"l2norm",
39+
typeConfiguration.getBasicTypeForJavaType( Double.class ),
40+
1
41+
);
42+
functionContributions.getFunctionRegistry().registerAlternateKey( "vector_norm", "l2norm" );
43+
functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "l2norm" );
44+
45+
functionContributions.getFunctionRegistry().namedDescriptorBuilder( "subvector" )
46+
.setArgumentsValidator( StandardArgumentsValidators.composite(
47+
StandardArgumentsValidators.exactly( 3 ),
48+
VectorArgumentValidator.INSTANCE
49+
) )
50+
.setArgumentTypeResolver( StandardFunctionArgumentTypeResolvers.byArgument(
51+
VectorArgumentTypeResolver.INSTANCE,
52+
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER ),
53+
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER )
54+
) )
55+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) )
56+
.register();
57+
functionContributions.getFunctionRegistry().namedDescriptorBuilder( "l2normalize" )
58+
.setArgumentsValidator( VectorArgumentValidator.INSTANCE )
59+
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
60+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) )
61+
.register();
62+
functionContributions.getFunctionRegistry().registerAlternateKey( "l2_normalize", "l2normalize" );
63+
}
64+
}
65+
66+
@Override
67+
public int ordinal() {
68+
return 200;
69+
}
70+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.checkerframework.checker.nullness.qual.Nullable;
8+
import org.hibernate.dialect.Dialect;
9+
import org.hibernate.engine.jdbc.Size;
10+
import org.hibernate.metamodel.mapping.JdbcMapping;
11+
import org.hibernate.sql.ast.spi.SqlAppender;
12+
import org.hibernate.type.descriptor.ValueExtractor;
13+
import org.hibernate.type.descriptor.WrapperOptions;
14+
import org.hibernate.type.descriptor.java.JavaType;
15+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
16+
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
17+
import org.hibernate.type.descriptor.jdbc.JdbcType;
18+
import org.hibernate.type.spi.TypeConfiguration;
19+
20+
import java.sql.CallableStatement;
21+
import java.sql.ResultSet;
22+
import java.sql.SQLException;
23+
24+
import static org.hibernate.vector.internal.VectorHelper.parseFloatVector;
25+
26+
public class HANAVectorJdbcType extends ArrayJdbcType {
27+
28+
private final int sqlType;
29+
private final String typeName;
30+
31+
public HANAVectorJdbcType(JdbcType elementJdbcType, int sqlType, String typeName) {
32+
super( elementJdbcType );
33+
this.sqlType = sqlType;
34+
this.typeName = typeName;
35+
}
36+
37+
@Override
38+
public int getDefaultSqlTypeCode() {
39+
return sqlType;
40+
}
41+
42+
@Override
43+
public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
44+
Integer precision,
45+
Integer scale,
46+
TypeConfiguration typeConfiguration) {
47+
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
48+
}
49+
50+
@Override
51+
public @Nullable String castToPattern(JdbcMapping targetJdbcMapping, @Nullable Size size) {
52+
final JdbcType jdbcType = targetJdbcMapping.getJdbcType();
53+
return jdbcType.isString()
54+
? jdbcType.isLob() ? "to_nclob(?1)" : "to_nvarchar(?1)"
55+
: null;
56+
}
57+
58+
@Override
59+
public void appendWriteExpression(
60+
String writeExpression,
61+
@Nullable Size size,
62+
SqlAppender appender,
63+
Dialect dialect) {
64+
appender.append( "to_" );
65+
appender.append( typeName );
66+
appender.append( '(');
67+
appender.append( writeExpression );
68+
appender.append( ')' );
69+
}
70+
71+
@Override
72+
public boolean isWriteExpressionTyped(Dialect dialect) {
73+
return true;
74+
}
75+
76+
@Override
77+
public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) {
78+
return sourceMapping.getJdbcType().isStringLike() ? "to_" + typeName + "(?1)" : null;
79+
}
80+
81+
@Override
82+
public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
83+
return new BasicExtractor<>( javaTypeDescriptor, this ) {
84+
@Override
85+
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
86+
return javaTypeDescriptor.wrap( parseFloatVector( rs.getString( paramIndex ) ), options );
87+
}
88+
89+
@Override
90+
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
91+
return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( index ) ), options );
92+
}
93+
94+
@Override
95+
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
96+
return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( name ) ), options );
97+
}
98+
};
99+
}
100+
101+
@Override
102+
public boolean equals(Object that) {
103+
return super.equals( that )
104+
&& that instanceof HANAVectorJdbcType vectorJdbcType
105+
&& sqlType == vectorJdbcType.sqlType;
106+
}
107+
108+
@Override
109+
public int hashCode() {
110+
return sqlType + 31 * super.hashCode();
111+
}
112+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.boot.model.TypeContributions;
8+
import org.hibernate.boot.model.TypeContributor;
9+
import org.hibernate.dialect.Dialect;
10+
import org.hibernate.dialect.HANADialect;
11+
import org.hibernate.engine.jdbc.spi.JdbcServices;
12+
import org.hibernate.service.ServiceRegistry;
13+
import org.hibernate.type.BasicArrayType;
14+
import org.hibernate.type.BasicType;
15+
import org.hibernate.type.BasicTypeRegistry;
16+
import org.hibernate.type.SqlTypes;
17+
import org.hibernate.type.StandardBasicTypes;
18+
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
19+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
20+
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
21+
import org.hibernate.type.spi.TypeConfiguration;
22+
23+
public class HANAVectorTypeContributor implements TypeContributor {
24+
25+
@Override
26+
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
27+
final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect();
28+
if ( dialect instanceof HANADialect hanaDialect && hanaDialect.isCloud() ) {
29+
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
30+
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
31+
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
32+
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
33+
final BasicType<Float> floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT );
34+
final ArrayJdbcType genericVectorJdbcType = new HANAVectorJdbcType(
35+
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
36+
SqlTypes.VECTOR,
37+
"real_vector"
38+
);
39+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType );
40+
final ArrayJdbcType floatVectorJdbcType = new HANAVectorJdbcType(
41+
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
42+
SqlTypes.VECTOR_FLOAT32,
43+
"real_vector"
44+
);
45+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType );
46+
final ArrayJdbcType float16VectorJdbcType = new HANAVectorJdbcType(
47+
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
48+
SqlTypes.VECTOR_FLOAT16,
49+
"half_vector"
50+
);
51+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT16, float16VectorJdbcType );
52+
53+
basicTypeRegistry.register(
54+
new BasicArrayType<>(
55+
floatBasicType,
56+
genericVectorJdbcType,
57+
javaTypeRegistry.getDescriptor( float[].class )
58+
),
59+
StandardBasicTypes.VECTOR.getName()
60+
);
61+
basicTypeRegistry.register(
62+
new BasicArrayType<>(
63+
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
64+
floatVectorJdbcType,
65+
javaTypeRegistry.getDescriptor( float[].class )
66+
),
67+
StandardBasicTypes.VECTOR_FLOAT32.getName()
68+
);
69+
basicTypeRegistry.register(
70+
new BasicArrayType<>(
71+
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
72+
float16VectorJdbcType,
73+
javaTypeRegistry.getDescriptor( float[].class )
74+
),
75+
StandardBasicTypes.VECTOR_FLOAT16.getName()
76+
);
77+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
78+
new VectorDdlType( SqlTypes.VECTOR, "real_vector($l)", "real_vector", dialect )
79+
);
80+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
81+
new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "real_vector($l)", "real_vector", dialect )
82+
);
83+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
84+
new VectorDdlType( SqlTypes.VECTOR_FLOAT16, "half_vector($l)", "half_vector", dialect )
85+
);
86+
}
87+
}
88+
}

‎hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ org.hibernate.vector.internal.OracleVectorFunctionContributor
33
org.hibernate.vector.internal.MariaDBFunctionContributor
44
org.hibernate.vector.internal.MySQLFunctionContributor
55
org.hibernate.vector.internal.DB2VectorFunctionContributor
6-
org.hibernate.vector.internal.CockroachFunctionContributor
6+
org.hibernate.vector.internal.CockroachFunctionContributor
7+
org.hibernate.vector.internal.HANAVectorFunctionContributor

‎hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ org.hibernate.vector.internal.MariaDBTypeContributor
44
org.hibernate.vector.internal.MySQLTypeContributor
55
org.hibernate.vector.internal.DB2VectorTypeContributor
66
org.hibernate.vector.internal.CockroachTypeContributor
7+
org.hibernate.vector.internal.HANAVectorTypeContributor

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /