Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit e647bfa

Browse files
committed
HHH-18973 Cleanup vector module and add MySQL vector support
Also add support for optional cast patterns to JdbcType to avoid having to touch Dialect for new JdbcType and DdlType.
1 parent 9bf7773 commit e647bfa

33 files changed

+971
-1034
lines changed

‎documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,21 @@ The Hibernate ORM Vector module contains support for mathematical vector types a
1212
This is useful for AI/ML topics like vector similarity search and Retrieval-Augmented Generation (RAG).
1313
The module comes with support for a special `vector` data type that essentially represents an array of bytes, floats, or doubles.
1414

15-
So far, both the PostgreSQL extension `pgvector` and the Oracle database 23ai+ `AI Vector Search` feature are supported, but in theory,
16-
the vector specific functions could be implemented to work with every database that supports arrays.
15+
Currently, the following databases are supported:
1716

18-
For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation] or the https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[AI Vector Search documentation].
17+
* PostgreSQL 13+ through the https://github.com/pgvector/pgvector#querying[`pgvector` extension]
18+
* https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[Oracle database 23ai+]
19+
* https://mariadb.com/docs/server/reference/sql-structure/vectors/vector-overview[MariaDB 11.7+]
20+
* https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html[MySQL 9.0+]
21+
22+
In theory, the vector-specific functions could be implemented to work with every database that supports arrays.
23+
24+
[WARNING]
25+
====
26+
Per the https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html#function_distance[MySQL documentation],
27+
the various vector distance functions for MySQL only work on MySQL cloud offerings like
28+
https://dev.mysql.com/doc/heatwave/en/mys-hw-about-heatwave.html[HeatWave MySQL on OCI].
29+
====
1930

2031
[[vector-module-setup]]
2132
=== Setup
@@ -57,7 +68,7 @@ As Oracle AI Vector Search supports different types of elements (to ensure bette
5768
====
5869
[source, java, indent=0]
5970
----
60-
include::{example-dir-vector}/PGVectorTest.java[tags=usage-example]
71+
include::{example-dir-vector}/FloatVectorTest.java[tags=usage-example]
6172
----
6273
====
6374

@@ -113,7 +124,7 @@ which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 )
113124
====
114125
[source, java, indent=0]
115126
----
116-
include::{example-dir-vector}/PGVectorTest.java[tags=cosine-distance-example]
127+
include::{example-dir-vector}/FloatVectorTest.java[tags=cosine-distance-example]
117128
----
118129
====
119130

@@ -128,7 +139,7 @@ The `l2_distance()` function is an alias.
128139
====
129140
[source, java, indent=0]
130141
----
131-
include::{example-dir-vector}/PGVectorTest.java[tags=euclidean-distance-example]
142+
include::{example-dir-vector}/FloatVectorTest.java[tags=euclidean-distance-example]
132143
----
133144
====
134145

@@ -143,7 +154,7 @@ The `l1_distance()` function is an alias.
143154
====
144155
[source, java, indent=0]
145156
----
146-
include::{example-dir-vector}/PGVectorTest.java[tags=taxicab-distance-example]
157+
include::{example-dir-vector}/FloatVectorTest.java[tags=taxicab-distance-example]
147158
----
148159
====
149160

@@ -158,7 +169,7 @@ and the `inner_product()` function as well, but multiplies the result time `-1`.
158169
====
159170
[source, java, indent=0]
160171
----
161-
include::{example-dir-vector}/PGVectorTest.java[tags=inner-product-example]
172+
include::{example-dir-vector}/FloatVectorTest.java[tags=inner-product-example]
162173
----
163174
====
164175

@@ -171,7 +182,7 @@ Determines the dimensions of a vector.
171182
====
172183
[source, java, indent=0]
173184
----
174-
include::{example-dir-vector}/PGVectorTest.java[tags=vector-dims-example]
185+
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-dims-example]
175186
----
176187
====
177188

@@ -185,7 +196,7 @@ which is `sqrt( sum( v_i^2 ) )`.
185196
====
186197
[source, java, indent=0]
187198
----
188-
include::{example-dir-vector}/PGVectorTest.java[tags=vector-norm-example]
199+
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-norm-example]
189200
----
190201
====
191202

