Skip to content

Commit 97ce030

Browse files
committed
HHH-18973 Cleanup vector module and add MySQL vector support
Also add support for optional cast patterns to JdbcType to avoid having to touch Dialect for new JdbcType and DdlType.
1 parent 78dca5e commit 97ce030

33 files changed

+971
-1035
lines changed

documentation/src/main/asciidoc/userguide/chapters/query/extensions/Vector.adoc

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,21 @@ The Hibernate ORM Vector module contains support for mathematical vector types a
1212
This is useful for AI/ML topics like vector similarity search and Retrieval-Augmented Generation (RAG).
1313
The module comes with support for a special `vector` data type that essentially represents an array of bytes, floats, or doubles.
1414

15-
So far, both the PostgreSQL extension `pgvector` and the Oracle database 23ai+ `AI Vector Search` feature are supported, but in theory,
16-
the vector specific functions could be implemented to work with every database that supports arrays.
15+
Currently, the following databases are supported:
1716

18-
For further details, refer to the https://github.com/pgvector/pgvector#querying[pgvector documentation] or the https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[AI Vector Search documentation].
17+
* PostgreSQL 13+ through the https://github.com/pgvector/pgvector#querying[`pgvector` extension]
18+
* https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/overview-node.html[Oracle database 23ai+]
19+
* https://mariadb.com/docs/server/reference/sql-structure/vectors/vector-overview[MariaDB 11.7+]
20+
* https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html[MySQL 9.0+]
21+
22+
In theory, the vector-specific functions could be implemented to work with every database that supports arrays.
23+
24+
[WARNING]
25+
====
26+
Per the https://dev.mysql.com/doc/refman/9.4/en/vector-functions.html#function_distance[MySQL documentation],
27+
the various vector distance functions for MySQL only work on MySQL cloud offerings like
28+
https://dev.mysql.com/doc/heatwave/en/mys-hw-about-heatwave.html[HeatWave MySQL on OCI].
29+
====
1930

