Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Supported `sql.Null*` from `database/sql` as query params in `toValue` func

## v3.118.0
* Added support for nullable `Date32`, `Datetime64`, `Timestamp64`, and `Interval64` types in the `optional` parameter builder
* Added method `query.WithIssuesHandler` to get query issues
Expand Down
83 changes: 83 additions & 0 deletions internal/bind/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/url"
"reflect"
"sort"
"strings"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -47,6 +48,84 @@ func asUUID(v any) (value.Value, bool) {
return nil, false
}

func asSQLNull(v any) (value.Value, bool) {
switch x := v.(type) {
case sql.NullBool:
return wrapWithNulls(x.Valid, value.BoolValue(x.Bool), types.Bool), true
case sql.NullFloat64:
return wrapWithNulls(x.Valid, value.DoubleValue(x.Float64), types.Double), true
case sql.NullInt16:
return wrapWithNulls(x.Valid, value.Int16Value(x.Int16), types.Int16), true
case sql.NullInt32:
return wrapWithNulls(x.Valid, value.Int32Value(x.Int32), types.Int32), true
case sql.NullInt64:
return wrapWithNulls(x.Valid, value.Int64Value(x.Int64), types.Int64), true
case sql.NullString:
return wrapWithNulls(x.Valid, value.TextValue(x.String), types.Text), true
case sql.NullTime:
return wrapWithNulls(x.Valid, value.TimestampValueFromTime(x.Time), types.Timestamp), true
}

return asSQLNullGeneric(v)
}

func wrapWithNulls(valid bool, val value.Value, t types.Type) value.Value {
if valid {
return value.OptionalValue(val)
}

return value.NullValue(t)
}

func asSQLNullGeneric(v any) (value.Value, bool) {
if v == nil {
return nil, false
}

rv := reflect.ValueOf(v)
rt := rv.Type()

if rv.Kind() != reflect.Struct {
return nil, false
}

vField := rv.FieldByName("V")
validField := rv.FieldByName("Valid")

if !vField.IsValid() || !validField.IsValid() {
return nil, false
}

if validField.Kind() != reflect.Bool {
return nil, false
}

if !strings.HasPrefix(rt.String(), "sql.Null[") {
return nil, false
}

valid := validField.Bool()
if !valid {
nullType, err := toType(vField.Interface())
if err != nil {
return value.NullValue(types.Text), true
}

return value.NullValue(nullType), true
}

return asSQLNullValue(vField.Interface())
}

func asSQLNullValue(v any) (value.Value, bool) {
val, err := toValue(v)
if err != nil {
return nil, false
}

return value.OptionalValue(val), true
}