‎hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,14 @@ public void render(
7777
renderCastArrayToString( sqlAppender, arguments.get( 0 ), dialect, walker );
7878
}
7979
else {
80-
new PatternRenderer( dialect.castPattern( sourceType, targetType ) )
81-
.render( sqlAppender, arguments, walker );
80+
String castPattern = targetJdbcMapping.getJdbcType().castFromPattern( sourceMapping );
81+
if ( castPattern == null ) {
82+
castPattern = sourceMapping.getJdbcType().castToPattern( targetJdbcMapping );
83+
if ( castPattern == null ) {
84+
castPattern = dialect.castPattern( sourceType, targetType );
85+
}
86+
}
87+
new PatternRenderer( castPattern ).render( sqlAppender, arguments, walker );
8288
}
8389
}
8490

‎hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ public ReturnableType<?> resolveFunctionReturnType(
9090
case NUMERIC:
9191
return BigInteger.class.isAssignableFrom( basicType.getJavaType() ) ? bigIntegerType : bigDecimalType;
9292
case VECTOR:
93+
case VECTOR_FLOAT32:
94+
case VECTOR_FLOAT64:
95+
case VECTOR_INT8:
9396
return basicType;
9497
}
9598
return bigDecimalType;
@@ -123,6 +126,9 @@ public BasicValuedMapping resolveFunctionReturnType(
123126
final Class<?> argTypeClass = jdbcMapping.getJavaTypeDescriptor().getJavaTypeClass();
124127
return BigInteger.class.isAssignableFrom( argTypeClass ) ? bigIntegerType : bigDecimalType;
125128
case VECTOR:
129+
case VECTOR_FLOAT32:
130+
case VECTOR_FLOAT64:
131+
case VECTOR_INT8:
126132
return (BasicValuedMapping) jdbcMapping;
127133
}
128134
return bigDecimalType;

‎hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
import java.sql.SQLException;
1010
import java.sql.Types;
1111

12+
import org.checkerframework.checker.nullness.qual.Nullable;
1213
import org.hibernate.Incubating;
1314
import org.hibernate.boot.model.relational.Database;
1415
import org.hibernate.dialect.Dialect;
1516
import org.hibernate.engine.jdbc.Size;
17+
import org.hibernate.metamodel.mapping.JdbcMapping;
1618
import org.hibernate.query.sqm.CastType;
1719
import org.hibernate.sql.ast.spi.SqlAppender;
1820
import org.hibernate.sql.ast.spi.StringBuilderSqlAppender;
@@ -367,6 +369,30 @@ default String getExtraCreateTableInfo(JavaType<?> javaType, String columnName,
367369
return "";
368370
}
369371

372+
/**
373+
* Returns the cast pattern from the given source type to this type, or {@code null} if not possible.
374+
*
375+
* @param sourceMapping The source type
376+
* @return The cast pattern or null
377+
* @since 7.1
378+
*/
379+
@Incubating
380+
default @Nullable String castFromPattern(JdbcMapping sourceMapping) {
381+
return null;
382+
}
383+
384+
/**
385+
* Returns the cast pattern from this type to the given target type, or {@code null} if not possible.
386+
*
387+
* @param targetJdbcMapping The target type
388+
* @return The cast pattern or null
389+
* @since 7.1
390+
*/
391+
@Incubating
392+
default @Nullable String castToPattern(JdbcMapping targetJdbcMapping) {
393+
return null;
394+
}
395+
370396
@Incubating
371397
default boolean isComparable() {
372398
final int code = getDefaultSqlTypeCode();

‎hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
import org.hibernate.boot.internal.MetadataBuilderImpl;
1414
import org.hibernate.boot.internal.NamedProcedureCallDefinitionImpl;
1515
import org.hibernate.boot.model.FunctionContributions;
16+
import org.hibernate.boot.model.FunctionContributor;
1617
import org.hibernate.boot.model.IdentifierGeneratorDefinition;
1718
import org.hibernate.boot.model.NamedEntityGraphDefinition;
1819
import org.hibernate.boot.model.TypeContributions;
20+
import org.hibernate.boot.model.TypeContributor;
1921
import org.hibernate.boot.model.TypeDefinition;
2022
import org.hibernate.boot.model.TypeDefinitionRegistry;
2123
import org.hibernate.boot.model.convert.spi.ConverterAutoApplyHandler;
@@ -97,6 +99,7 @@
9799
import org.hibernate.type.descriptor.java.StringJavaType;
98100
import org.hibernate.type.descriptor.jdbc.JdbcType;
99101
import org.hibernate.type.descriptor.jdbc.VarcharJdbcType;
102+
import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry;
100103
import org.hibernate.type.internal.BasicTypeImpl;
101104
import org.hibernate.type.spi.TypeConfiguration;
102105
import org.hibernate.usertype.CompositeUserType;
@@ -105,6 +108,7 @@
105108
import java.util.HashMap;
106109
import java.util.List;
107110
import java.util.Map;
111+
import java.util.ServiceLoader;
108112
import java.util.Set;
109113
import java.util.UUID;
110114
import java.util.function.Consumer;
@@ -1076,6 +1080,66 @@ public boolean apply(Dialect dialect) {
10761080
}
10771081
}
10781082

