Skip to content

Commit 2ae9c56

Browse files
committed
HHH-19710 Add vector support for SAP HANA Cloud
1 parent 5c2431d commit 2ae9c56

File tree

5 files changed

+273
-1
lines changed

5 files changed

+273
-1
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.boot.model.FunctionContributions;
8+
import org.hibernate.boot.model.FunctionContributor;
9+
import org.hibernate.dialect.Dialect;
10+
import org.hibernate.dialect.HANADialect;
11+
import org.hibernate.query.sqm.produce.function.StandardArgumentsValidators;
12+
import org.hibernate.query.sqm.produce.function.StandardFunctionArgumentTypeResolvers;
13+
import org.hibernate.query.sqm.produce.function.StandardFunctionReturnTypeResolvers;
14+
import org.hibernate.type.spi.TypeConfiguration;
15+
16+
import static org.hibernate.query.sqm.produce.function.FunctionParameterType.INTEGER;
17+
18+
public class HANAVectorFunctionContributor implements FunctionContributor {
19+
20+
@Override
21+
public void contributeFunctions(FunctionContributions functionContributions) {
22+
final Dialect dialect = functionContributions.getDialect();
23+
if ( dialect instanceof HANADialect hanaDialect && hanaDialect.isCloud() ) {
24+
final VectorFunctionFactory vectorFunctionFactory = new VectorFunctionFactory( functionContributions );
25+
26+
vectorFunctionFactory.cosineDistance( "cosine_similarity(?1,?2)" );
27+
vectorFunctionFactory.euclideanDistance( "l2distance(?1,?2)" );
28+
vectorFunctionFactory.euclideanSquaredDistance( "power(l2distance(?1,?2),2)" );
29+
30+
final TypeConfiguration typeConfiguration = functionContributions.getTypeConfiguration();
31+
vectorFunctionFactory.registerPatternVectorFunction(
32+
"vector_dims",
33+
"cardinality(?1)",
34+
typeConfiguration.getBasicTypeForJavaType( Integer.class ),
35+
1
36+
);
37+
vectorFunctionFactory.registerNamedVectorFunction(
38+
"l2norm",
39+
typeConfiguration.getBasicTypeForJavaType( Double.class ),
40+
1
41+
);
42+
functionContributions.getFunctionRegistry().registerAlternateKey( "vector_norm", "l2norm" );
43+
functionContributions.getFunctionRegistry().registerAlternateKey( "l2_norm", "l2norm" );
44+
45+
functionContributions.getFunctionRegistry().namedDescriptorBuilder( "subvector" )
46+
.setArgumentsValidator( StandardArgumentsValidators.composite(
47+
StandardArgumentsValidators.exactly( 3 ),
48+
VectorArgumentValidator.INSTANCE
49+
) )
50+
.setArgumentTypeResolver( StandardFunctionArgumentTypeResolvers.byArgument(
51+
VectorArgumentTypeResolver.INSTANCE,
52+
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER ),
53+
StandardFunctionArgumentTypeResolvers.invariant( typeConfiguration, INTEGER )
54+
) )
55+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) )
56+
.register();
57+
functionContributions.getFunctionRegistry().namedDescriptorBuilder( "l2normalize" )
58+
.setArgumentsValidator( VectorArgumentValidator.INSTANCE )
59+
.setArgumentTypeResolver( VectorArgumentTypeResolver.INSTANCE )
60+
.setReturnTypeResolver( StandardFunctionReturnTypeResolvers.useArgType( 1 ) )
61+
.register();
62+
functionContributions.getFunctionRegistry().registerAlternateKey( "l2_normalize", "l2normalize" );
63+
}
64+
}
65+
66+
@Override
67+
public int ordinal() {
68+
return 200;
69+
}
70+
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.checkerframework.checker.nullness.qual.Nullable;
8+
import org.hibernate.dialect.Dialect;
9+
import org.hibernate.engine.jdbc.Size;
10+
import org.hibernate.metamodel.mapping.JdbcMapping;
11+
import org.hibernate.sql.ast.spi.SqlAppender;
12+
import org.hibernate.type.descriptor.ValueExtractor;
13+
import org.hibernate.type.descriptor.WrapperOptions;
14+
import org.hibernate.type.descriptor.java.JavaType;
15+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
16+
import org.hibernate.type.descriptor.jdbc.BasicExtractor;
17+
import org.hibernate.type.descriptor.jdbc.JdbcType;
18+
import org.hibernate.type.spi.TypeConfiguration;
19+
20+
import java.sql.CallableStatement;
21+
import java.sql.ResultSet;
22+
import java.sql.SQLException;
23+
24+
import static org.hibernate.vector.internal.VectorHelper.parseFloatVector;
25+
26+
public class HANAVectorJdbcType extends ArrayJdbcType {
27+
28+
private final int sqlType;
29+
private final String typeName;
30+
31+
public HANAVectorJdbcType(JdbcType elementJdbcType, int sqlType, String typeName) {
32+
super( elementJdbcType );
33+
this.sqlType = sqlType;
34+
this.typeName = typeName;
35+
}
36+
37+
@Override
38+
public int getDefaultSqlTypeCode() {
39+
return sqlType;
40+
}
41+
42+
@Override
43+
public <T> JavaType<T> getJdbcRecommendedJavaTypeMapping(
44+
Integer precision,
45+
Integer scale,
46+
TypeConfiguration typeConfiguration) {
47+
return typeConfiguration.getJavaTypeRegistry().resolveDescriptor( float[].class );
48+
}
49+
50+
@Override
51+
public @Nullable String castToPattern(JdbcMapping targetJdbcMapping, @Nullable Size size) {
52+
final JdbcType jdbcType = targetJdbcMapping.getJdbcType();
53+
return jdbcType.isString()
54+
? jdbcType.isLob() ? "to_nclob(?1)" : "to_nvarchar(?1)"
55+
: null;
56+
}
57+
58+
@Override
59+
public void appendWriteExpression(
60+
String writeExpression,
61+
@Nullable Size size,
62+
SqlAppender appender,
63+
Dialect dialect) {
64+
appender.append( "to_" );
65+
appender.append( typeName );
66+
appender.append( '(');
67+
appender.append( writeExpression );
68+
appender.append( ')' );
69+
}
70+
71+
@Override
72+
public boolean isWriteExpressionTyped(Dialect dialect) {
73+
return true;
74+
}
75+
76+
@Override
77+
public @Nullable String castFromPattern(JdbcMapping sourceMapping, @Nullable Size size) {
78+
return sourceMapping.getJdbcType().isStringLike() ? "to_" + typeName + "(?1)" : null;
79+
}
80+
81+
@Override
82+
public <X> ValueExtractor<X> getExtractor(JavaType<X> javaTypeDescriptor) {
83+
return new BasicExtractor<>( javaTypeDescriptor, this ) {
84+
@Override
85+
protected X doExtract(ResultSet rs, int paramIndex, WrapperOptions options) throws SQLException {
86+
return javaTypeDescriptor.wrap( parseFloatVector( rs.getString( paramIndex ) ), options );
87+
}
88+
89+
@Override
90+
protected X doExtract(CallableStatement statement, int index, WrapperOptions options) throws SQLException {
91+
return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( index ) ), options );
92+
}
93+
94+
@Override
95+
protected X doExtract(CallableStatement statement, String name, WrapperOptions options) throws SQLException {
96+
return javaTypeDescriptor.wrap( parseFloatVector( statement.getString( name ) ), options );
97+
}
98+
};
99+
}
100+
101+
@Override
102+
public boolean equals(Object that) {
103+
return super.equals( that )
104+
&& that instanceof HANAVectorJdbcType vectorJdbcType
105+
&& sqlType == vectorJdbcType.sqlType;
106+
}
107+
108+
@Override
109+
public int hashCode() {
110+
return sqlType + 31 * super.hashCode();
111+
}
112+
}
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
* Copyright Red Hat Inc. and Hibernate Authors
4+
*/
5+
package org.hibernate.vector.internal;
6+
7+
import org.hibernate.boot.model.TypeContributions;
8+
import org.hibernate.boot.model.TypeContributor;
9+
import org.hibernate.dialect.Dialect;
10+
import org.hibernate.dialect.HANADialect;
11+
import org.hibernate.engine.jdbc.spi.JdbcServices;
12+
import org.hibernate.service.ServiceRegistry;
13+
import org.hibernate.type.BasicArrayType;
14+
import org.hibernate.type.BasicType;
15+
import org.hibernate.type.BasicTypeRegistry;
16+
import org.hibernate.type.SqlTypes;
17+
import org.hibernate.type.StandardBasicTypes;
18+
import org.hibernate.type.descriptor.java.spi.JavaTypeRegistry;
19+
import org.hibernate.type.descriptor.jdbc.ArrayJdbcType;
20+
import org.hibernate.type.descriptor.jdbc.spi.JdbcTypeRegistry;
21+
import org.hibernate.type.spi.TypeConfiguration;
22+
23+
public class HANAVectorTypeContributor implements TypeContributor {
24+
25+
@Override
26+
public void contribute(TypeContributions typeContributions, ServiceRegistry serviceRegistry) {
27+
final Dialect dialect = serviceRegistry.requireService( JdbcServices.class ).getDialect();
28+
if ( dialect instanceof HANADialect hanaDialect && hanaDialect.isCloud() ) {
29+
final TypeConfiguration typeConfiguration = typeContributions.getTypeConfiguration();
30+
final JavaTypeRegistry javaTypeRegistry = typeConfiguration.getJavaTypeRegistry();
31+
final JdbcTypeRegistry jdbcTypeRegistry = typeConfiguration.getJdbcTypeRegistry();
32+
final BasicTypeRegistry basicTypeRegistry = typeConfiguration.getBasicTypeRegistry();
33+
final BasicType<Float> floatBasicType = basicTypeRegistry.resolve( StandardBasicTypes.FLOAT );
34+
final ArrayJdbcType genericVectorJdbcType = new HANAVectorJdbcType(
35+
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
36+
SqlTypes.VECTOR,
37+
"real_vector"
38+
);
39+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR, genericVectorJdbcType );
40+
final ArrayJdbcType floatVectorJdbcType = new HANAVectorJdbcType(
41+
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
42+
SqlTypes.VECTOR_FLOAT32,
43+
"real_vector"
44+
);
45+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT32, floatVectorJdbcType );
46+
final ArrayJdbcType float16VectorJdbcType = new HANAVectorJdbcType(
47+
jdbcTypeRegistry.getDescriptor( SqlTypes.FLOAT ),
48+
SqlTypes.VECTOR_FLOAT16,
49+
"half_vector"
50+
);
51+
jdbcTypeRegistry.addDescriptor( SqlTypes.VECTOR_FLOAT16, float16VectorJdbcType );
52+
53+
basicTypeRegistry.register(
54+
new BasicArrayType<>(
55+
floatBasicType,
56+
genericVectorJdbcType,
57+
javaTypeRegistry.getDescriptor( float[].class )
58+
),
59+
StandardBasicTypes.VECTOR.getName()
60+
);
61+
basicTypeRegistry.register(
62+
new BasicArrayType<>(
63+
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
64+
floatVectorJdbcType,
65+
javaTypeRegistry.getDescriptor( float[].class )
66+
),
67+
StandardBasicTypes.VECTOR_FLOAT32.getName()
68+
);
69+
basicTypeRegistry.register(
70+
new BasicArrayType<>(
71+
basicTypeRegistry.resolve( StandardBasicTypes.FLOAT ),
72+
float16VectorJdbcType,
73+
javaTypeRegistry.getDescriptor( float[].class )
74+
),
75+
StandardBasicTypes.VECTOR_FLOAT16.getName()
76+
);
77+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
78+
new VectorDdlType( SqlTypes.VECTOR, "real_vector($l)", "real_vector", dialect )
79+
);
80+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
81+
new VectorDdlType( SqlTypes.VECTOR_FLOAT32, "real_vector($l)", "real_vector", dialect )
82+
);
83+
typeConfiguration.getDdlTypeRegistry().addDescriptor(
84+
new VectorDdlType( SqlTypes.VECTOR_FLOAT16, "half_vector($l)", "half_vector", dialect )
85+
);
86+
}
87+
}
88+
}

hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.FunctionContributor

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ org.hibernate.vector.internal.OracleVectorFunctionContributor
33
org.hibernate.vector.internal.MariaDBFunctionContributor
44
org.hibernate.vector.internal.MySQLFunctionContributor
55
org.hibernate.vector.internal.DB2VectorFunctionContributor
6-
org.hibernate.vector.internal.CockroachFunctionContributor
6+
org.hibernate.vector.internal.CockroachFunctionContributor
7+
org.hibernate.vector.internal.HANAVectorFunctionContributor

hibernate-vector/src/main/resources/META-INF/services/org.hibernate.boot.model.TypeContributor

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ org.hibernate.vector.internal.MariaDBTypeContributor
44
org.hibernate.vector.internal.MySQLTypeContributor
55
org.hibernate.vector.internal.DB2VectorTypeContributor
66
org.hibernate.vector.internal.CockroachTypeContributor
7+
org.hibernate.vector.internal.HANAVectorTypeContributor

0 commit comments

Comments
 (0)