diff --git a/CHANGELOG.md b/CHANGELOG.md index 02126b7fe..bb0357e37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Supported `sql.Null*` from `database/sql` as query params in `toValue` func + ## v3.118.0 * Added support for nullable `Date32`, `Datetime64`, `Timestamp64`, and `Interval64` types in the `optional` parameter builder * Added method `query.WithIssuesHandler` to get query issues diff --git a/internal/bind/params.go b/internal/bind/params.go index 3b4c13de6..257066d4b 100644 --- a/internal/bind/params.go +++ b/internal/bind/params.go @@ -9,6 +9,7 @@ import ( "net/url" "reflect" "sort" + "strings" "time" "github.com/google/uuid" @@ -47,6 +48,84 @@ func asUUID(v any) (value.Value, bool) { return nil, false } +func asSQLNull(v any) (value.Value, bool) { + switch x := v.(type) { + case sql.NullBool: + return wrapWithNulls(x.Valid, value.BoolValue(x.Bool), types.Bool), true + case sql.NullFloat64: + return wrapWithNulls(x.Valid, value.DoubleValue(x.Float64), types.Double), true + case sql.NullInt16: + return wrapWithNulls(x.Valid, value.Int16Value(x.Int16), types.Int16), true + case sql.NullInt32: + return wrapWithNulls(x.Valid, value.Int32Value(x.Int32), types.Int32), true + case sql.NullInt64: + return wrapWithNulls(x.Valid, value.Int64Value(x.Int64), types.Int64), true + case sql.NullString: + return wrapWithNulls(x.Valid, value.TextValue(x.String), types.Text), true + case sql.NullTime: + return wrapWithNulls(x.Valid, value.TimestampValueFromTime(x.Time), types.Timestamp), true + } + + return asSQLNullGeneric(v) +} + +func wrapWithNulls(valid bool, val value.Value, t types.Type) value.Value { + if valid { + return value.OptionalValue(val) + } + + return value.NullValue(t) +} + +func asSQLNullGeneric(v any) (value.Value, bool) { + if v == nil { + return nil, false + } + + rv := reflect.ValueOf(v) + rt := rv.Type() + + if rv.Kind() != reflect.Struct { + return nil, false + } + + vField := rv.FieldByName("V") + validField := rv.FieldByName("Valid") + + if !vField.IsValid() || !validField.IsValid() { + return nil, false + } + + if validField.Kind() != reflect.Bool { + return nil, false + } + + if !strings.HasPrefix(rt.String(), "sql.Null[") { + return nil, false + } + + valid := validField.Bool() + if !valid { + nullType, err := toType(vField.Interface()) + if err != nil { + return value.NullValue(types.Text), true + } + + return value.NullValue(nullType), true + } + + return asSQLNullValue(vField.Interface()) +} + +func asSQLNullValue(v any) (value.Value, bool) { + val, err := toValue(v) + if err != nil { + return nil, false + } + + return value.OptionalValue(val), true +} + func toType(v any) (_ types.Type, err error) { //nolint:funlen switch x := v.(type) { case bool: @@ -163,6 +242,10 @@ func toValue(v any) (_ value.Value, err error) { return x, nil } + if nullValue, ok := asSQLNull(v); ok { + return nullValue, nil + } + if valuer, ok := v.(driver.Valuer); ok { v, err = valuer.Value() if err != nil { diff --git a/internal/bind/params_test.go b/internal/bind/params_test.go index 54351df24..a4744e2b8 100644 --- a/internal/bind/params_test.go +++ b/internal/bind/params_test.go @@ -998,3 +998,151 @@ func BenchmarkAsUUIDUsingReflect(b *testing.B) { require.Equal(b, expUUIDValue, v) } } + +func TestSQLNullTypes(t *testing.T) { + tests := []struct { + name string + src any + expected value.Value + }{ + { + name: "NullBool valid", + src: sql.NullBool{Bool: true, Valid: true}, + expected: value.OptionalValue(value.BoolValue(true)), + }, + { + name: "NullBool invalid", + src: sql.NullBool{Bool: false, Valid: false}, + expected: value.NullValue(types.Bool), + }, + { + name: "NullFloat64 valid", + src: sql.NullFloat64{Float64: 3.14, Valid: true}, + expected: value.OptionalValue(value.DoubleValue(3.14)), + }, + { + name: "NullFloat64 invalid", + src: sql.NullFloat64{Float64: 0, Valid: false}, + expected: value.NullValue(types.Double), + }, + { + name: "NullInt16 valid", + src: sql.NullInt16{Int16: 42, Valid: true}, + expected: value.OptionalValue(value.Int16Value(42)), + }, + { + name: "NullInt16 invalid", + src: sql.NullInt16{Int16: 0, Valid: false}, + expected: value.NullValue(types.Int16), + }, + { + name: "NullInt32 valid", + src: sql.NullInt32{Int32: 42, Valid: true}, + expected: value.OptionalValue(value.Int32Value(42)), + }, + { + name: "NullInt32 invalid", + src: sql.NullInt32{Int32: 0, Valid: false}, + expected: value.NullValue(types.Int32), + }, + { + name: "NullInt64 valid", + src: sql.NullInt64{Int64: 42, Valid: true}, + expected: value.OptionalValue(value.Int64Value(42)), + }, + { + name: "NullInt64 invalid", + src: sql.NullInt64{Int64: 0, Valid: false}, + expected: value.NullValue(types.Int64), + }, + { + name: "NullString valid", + src: sql.NullString{String: "hello", Valid: true}, + expected: value.OptionalValue(value.TextValue("hello")), + }, + { + name: "NullString invalid", + src: sql.NullString{String: "", Valid: false}, + expected: value.NullValue(types.Text), + }, + { + name: "NullTime valid", + src: sql.NullTime{Time: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), Valid: true}, + expected: value.OptionalValue(value.TimestampValueFromTime(time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC))), + }, + { + name: "NullTime invalid", + src: sql.NullTime{Time: time.Time{}, Valid: false}, + expected: value.NullValue(types.Timestamp), + }, + { + name: "sql.Null[string] valid", + src: sql.Null[string]{V: "hello", Valid: true}, + expected: value.OptionalValue(value.TextValue("hello")), + }, + { + name: "sql.Null[string] invalid", + src: sql.Null[string]{V: "", Valid: false}, + expected: value.NullValue(types.Text), + }, + { + name: "sql.Null[int64] valid", + src: sql.Null[int64]{V: 42, Valid: true}, + expected: value.OptionalValue(value.Int64Value(42)), + }, + { + name: "sql.Null[int64] invalid", + src: sql.Null[int64]{V: 0, Valid: false}, + expected: value.NullValue(types.Int64), + }, + { + name: "sql.Null[bool] valid", + src: sql.Null[bool]{V: true, Valid: true}, + expected: value.OptionalValue(value.BoolValue(true)), + }, + { + name: "sql.Null[bool] invalid", + src: sql.Null[bool]{V: false, Valid: false}, + expected: value.NullValue(types.Bool), + }, + { + name: "sql.Null[float64] valid", + src: sql.Null[float64]{V: 3.14, Valid: true}, + expected: value.OptionalValue(value.DoubleValue(3.14)), + }, + { + name: "sql.Null[float64] invalid", + src: sql.Null[float64]{V: 0, Valid: false}, + expected: value.NullValue(types.Double), + }, + { + name: "sql.Null[time.Time] valid", + src: sql.Null[time.Time]{V: time.Date(2024, 2, 3, 4, 5, 6, 7, time.UTC), Valid: true}, + expected: value.OptionalValue(value.TimestampValueFromTime(time.Date(2024, 2, 3, 4, 5, 6, 7, time.UTC))), + }, + { + name: "sql.Null[time.Time] invalid", + src: sql.Null[time.Time]{V: time.Time{}, Valid: false}, + expected: value.NullValue(types.Timestamp), + }, + { + name: "sql.Null[[]byte] valid", + src: sql.Null[[]byte]{V: []byte("abc"), Valid: true}, + expected: value.OptionalValue(value.BytesValue([]byte("abc"))), + }, + { + name: "sql.Null[[]byte] invalid", + src: sql.Null[[]byte]{V: nil, Valid: false}, + expected: value.NullValue(types.Bytes), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := toValue(tt.src) + require.NoError(t, err) + require.Equal(t, tt.expected.Type(), result.Type()) + require.Equal(t, tt.expected.Yql(), result.Yql()) + }) + } +}