2031
[[vector-module-setup]]
2132
=== Setup
@@ -57,7 +68,7 @@ As Oracle AI Vector Search supports different types of elements (to ensure bette
5768
====
5869
[source, java, indent=0]
5970
----
60-
include::{example-dir-vector}/PGVectorTest.java[tags=usage-example]
71+
include::{example-dir-vector}/FloatVectorTest.java[tags=usage-example]
6172
----
6273
====
6374

@@ -113,7 +124,7 @@ which is `1 - inner_product( v1, v2 ) / ( vector_norm( v1 ) * vector_norm( v2 )
113124
====
114125
[source, java, indent=0]
115126
----
116-
include::{example-dir-vector}/PGVectorTest.java[tags=cosine-distance-example]
127+
include::{example-dir-vector}/FloatVectorTest.java[tags=cosine-distance-example]
117128
----
118129
====
119130

@@ -128,7 +139,7 @@ The `l2_distance()` function is an alias.
128139
====
129140
[source, java, indent=0]
130141
----
131-
include::{example-dir-vector}/PGVectorTest.java[tags=euclidean-distance-example]
142+
include::{example-dir-vector}/FloatVectorTest.java[tags=euclidean-distance-example]
132143
----
133144
====
134145

@@ -143,7 +154,7 @@ The `l1_distance()` function is an alias.
143154
====
144155
[source, java, indent=0]
145156
----
146-
include::{example-dir-vector}/PGVectorTest.java[tags=taxicab-distance-example]
157+
include::{example-dir-vector}/FloatVectorTest.java[tags=taxicab-distance-example]
147158
----
148159
====
149160

@@ -158,7 +169,7 @@ and the `inner_product()` function as well, but multiplies the result time `-1`.
158169
====
159170
[source, java, indent=0]
160171
----
161-
include::{example-dir-vector}/PGVectorTest.java[tags=inner-product-example]
172+
include::{example-dir-vector}/FloatVectorTest.java[tags=inner-product-example]
162173
----
163174
====
164175

@@ -171,7 +182,7 @@ Determines the dimensions of a vector.
171182
====
172183
[source, java, indent=0]
173184
----
174-
include::{example-dir-vector}/PGVectorTest.java[tags=vector-dims-example]
185+
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-dims-example]
175186
----
176187
====
177188

@@ -185,7 +196,7 @@ which is `sqrt( sum( v_i^2 ) )`.
185196
====
186197
[source, java, indent=0]
187198
----
188-
include::{example-dir-vector}/PGVectorTest.java[tags=vector-norm-example]
199+
include::{example-dir-vector}/FloatVectorTest.java[tags=vector-norm-example]
189200
----
190201
====
191202

hibernate-core/src/main/java/org/hibernate/dialect/function/CastFunction.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,14 @@ public void render(
7777
renderCastArrayToString( sqlAppender, arguments.get( 0 ), dialect, walker );
7878
}
7979
else {
80-
new PatternRenderer( dialect.castPattern( sourceType, targetType ) )
81-
.render( sqlAppender, arguments, walker );
80+
String castPattern = targetJdbcMapping.getJdbcType().castFromPattern( sourceMapping );
81+
if ( castPattern == null ) {
82+
castPattern = sourceMapping.getJdbcType().castToPattern( targetJdbcMapping );
83+
if ( castPattern == null ) {
84+
castPattern = dialect.castPattern( sourceType, targetType );
85+
}
86+
}
87+
new PatternRenderer( castPattern ).render( sqlAppender, arguments, walker );
8288
}
8389
}
8490

hibernate-core/src/main/java/org/hibernate/dialect/function/SumReturnTypeResolver.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ public ReturnableType<?> resolveFunctionReturnType(
9090
case NUMERIC:
9191
return BigInteger.class.isAssignableFrom( basicType.getJavaType() ) ? bigIntegerType : bigDecimalType;
9292
case VECTOR:
93+
case VECTOR_FLOAT32:
94+
case VECTOR_FLOAT64:
95+
case VECTOR_INT8:
9396
return basicType;
9497
}
9598
return bigDecimalType;
@@ -123,6 +126,9 @@ public BasicValuedMapping resolveFunctionReturnType(
123126
final Class<?> argTypeClass = jdbcMapping.getJavaTypeDescriptor().getJavaTypeClass();
124127
return BigInteger.class.isAssignableFrom( argTypeClass ) ? bigIntegerType : bigDecimalType;
125128
case VECTOR:
129+
case VECTOR_FLOAT32:
130+
case VECTOR_FLOAT64:
131+
case VECTOR_INT8:
126132
return (BasicValuedMapping) jdbcMapping;
127133
}
128134
return bigDecimalType;

hibernate-core/src/main/java/org/hibernate/type/descriptor/jdbc/JdbcType.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
import java.sql.SQLException;
1010
import java.sql.Types;
1111

12+
import org.checkerframework.checker.nullness.qual.Nullable;
1213
import org.hibernate.Incubating;
1314
import org.hibernate.boot.model.relational.Database;
1415
import org.hibernate.dialect.Dialect;
1516
import org.hibernate.engine.jdbc.Size;
17+
import org.hibernate.metamodel.mapping.JdbcMapping;
1618
import org.hibernate.query.sqm.CastType;
1719
import org.hibernate.sql.ast.spi.SqlAppender;
1820
import org.hibernate.sql.ast.spi.StringBuilderSqlAppender;
@@ -367,6 +369,30 @@ default String getExtraCreateTableInfo(JavaType<?> javaType, String columnName,
367369
return "";
368370
}
369371

372+
/**
373+
* Returns the cast pattern from the given source type to this type, or {@code null} if not possible.
374+
*
375+
* @param sourceMapping The source type
376+
* @return The cast pattern or null
377+
* @since 7.1
378+
*/
379+
@Incubating
380+
default @Nullable String castFromPattern(JdbcMapping sourceMapping) {
381+
return null;
382+
}
383+
384+
/**
385+
* Returns the cast pattern from this type to the given target type, or {@code null} if not possible.
386+
*
387+
* @param targetJdbcMapping The target type
388+
* @return The cast pattern or null
389+
* @since 7.1
390+
*/
391+
@Incubating
392+
default @Nullable String castToPattern(JdbcMapping targetJdbcMapping) {
393+
return null;
394+
}
395+
370396
@Incubating
371397
default boolean isComparable() {
372398
final int code = getDefaultSqlTypeCode();

hibernate-testing/src/main/java/org/hibernate/testing/orm/junit/DialectFeatureChecks.java

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
import org.hibernate.boot.internal.MetadataBuilderImpl;
1414
import org.hibernate.boot.internal.NamedProcedureCallDefinitionImpl;
1515
import org.hibernate.boot.model.FunctionContributions;
16+
import org.hibernate.boot.model.FunctionContributor;
1617
import org.hibernate.boot.model.IdentifierGeneratorDefinition;
1718
import org.hibernate.boot.model.NamedEntityGraphDefinition;
1819
import org.hibernate.boot.model.TypeContributions;
20+
import org.hibernate.boot.model.TypeContributor;
1921
import org.hibernate.boot.model.TypeDefinition;
2022
import org.hibernate.boot.model.TypeDefinitionRegistry;
2123
import org.hibernate.boot.model.convert.spi.ConverterAutoApplyHandler;
@@ -97,6 +99,7 @@
9799
import org.hibernate.type.descriptor.java.StringJavaType;
98100
import org.hibernate.type.descriptor.jdbc.JdbcType;
99101
import org.hibernate.type.descriptor.jdbc.VarcharJdbcType;
102+
import org.hibernate.type.descriptor.sql.spi.DdlTypeRegistry;
100103
import org.hibernate.type.internal.BasicTypeImpl;
101104
import org.hibernate.type.spi.TypeConfiguration;
102105
import org.hibernate.usertype.CompositeUserType;
@@ -105,6 +108,7 @@
105108
import java.util.HashMap;
106109
import java.util.List;
107110
import java.util.Map;
111+
import java.util.ServiceLoader;
108112
import java.util.Set;
109113
import java.util.UUID;
110114
import java.util.function.Consumer;
@@ -1076,6 +1080,66 @@ public boolean apply(Dialect dialect) {
10761080
}
10771081
}
10781082

1083+
public static class SupportsVectorType implements DialectFeatureCheck {
1084+
public boolean apply(Dialect dialect) {
1085+
return definesDdlType( dialect, SqlTypes.VECTOR );
1086+
}
1087+
}
1088+
1089+
public static class SupportsDoubleVectorType implements DialectFeatureCheck {
1090+
public boolean apply(Dialect dialect) {
1091+
return definesDdlType( dialect, SqlTypes.VECTOR_FLOAT64 );
1092+
}
1093+
}
1094+
1095+
public static class SupportsByteVectorType implements DialectFeatureCheck {
1096+
public boolean apply(Dialect dialect) {
1097+
return definesDdlType( dialect, SqlTypes.VECTOR_INT8 );
1098+
}
1099+
}
1100+
1101+
public static class SupportsCosineDistance implements DialectFeatureCheck {
1102+
public boolean apply(Dialect dialect) {
1103+
return definesFunction( dialect, "cosine_distance" );
1104+
}
1105+
}
1106+
1107+
public static class SupportsEuclideanDistance implements DialectFeatureCheck {
1108+
public boolean apply(Dialect dialect) {
1109+
return definesFunction( dialect, "euclidean_distance" );
1110+
}
1111+
}
1112+
1113+
public static class SupportsTaxicabDistance implements DialectFeatureCheck {
1114+
public boolean apply(Dialect dialect) {
1115+
return definesFunction( dialect, "taxicab_distance" );
1116+
}
1117+
}
1118+
1119+
public static class SupportsHammingDistance implements DialectFeatureCheck {
1120+
public boolean apply(Dialect dialect) {
1121+
return definesFunction( dialect, "hamming_distance" );
1122+
}
1123+
}
1124+
1125+
public static class SupportsInnerProduct implements DialectFeatureCheck {
1126+
public boolean apply(Dialect dialect) {
1127+
return definesFunction( dialect, "inner_product" );
1128+
}
1129+
}
1130+
1131+
public static class SupportsVectorDims implements DialectFeatureCheck {
1132+
public boolean apply(Dialect dialect) {
1133+
return definesFunction( dialect, "vector_dims" );
1134+
}
1135+
}
1136+
1137+
public static class SupportsVectorNorm implements DialectFeatureCheck {
1138+
public boolean apply(Dialect dialect) {
1139+
return definesFunction( dialect, "vector_norm" );
1140+
}
1141+
}
1142+
10791143
public static class IsJtds implements DialectFeatureCheck {
10801144
public boolean apply(Dialect dialect) {
10811145
return dialect instanceof SybaseDialect && ( (SybaseDialect) dialect ).getDriverKind() == SybaseDriverKind.JTDS;
@@ -1141,7 +1205,7 @@ public boolean apply(Dialect dialect) {
11411205
}
11421206
}
11431207

1144-
private static final HashMap<Dialect, SqmFunctionRegistry> FUNCTION_REGISTRIES = new HashMap<>();
1208+
private static final HashMap<Dialect, FakeFunctionContributions> FUNCTION_CONTRIBUTIONS = new HashMap<>();
11451209

11461210
public static boolean definesFunction(Dialect dialect, String functionName) {
11471211
return getSqmFunctionRegistry( dialect ).findFunctionDescriptor( functionName ) != null;
@@ -1151,6 +1215,11 @@ public static boolean definesSetReturningFunction(Dialect dialect, String functi
11511215
return getSqmFunctionRegistry( dialect ).findSetReturningFunctionDescriptor( functionName ) != null;
11521216
}
11531217

1218+
public static boolean definesDdlType(Dialect dialect, int typeCode) {
1219+
final DdlTypeRegistry ddlTypeRegistry = getFunctionContributions( dialect ).typeConfiguration.getDdlTypeRegistry();
1220+
return ddlTypeRegistry.getDescriptor( typeCode ) != null;
1221+
}
1222+
11541223
public static class SupportsSubqueryInSelect implements DialectFeatureCheck {
11551224
@Override
11561225
public boolean apply(Dialect dialect) {
@@ -1172,24 +1241,33 @@ public boolean apply(Dialect dialect) {
11721241
}
11731242
}
11741243

1175-
11761244
private static SqmFunctionRegistry getSqmFunctionRegistry(Dialect dialect) {
1177-
SqmFunctionRegistry sqmFunctionRegistry = FUNCTION_REGISTRIES.get( dialect );
1178-
if ( sqmFunctionRegistry == null ) {
1245+
return getFunctionContributions( dialect ).functionRegistry;
1246+
}
1247+
1248+
private static FakeFunctionContributions getFunctionContributions(Dialect dialect) {
1249+
FakeFunctionContributions functionContributions = FUNCTION_CONTRIBUTIONS.get( dialect );
1250+
if ( functionContributions == null ) {
11791251
final TypeConfiguration typeConfiguration = new TypeConfiguration();
11801252
final SqmFunctionRegistry functionRegistry = new SqmFunctionRegistry();
11811253
typeConfiguration.scope( new FakeMetadataBuildingContext( typeConfiguration, functionRegistry ) );
11821254
final FakeTypeContributions typeContributions = new FakeTypeContributions( typeConfiguration );
1183-
final FakeFunctionContributions functionContributions = new FakeFunctionContributions(
1255+
functionContributions = new FakeFunctionContributions(
11841256
dialect,
11851257
typeConfiguration,
11861258
functionRegistry
11871259
);
11881260
dialect.contribute( typeContributions, typeConfiguration.getServiceRegistry() );
11891261
dialect.initializeFunctionRegistry( functionContributions );
1190-
FUNCTION_REGISTRIES.put( dialect, sqmFunctionRegistry = functionContributions.functionRegistry );
1262+
for ( TypeContributor typeContributor : ServiceLoader.load( TypeContributor.class ) ) {
1263+
typeContributor.contribute( typeContributions, typeConfiguration.getServiceRegistry() );
1264+
}
1265+
for ( FunctionContributor functionContributor : ServiceLoader.load( FunctionContributor.class ) ) {
1266+
functionContributor.contributeFunctions( functionContributions );
1267+
}
1268+
FUNCTION_CONTRIBUTIONS.put( dialect, functionContributions );
11911269
}
1192-
return sqmFunctionRegistry;
1270+
return functionContributions;
11931271
}
11941272

11951273
public static class FakeTypeContributions implements TypeContributions {

hibernate-vector/src/main/java/org/hibernate/vector/AbstractOracleVectorJdbcType.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import java.sql.ResultSet;
1010
import java.sql.SQLException;
1111

12+
import org.checkerframework.checker.nullness.qual.Nullable;
1213
import org.hibernate.dialect.Dialect;
14+
import org.hibernate.metamodel.mapping.JdbcMapping;
1315
import org.hibernate.sql.ast.spi.SqlAppender;
1416
import org.hibernate.type.SqlTypes;
1517
import org.hibernate.type.descriptor.ValueBinder;
@@ -43,13 +45,13 @@ public AbstractOracleVectorJdbcType(JdbcType elementJdbcType, boolean isVectorSu
4345
this.isVectorSupported = isVectorSupported;
4446
}
4547

46-
public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect);
47-
4848
@Override
49-
public int getDefaultSqlTypeCode() {
50-
return SqlTypes.VECTOR;
49+
public @Nullable String castToPattern(JdbcMapping targetJdbcMapping) {
50+
return targetJdbcMapping.getJdbcType().isStringLike() ? "from_vector(?1 returning ?2)" : null;
5151
}
5252

53+
public abstract void appendWriteExpression(String writeExpression, SqlAppender appender, Dialect dialect);
54+
5355
@Override
5456
public <T> JdbcLiteralFormatter<T> getJdbcLiteralFormatter(JavaType<T> javaTypeDescriptor) {
5557
final JavaType<T> elementJavaType;

hibernate-vector/src/main/java/org/hibernate/vector/MariaDBFunctionContributor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public class MariaDBFunctionContributor implements FunctionContributor {
1313
@Override
1414
public void contributeFunctions(FunctionContributions functionContributions) {
1515
final Dialect dialect = functionContributions.getDialect();
16-
if ( dialect instanceof MariaDBDialect ) {
16+
if ( dialect instanceof MariaDBDialect && dialect.getVersion().isSameOrAfter( 11, 7 ) ) {
1717
final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions );
1818

1919
vectorFunctionFactory.cosineDistance( "vec_distance_cosine(?1,?2)" );

0 commit comments

Comments
 (0)