func toType(v any) (_ types.Type, err error) { //nolint:funlen
switch x := v.(type) {
case bool:
Expand Down Expand Up @@ -163,6 +242,10 @@ func toValue(v any) (_ value.Value, err error) {
return x, nil
}

if nullValue, ok := asSQLNull(v); ok {
return nullValue, nil
}

if valuer, ok := v.(driver.Valuer); ok {
v, err = valuer.Value()
if err != nil {
Expand Down
148 changes: 148 additions & 0 deletions internal/bind/params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -998,3 +998,151 @@ func BenchmarkAsUUIDUsingReflect(b *testing.B) {
require.Equal(b, expUUIDValue, v)
}
}

func TestSQLNullTypes(t *testing.T) {
tests := []struct {
name string
src any
expected value.Value
}{
{
name: "NullBool valid",
src: sql.NullBool{Bool: true, Valid: true},
expected: value.OptionalValue(value.BoolValue(true)),
},
{
name: "NullBool invalid",
src: sql.NullBool{Bool: false, Valid: false},
expected: value.NullValue(types.Bool),
},
{
name: "NullFloat64 valid",
src: sql.NullFloat64{Float64: 3.14, Valid: true},
expected: value.OptionalValue(value.DoubleValue(3.14)),
},
{
name: "NullFloat64 invalid",
src: sql.NullFloat64{Float64: 0, Valid: false},
expected: value.NullValue(types.Double),
},
{
name: "NullInt16 valid",
src: sql.NullInt16{Int16: 42, Valid: true},
expected: value.OptionalValue(value.Int16Value(42)),
},
{
name: "NullInt16 invalid",
src: sql.NullInt16{Int16: 0, Valid: false},
expected: value.NullValue(types.Int16),
},
{
name: "NullInt32 valid",
src: sql.NullInt32{Int32: 42, Valid: true},
expected: value.OptionalValue(value.Int32Value(42)),
},
{
name: "NullInt32 invalid",
src: sql.NullInt32{Int32: 0, Valid: false},
expected: value.NullValue(types.Int32),
},
{
name: "NullInt64 valid",
src: sql.NullInt64{Int64: 42, Valid: true},
expected: value.OptionalValue(value.Int64Value(42)),
},
{
name: "NullInt64 invalid",
src: sql.NullInt64{Int64: 0, Valid: false},
expected: value.NullValue(types.Int64),
},
{
name: "NullString valid",
src: sql.NullString{String: "hello", Valid: true},
expected: value.OptionalValue(value.TextValue("hello")),
},
{
name: "NullString invalid",
src: sql.NullString{String: "", Valid: false},
expected: value.NullValue(types.Text),
},
{
name: "NullTime valid",
src: sql.NullTime{Time: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), Valid: true},
expected: value.OptionalValue(value.TimestampValueFromTime(time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC))),
},
{
name: "NullTime invalid",
src: sql.NullTime{Time: time.Time{}, Valid: false},
expected: value.NullValue(types.Timestamp),
},
{
name: "sql.Null[string] valid",
src: sql.Null[string]{V: "hello", Valid: true},
expected: value.OptionalValue(value.TextValue("hello")),
},
{
name: "sql.Null[string] invalid",
src: sql.Null[string]{V: "", Valid: false},
expected: value.NullValue(types.Text),
},
{
name: "sql.Null[int64] valid",
src: sql.Null[int64]{V: 42, Valid: true},
expected: value.OptionalValue(value.Int64Value(42)),
},
{
name: "sql.Null[int64] invalid",
src: sql.Null[int64]{V: 0, Valid: false},
expected: value.NullValue(types.Int64),
},
{
name: "sql.Null[bool] valid",
src: sql.Null[bool]{V: true, Valid: true},
expected: value.OptionalValue(value.BoolValue(true)),
},
{
name: "sql.Null[bool] invalid",
src: sql.Null[bool]{V: false, Valid: false},
expected: value.NullValue(types.Bool),
},
{
name: "sql.Null[float64] valid",
src: sql.Null[float64]{V: 3.14, Valid: true},
expected: value.OptionalValue(value.DoubleValue(3.14)),
},
{
name: "sql.Null[float64] invalid",
src: sql.Null[float64]{V: 0, Valid: false},
expected: value.NullValue(types.Double),
},
{
name: "sql.Null[time.Time] valid",
src: sql.Null[time.Time]{V: time.Date(2024, 2, 3, 4, 5, 6, 7, time.UTC), Valid: true},
expected: value.OptionalValue(value.TimestampValueFromTime(time.Date(2024, 2, 3, 4, 5, 6, 7, time.UTC))),
},
{
name: "sql.Null[time.Time] invalid",
src: sql.Null[time.Time]{V: time.Time{}, Valid: false},
expected: value.NullValue(types.Timestamp),
},
{
name: "sql.Null[[]byte] valid",
src: sql.Null[[]byte]{V: []byte("abc"), Valid: true},
expected: value.OptionalValue(value.BytesValue([]byte("abc"))),
},
{
name: "sql.Null[[]byte] invalid",
src: sql.Null[[]byte]{V: nil, Valid: false},
expected: value.NullValue(types.Bytes),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := toValue(tt.src)
require.NoError(t, err)
require.Equal(t, tt.expected.Type(), result.Type())
require.Equal(t, tt.expected.Yql(), result.Yql())
})
}
}
Loading