Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit b064009

Browse files
committed
HHH-19710 Add vector support for SAP HANA Cloud
1 parent 0d05965 commit b064009

File tree

7 files changed

+352
-2
lines changed

7 files changed

+352
-2
lines changed

‎hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/ArrayJdbcType.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
8181
}
8282
}
8383

84-
private static JavaType<?> elementJavaType(JavaType<?> javaTypeDescriptor) {
84+
protected static JavaType<?> elementJavaType(JavaType<?> javaTypeDescriptor) {
8585
if ( javaTypeDescriptor instanceof ByteArrayJavaType ) {
8686
// Special handling needed for Byte[], because that would conflict with the VARBINARY mapping
8787
return ByteJavaType.INSTANCE;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.dialect.Dialect;
8+
import org.hibernate.sql.ast.spi.SqlAppender;
9+
import org.hibernate.type.descriptor.WrapperOptions;
10+
import org.hibernate.type.descriptor.java.JavaType;
11+
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
12+
import org.hibernate.type.descriptor.jdbc.spi.BasicJdbcLiteralFormatter;
13+
14+
public class HANAJdbcLiteralFormatterVector<T> extends BasicJdbcLiteralFormatter<T> {
15+
16+
private final JdbcLiteralFormatter<Object> elementFormatter;
17+
private final String typeName;
18+
19+
public HANAJdbcLiteralFormatterVector(JavaType<T> javaType, JdbcLiteralFormatter<?> elementFormatter, String typeName) {
20+
super( javaType );
21+
//noinspection unchecked
22+
this.elementFormatter = (JdbcLiteralFormatter<Object>) elementFormatter;
23+
this.typeName = typeName;
24+
}
25+
26+
@Override
27+
public void appendJdbcLiteral(SqlAppender appender, T value, Dialect dialect, WrapperOptions wrapperOptions) {
28+
appender.appendSql( "to_" );
29+
appender.appendSql( typeName );
30+
appender.appendSql( "('" );
31+
final Object[] objects = unwrap( value, Object[].class, wrapperOptions );
32+
appender.appendSql( "cast('" );
33+
char separator = '[';
34+
for ( Object o : objects ) {
35+
appender.appendSql( separator );
36+
elementFormatter.appendJdbcLiteral( appender, o, dialect, wrapperOptions );
37+
separator = ',';
38+
}
39+
appender.appendSql( "]')" );
40+
}
41+
42+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.boot.model.FunctionContributions;
8+
import org.hibernate.boot.model.FunctionContributor;
9+
import org.hibernate.dialect.Dialect;
10+
import org.hibernate.dialect.HANADialect;
11+
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
12+
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
13+
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
14+
import org.hibernate.type.spi.TypeConfiguration;
15+
16+
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER;
17+
18+
public class HANAVectorFunctionContributor implements FunctionContributor {
19+
20+
@Override
21+
public void contributeFunctions(FunctionContributions functionContributions) {
22+
final Dialect dialect = functionContributions.getDialect();
23+
if ( dialect instanceof HANADialect hanaDialect && hanaDialect.isCloud() ) {
24+
final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions );
25+
26+
vectorFunctionFactory.cosineDistance( "cosine_similarity(?1,?2)" );
27+
vectorFunctionFactory.euclideanDistance( "l2distance(?1,?2)" );
28+
vectorFunctionFactory.euclideanSquaredDistance( "power(l2distance(?1,?2),2)" );
29+
30+
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
31+
vectorFunctionFactory.registerPatternVectorFunction(
32+
"vector_dims",
33+
"cardinality(?1)",
34+
typeConfiguration.getBasicTypeForJavaType( Integer.class ),
35+
1
36+
);
37+
vectorFunctionFactory.registerNamedVectorFunction(
38+
"l2norm",
39+
typeConfiguration.getBasicTypeForJavaType( Double.class ),
40+
1
41+
);
42+
functionContributions.getFunctionRegistry().registerAlternateKey( "vector_norm", "l2norm" );
43+
functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "l2norm" );
44+
45+
functionContributions.getFunctionRegistry().namedDescriptorBuilder( "subvector" )
46+
.setArgumentsValidator( StandardArgumentsValidators.composite(
47+
StandardArgumentsValidators.exactly( 3 ),
48+
VectorArgumentValidator.INSTANCE
49+
) )
50+
.setArgumentTypeResolver( StandardFunctionArgumentTypeResolvers.byArgument(
51+
VectorArgumentTypeResolver.INSTANCE,
52+
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER ),
53+
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER )
54+
) )
55+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) )
56+
.register();
57+
functionContributions.getFunctionRegistry().namedDescriptorBuilder( "l2normalize" )
58+
.setArgumentsValidator( VectorArgumentValidator.INSTANCE )
59+
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
60+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) )
61+
.register();
62+
functionContributions.getFunctionRegistry().registerAlternateKey( "l2_normalize", "l2normalize" );
63+
}
64+
}
65+
66+
@Override
67+
public int ordinal() {
68+
return 200;
69+
}
70+
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.checkerframework.checker.nullness.qual.Nullable;
8+
import org.hibernate.dialect.Dialect;
9+
import org.hibernate.engine.jdbc.Size;
10+
import org.hibernate.metamodel.mapping.JdbcMapping;
11+
import org.hibernate.sql.ast.spi.SqlAppender;
12+
import org.hibernate.type.descriptor.ValueBinder;
13+
import org.hibernate.type.descriptor.ValueExtractor;
14+
import org.hibernate.type.descriptor.WrapperOptions;
15+
import org.hibernate.type.descriptor.java.JavaType;
16+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
17+
import org.hibernate.type.descriptor.jdbc.BasicBinder;
18+
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
19+
import org.hibernate.type.descriptor.jdbc.JdbcLiteralFormatter;
20+
import org.hibernate.type.descriptor.jdbc.JdbcType;
21+
import org.hibernate.type.spi.TypeConfiguration;
22+
23+
import java.sql.CallableStatement;
24+
import java.sql.PreparedStatement;
25+
import java.sql.ResultSet;
26+
import java.sql.SQLException;
27+
import java.util.Arrays;
28+
29+
import static org.hibernate.vector.internal.VectorHelper.parseFloatVector;
30+
31+
public class HANAVectorJdbcType extends ArrayJdbcType {
32+
33+
private final int sqlType;
34+
private final String typeName;
35+
36+
public HANAVectorJdbcType(JdbcType elementJdbcType, int sqlType, String typeName) {
37+
super( elementJdbcType );
38+
this.sqlType = sqlType;
39+
this.typeName = typeName;
40+
}
41+
42+
@Override
43+
public int getDefaultSqlTypeCode() {
44+
return sqlType;
45+
}
46+
47+
@Override
48+
public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
49+
Integer precision,
50+
Integer scale,
51+
TypeConfiguration typeConfiguration) {
52+
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
53+
}
54+
55+
@Override
56+
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
57+
return new HANAJdbcLiteralFormatterVector<>(
58+
javaTypeDescriptor,
59+
getElementJdbcType().getJdbcLiteralFormatter( elementJavaType( javaTypeDescriptor ) ),
60+
typeName
61+
);
62+
}
63+
64+
@Override
65+
public @Nullable String castToPattern(JdbcMapping targetJdbcMapping, @Nullable Size size) {
66+
final JdbcType jdbcType = targetJdbcMapping.getJdbcType();
67+
return jdbcType.isString()
68+
? jdbcType.isLob() ? "to_nclob(?1)" : "to_nvarchar(?1)"
69+
: null;
70+
}
71+
72+
@Override
73+
public void appendWriteExpression(
74+
String writeExpression,
75+
@Nullable Size size,
76+
SqlAppender appender,
77+
Dialect dialect) {
78+
appender.append( "to_" );
79+
appender.append( typeName );
80+
appender.append( '(');
81+
appender.append( writeExpression );
82+
appender.append( ')' );
83+
}
84+
85+
@Override
86+
public boolean isWriteExpressionTyped(Dialect dialect) {
87+
return true;
88+
}
89+
90+
@Override
91+
public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) {
92+
return sourceMapping.getJdbcType().isStringLike() ? "to_" + typeName + "(?1)" : null;
93+
}
94+
95+
@Override
96+
public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
97+
return new BasicExtractor<>( javaTypeDescriptor, this ) {
98+
@Override
99+
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
100+
return javaTypeDescriptor.wrap( parseFloatVector( rs.getString( paramIndex ) ), options );
101+
}
102+
103+
@Override
104+
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
105+
return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( index ) ), options );
106+
}
107+
108+
@Override
109+
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
110+
return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( name ) ), options );
111+
}
112+
};
113+
}
114+
115+
@Override
116+
public <X> ValueBinder<X> getBinder(final JavaType<X> javaTypeDescriptor) {
117+
return new BasicBinder<>( javaTypeDescriptor, this ) {
118+
119+
@Override
120+
protected void doBind(PreparedStatement st, X value, int index, WrapperOptions options) throws SQLException {
121+
st.setString( index, getBindValue( value, options ) );
122+
}
123+
124+
@Override
125+
protected void doBind(CallableStatement st, X value, String name, WrapperOptions options)
126+
throws SQLException {
127+
st.setString( name, getBindValue( value, options ) );
128+
}
129+
130+
@Override
131+
public String getBindValue(X value, WrapperOptions options) {
132+
return Arrays.toString( getJavaType().unwrap( value, float[].class, options ) );
133+
}
134+
};
135+
}
136+
137+
@Override
138+
public boolean equals(Object that) {
139+
return super.equals( that )
140+
&& that instanceof HANAVectorJdbcType vectorJdbcType
141+
&& sqlType == vectorJdbcType.sqlType;
142+
}
143+
144+
@Override
145+
public int hashCode() {
146+
return sqlType + 31 * super.hashCode();
147+
}
148+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.boot.model.TypeContributions;
8+
import org.hibernate.boot.model.TypeContributor;
9+
import org.hibernate.dialect.Dialect;
10+
import org.hibernate.dialect.HANADialect;
11+
import org.hibernate.engine.jdbc.spi.JdbcServices;
12+
import org.hibernate.service.ServiceRegistry;
13+
import org.hibernate.type.BasicArrayType;
14+
import org.hibernate.type.BasicType;
15+
import org.hibernate.type.BasicTypeRegistry;
16+
import org.hibernate.type.SqlTypes;
17+
import org.hibernate.type.StandardBasicTypes;
18+
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
19+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
20+
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
21+
import org.hibernate.type.spi.TypeConfiguration;
22+
23+
public class HANAVectorTypeContributor implements TypeContributor {
24+
25+
@Override
26+
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
27+
final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect();
28+
if ( dialect instanceof HANADialect hanaDialect && hanaDialect.isCloud() ) {
29+
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
30+
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
31+
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
32+
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
33+
final BasicType<Float> floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT );
34+
final ArrayJdbcType genericVectorJdbcType = new HANAVectorJdbcType(
35+
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
36+
SqlTypes.VECTOR,
37+
"real_vector"
38+
);
39+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType );
40+
final ArrayJdbcType floatVectorJdbcType = new HANAVectorJdbcType(
41+
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
42+
SqlTypes.VECTOR_FLOAT32,
43+
"real_vector"
44+
);
45+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType );
46+
final ArrayJdbcType float16VectorJdbcType = new HANAVectorJdbcType(
47+
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
48+
SqlTypes.VECTOR_FLOAT16,
49+
"half_vector"
50+
);
51+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT16, float16VectorJdbcType );
52+
53+
basicTypeRegistry.register(
54+
new BasicArrayType<>(
55+
floatBasicType,
56+
genericVectorJdbcType,
57+
javaTypeRegistry.getDescriptor( float[].class )
58+
),
59+
StandardBasicTypes.VECTOR.getName()
60+
);
61+
basicTypeRegistry.register(
62+
new BasicArrayType<>(
63+
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
64+
floatVectorJdbcType,
65+
javaTypeRegistry.getDescriptor( float[].class )
66+
),
67+
StandardBasicTypes.VECTOR_FLOAT32.getName()
68+
);
69+
basicTypeRegistry.register(
70+
new BasicArrayType<>(
71+
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
72+
float16VectorJdbcType,
73+
javaTypeRegistry.getDescriptor( float[].class )
74+
),
75+
StandardBasicTypes.VECTOR_FLOAT16.getName()
76+
);
77+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
78+
new VectorDdlType( SqlTypes.VECTOR, "real_vector($l)", "real_vector", dialect )
79+
);
80+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
81+
new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "real_vector($l)", "real_vector", dialect )
82+
);
83+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
84+
new VectorDdlType( SqlTypes.VECTOR_FLOAT16, "half_vector($l)", "half_vector", dialect )
85+
);
86+
}
87+
}
88+
}

‎hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ org.hibernate.vector.internal.OracleVectorFunctionContributor
33
org.hibernate.vector.internal.MariaDBFunctionContributor
44
org.hibernate.vector.internal.MySQLFunctionContributor
55
org.hibernate.vector.internal.DB2VectorFunctionContributor
6-
org.hibernate.vector.internal.CockroachFunctionContributor
6+
org.hibernate.vector.internal.CockroachFunctionContributor
7+
org.hibernate.vector.internal.HANAVectorFunctionContributor

‎hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ org.hibernate.vector.internal.MariaDBTypeContributor
44
org.hibernate.vector.internal.MySQLTypeContributor
55
org.hibernate.vector.internal.DB2VectorTypeContributor
66
org.hibernate.vector.internal.CockroachTypeContributor
7+
org.hibernate.vector.internal.HANAVectorTypeContributor

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /