Skip to content

Commit 0d891fd

Browse files
committed
Additional handling of vector types.
1 parent 9bd98b0 commit 0d891fd

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

enginetest/server_engine.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package enginetest
1616

1717
import (
1818
gosql "database/sql"
19+
"encoding/binary"
1920
"encoding/json"
2021
"errors"
2122
"fmt"
@@ -397,6 +398,11 @@ func convertValue(ctx *sql.Context, sch sql.Schema, row sql.Row) sql.Row {
397398
row[i] = r
398399
}
399400
}
401+
case query.Type_VECTOR:
402+
r := row[i].([]byte)
403+
dimensions := len(r) / 4
404+
row[i] = make([]float32, dimensions)
405+
binary.Decode(r, binary.LittleEndian, row[i])
400406
case query.Type_TIME:
401407
if row[i] != nil {
402408
r, _, err := types.TimespanType_{}.Convert(ctx, string(row[i].([]byte)))
@@ -584,7 +590,7 @@ func emptyValuePointerForType(t sql.Type) (any, error) {
584590
case query.Type_FLOAT32, query.Type_FLOAT64:
585591
var f gosql.NullFloat64
586592
return &f, nil
587-
case query.Type_JSON, query.Type_BLOB, query.Type_TIME, query.Type_GEOMETRY:
593+
case query.Type_JSON, query.Type_BLOB, query.Type_TIME, query.Type_GEOMETRY, query.Type_VECTOR:
588594
var f []byte
589595
return &f, nil
590596
case query.Type_NULL_TYPE:
@@ -677,6 +683,9 @@ func convertGoSqlType(columnType *gosql.ColumnType) (sql.Type, error) {
677683
return types.Null, nil
678684
case "geometry":
679685
return types.GeometryType{}, nil
686+
case "vector":
687+
length, _ := columnType.Length()
688+
return types.VectorType{Dimensions: int(length)}, nil
680689
default:
681690
return nil, fmt.Errorf("unhandled type %s", columnType.DatabaseTypeName())
682691
}

sql/types/vector.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ var vectorValueType = reflect.TypeOf([]float32{})
3333
// VectorType represents the VECTOR(N) type.
3434
// It stores a fixed-length array of N floating point numbers.
3535
type VectorType struct {
36+
// The number of floats in the vector.
37+
// If Dimensions is 0, then the type can hold a variable number of floats, but this is only used
38+
// as the return type of some functions, and in values sent over the wire.
3639
Dimensions int
3740
}
3841

@@ -99,15 +102,15 @@ func (t VectorType) Convert(ctx context.Context, v interface{}) (interface{}, sq
99102
}
100103
return result, sql.InRange, nil
101104
case []float32:
102-
if len(val) != t.Dimensions {
105+
if t.Dimensions != 0 && len(val) != t.Dimensions {
103106
return nil, sql.OutOfRange, fmt.Errorf("VECTOR dimension mismatch: expected %d, got %d", t.Dimensions, len(val))
104107
}
105108
return val, sql.InRange, nil
106109
case []interface{}:
107-
if len(val) != t.Dimensions {
110+
if t.Dimensions != 0 && len(val) != t.Dimensions {
108111
return nil, sql.OutOfRange, fmt.Errorf("VECTOR dimension mismatch: expected %d, got %d", t.Dimensions, len(val))
109112
}
110-
result := make([]float32, t.Dimensions)
113+
result := make([]float32, len(val))
111114
for i, elem := range val {
112115
switch e := elem.(type) {
113116
case float64:
@@ -182,7 +185,7 @@ func (t VectorType) String() string {
182185

183186
// Type implements Type interface.
184187
func (t VectorType) Type() query.Type {
185-
return sqltypes.TypeJSON
188+
return sqltypes.Vector
186189
}
187190

188191
// ValueType implements Type interface.

0 commit comments

Comments
 (0)