1083+
public static class SupportsVectorType implements DialectFeatureCheck {
1084+
public boolean apply(Dialect dialect) {
1085+
return definesDdlType( dialect, SqlTypes.VECTOR );
1086+
}
1087+
}
1088+
1089+
public static class SupportsDoubleVectorType implements DialectFeatureCheck {
1090+
public boolean apply(Dialect dialect) {
1091+
return definesDdlType( dialect, SqlTypes.VECTOR_FLOAT64 );
1092+
}
1093+
}
1094+
1095+
public static class SupportsByteVectorType implements DialectFeatureCheck {
1096+
public boolean apply(Dialect dialect) {
1097+
return definesDdlType( dialect, SqlTypes.VECTOR_INT8 );
1098+
}
1099+
}
1100+
1101+
public static class SupportsCosineDistance implements DialectFeatureCheck {
1102+
public boolean apply(Dialect dialect) {
1103+
return definesFunction( dialect, "cosine_distance" );
1104+
}
1105+
}
1106+
1107+
public static class SupportsEuclideanDistance implements DialectFeatureCheck {
1108+
public boolean apply(Dialect dialect) {
1109+
return definesFunction( dialect, "euclidean_distance" );
1110+
}
1111+
}
1112+
1113+
public static class SupportsTaxicabDistance implements DialectFeatureCheck {
1114+
public boolean apply(Dialect dialect) {
1115+
return definesFunction( dialect, "taxicab_distance" );
1116+
}
1117+
}
1118+
1119+
public static class SupportsHammingDistance implements DialectFeatureCheck {
1120+
public boolean apply(Dialect dialect) {
1121+
return definesFunction( dialect, "hamming_distance" );
1122+
}
1123+
}
1124+
1125+
public static class SupportsInnerProduct implements DialectFeatureCheck {
1126+
public boolean apply(Dialect dialect) {
1127+
return definesFunction( dialect, "inner_product" );
1128+
}
1129+
}
1130+
1131+
public static class SupportsVectorDims implements DialectFeatureCheck {
1132+
public boolean apply(Dialect dialect) {
1133+
return definesFunction( dialect, "vector_dims" );
1134+
}
1135+
}
1136+
1137+
public static class SupportsVectorNorm implements DialectFeatureCheck {
1138+
public boolean apply(Dialect dialect) {
1139+
return definesFunction( dialect, "vector_norm" );
1140+
}
1141+
}
1142+
10791143
public static class IsJtds implements DialectFeatureCheck {
10801144
public boolean apply(Dialect dialect) {
10811145
return dialect instanceof SybaseDialect && ( (SybaseDialect) dialect ).getDriverKind() == SybaseDriverKind.JTDS;
@@ -1141,7 +1205,7 @@ public boolean apply(Dialect dialect) {
11411205
}
11421206
}
11431207

1144-
private static final HashMap<Dialect, SqmFunctionRegistry> FUNCTION_REGISTRIES = new HashMap<>();
1208+
private static final HashMap<Dialect, FakeFunctionContributions> FUNCTION_CONTRIBUTIONS = new HashMap<>();
11451209

11461210
public static boolean definesFunction(Dialect dialect, String functionName) {
11471211
return getSqmFunctionRegistry( dialect ).findFunctionDescriptor( functionName ) != null;
@@ -1151,6 +1215,11 @@ public static boolean definesSetReturningFunction(Dialect dialect, String functi
11511215
return getSqmFunctionRegistry( dialect ).findSetReturningFunctionDescriptor( functionName ) != null;
11521216
}
11531217

1218+
public static boolean definesDdlType(Dialect dialect, int typeCode) {
1219+
final DdlTypeRegistry ddlTypeRegistry = getFunctionContributions( dialect ).typeConfiguration.getDdlTypeRegistry();
1220+
return ddlTypeRegistry.getDescriptor( typeCode ) != null;
1221+
}
1222+
11541223
public static class SupportsSubqueryInSelect implements DialectFeatureCheck {
11551224
@Override
11561225
public boolean apply(Dialect dialect) {
@@ -1172,24 +1241,33 @@ public boolean apply(Dialect dialect) {
11721241
}
11731242
}
11741243

1175-
11761244
private static SqmFunctionRegistry getSqmFunctionRegistry(Dialect dialect) {
1177-
SqmFunctionRegistry sqmFunctionRegistry = FUNCTION_REGISTRIES.get( dialect );
1178-
if ( sqmFunctionRegistry == null ) {
1245+
return getFunctionContributions( dialect ).functionRegistry;
1246+
}
1247+
1248+
private static FakeFunctionContributions getFunctionContributions(Dialect dialect) {
1249+
FakeFunctionContributions functionContributions = FUNCTION_CONTRIBUTIONS.get( dialect );
1250+
if ( functionContributions == null ) {
11791251
final TypeConfiguration typeConfiguration = new TypeConfiguration();
11801252
final SqmFunctionRegistry functionRegistry = new SqmFunctionRegistry();
11811253
typeConfiguration.scope( new FakeMetadataBuildingContext( typeConfiguration, functionRegistry ) );
11821254
final FakeTypeContributions typeContributions = new FakeTypeContributions( typeConfiguration );
1183-
finalFakeFunctionContributionsfunctionContributions = new FakeFunctionContributions(
1255+
functionContributions = new FakeFunctionContributions(
11841256
dialect,
11851257
typeConfiguration,
11861258
functionRegistry
11871259
);
11881260
dialect.contribute( typeContributions, typeConfiguration.getServiceRegistry() );
11891261
dialect.initializeFunctionRegistry( functionContributions );
1190-
FUNCTION_REGISTRIES.put( dialect, sqmFunctionRegistry = functionContributions.functionRegistry );
1262+
for ( TypeContributor typeContributor : ServiceLoader.load( TypeContributor.class ) ) {
1263+
typeContributor.contribute( typeContributions, typeConfiguration.getServiceRegistry() );
1264+
}
1265+
for ( FunctionContributor functionContributor : ServiceLoader.load( FunctionContributor.class ) ) {
1266+
functionContributor.contributeFunctions( functionContributions );
1267+
}
1268+
FUNCTION_CONTRIBUTIONS.put( dialect, functionContributions );
11911269
}
1192-
return sqmFunctionRegistry;
1270+
return functionContributions;
11931271
}
11941272

11951273
public static class FakeTypeContributions implements TypeContributions {

‎hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import java.sql.ResultSet;
1010
import java.sql.SQLException;
1111

12+
import org.checkerframework.checker.nullness.qual.Nullable;
1213
import org.hibernate.dialect.Dialect;
14+
import org.hibernate.metamodel.mapping.JdbcMapping;
1315
import org.hibernate.sql.ast.spi.SqlAppender;
1416
import org.hibernate.type.SqlTypes;
1517
import org.hibernate.type.descriptor.ValueBinder;
@@ -43,13 +45,13 @@ public AbstractOracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSu
4345
this.isVectorSupported = isVectorSupported;
4446
}
4547

46-
public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect);
47-
4848
@Override
49-
public intgetDefaultSqlTypeCode() {
50-
return SqlTypes.VECTOR;
49+
public @NullableStringcastToPattern(JdbcMappingtargetJdbcMapping) {
50+
return targetJdbcMapping.getJdbcType().isStringLike() ? "from_vector(?1 returning ?2)" : null;
5151
}
5252

53+
public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect);
54+
5355
@Override
5456
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
5557
final JavaType<T> elementJavaType;

‎hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class MariaDBFunctionContributor implements FunctionContributor {
1313
@Override
1414
public void contributeFunctions(FunctionContributions functionContributions) {
1515
final Dialect dialect = functionContributions.getDialect();
16-
if ( dialect instanceof MariaDBDialect ) {
16+
if ( dialect instanceof MariaDBDialect && dialect.getVersion().isSameOrAfter( 11, 7 ) ) {
1717
final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions );
1818

1919
vectorFunctionFactory.cosineDistance( "vec_distance_cosine(?1,?2)" );

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /