Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Supported `sql.Null*` from `database/sql` as query params in `toValue` func

## v3.117.1
* Fixed scan a column of type `Decimal(precision,scale)` into a struct field of type `types.Decimal{}` using `ScanStruct()`
* Fixed race in integration test `TestTopicWriterLogMessagesWithoutData`
Expand Down
102 changes: 102 additions & 0 deletions internal/bind/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"net/url"
"reflect"
"sort"
"strings"
"time"

"github.com/google/uuid"
Expand Down Expand Up @@ -47,6 +48,103 @@
return nil, false
}

func asSQLNull(v any) (value.Value, bool) {
switch x := v.(type) {
case sql.NullBool:
if x.Valid {
return value.OptionalValue(value.BoolValue(x.Bool)), true
}

Check failure on line 57 in internal/bind/params.go

View workflow job for this annotation

GitHub Actions / golangci-lint

File is not properly formatted (gci)
return value.NullValue(types.Bool), true
case sql.NullFloat64:
if x.Valid {
return value.OptionalValue(value.DoubleValue(x.Float64)), true
}

return value.NullValue(types.Double), true
case sql.NullInt16:
if x.Valid {
return value.OptionalValue(value.Int16Value(x.Int16)), true
}

return value.NullValue(types.Int16), true
case sql.NullInt32:
if x.Valid {
return value.OptionalValue(value.Int32Value(x.Int32)), true
}

return value.NullValue(types.Int32), true
case sql.NullInt64:
if x.Valid {
return value.OptionalValue(value.Int64Value(x.Int64)), true
}

return value.NullValue(types.Int64), true
case sql.NullString:
if x.Valid {
return value.OptionalValue(value.TextValue(x.String)), true
}

return value.NullValue(types.Text), true
case sql.NullTime:
if x.Valid {
return value.OptionalValue(value.TimestampValueFromTime(x.Time)), true
}

return value.NullValue(types.Timestamp), true
}

return asSQLNullGeneric(v)
}

func asSQLNullGeneric(v any) (value.Value, bool) {
if v == nil {
return nil, false
}

rv := reflect.ValueOf(v)
rt := rv.Type()

if rv.Kind() != reflect.Struct {
return nil, false
}

vField := rv.FieldByName("V")
validField := rv.FieldByName("Valid")

if !vField.IsValid() || !validField.IsValid() {
return nil, false
}

if validField.Kind() != reflect.Bool {
return nil, false
}

if !strings.HasPrefix(rt.String(), "sql.Null[") {
return nil, false
}

valid := validField.Bool()
if !valid {
nullType, err := toType(vField.Interface())
if err != nil {
return value.NullValue(types.Text), true
}
return value.NullValue(nullType), true

Check failure on line 133 in internal/bind/params.go

View workflow job for this annotation

GitHub Actions / golangci-lint

return with no blank line before (nlreturn)
}

return asSQLNullValue(vField.Interface())
}

func asSQLNullValue(v any) (value.Value, bool) {
val, err := toValue(v)
if err != nil {
return nil, false
}

return value.OptionalValue(val), true
}

func toType(v any) (_ types.Type, err error) { //nolint:funlen
switch x := v.(type) {
case bool:
Expand Down Expand Up @@ -163,6 +261,10 @@
return x, nil
}

if nullValue, ok := asSQLNull(v); ok {
return nullValue, nil
}

if valuer, ok := v.(driver.Valuer); ok {
v, err = valuer.Value()
if err != nil {
Expand Down
108 changes: 108 additions & 0 deletions internal/bind/params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -998,3 +998,111 @@ func BenchmarkAsUUIDUsingReflect(b *testing.B) {
require.Equal(b, expUUIDValue, v)
}
}

func TestSQLNullTypes(t *testing.T) {
tests := []struct {
name string
src any
expected value.Value
}{
{
name: "NullBool valid",
src: sql.NullBool{Bool: true, Valid: true},
expected: value.OptionalValue(value.BoolValue(true)),
},
{
name: "NullBool invalid",
src: sql.NullBool{Bool: false, Valid: false},
expected: value.NullValue(types.Bool),
},
{
name: "NullFloat64 valid",
src: sql.NullFloat64{Float64: 3.14, Valid: true},
expected: value.OptionalValue(value.DoubleValue(3.14)),
},
{
name: "NullFloat64 invalid",
src: sql.NullFloat64{Float64: 0, Valid: false},
expected: value.NullValue(types.Double),
},
{
name: "NullInt16 valid",
src: sql.NullInt16{Int16: 42, Valid: true},
expected: value.OptionalValue(value.Int16Value(42)),
},
{
name: "NullInt16 invalid",
src: sql.NullInt16{Int16: 0, Valid: false},
expected: value.NullValue(types.Int16),
},
{
name: "NullInt32 valid",
src: sql.NullInt32{Int32: 42, Valid: true},
expected: value.OptionalValue(value.Int32Value(42)),
},
{
name: "NullInt32 invalid",
src: sql.NullInt32{Int32: 0, Valid: false},
expected: value.NullValue(types.Int32),
},
{
name: "NullInt64 valid",
src: sql.NullInt64{Int64: 42, Valid: true},
expected: value.OptionalValue(value.Int64Value(42)),
},
{
name: "NullInt64 invalid",
src: sql.NullInt64{Int64: 0, Valid: false},
expected: value.NullValue(types.Int64),
},
{
name: "NullString valid",
src: sql.NullString{String: "hello", Valid: true},
expected: value.OptionalValue(value.TextValue("hello")),
},
{
name: "NullString invalid",
src: sql.NullString{String: "", Valid: false},
expected: value.NullValue(types.Text),
},
{
name: "NullTime valid",
src: sql.NullTime{Time: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), Valid: true},
expected: value.OptionalValue(value.TimestampValueFromTime(time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC))),
},
{
name: "NullTime invalid",
src: sql.NullTime{Time: time.Time{}, Valid: false},
expected: value.NullValue(types.Timestamp),
},
{
name: "sql.Null[string] valid",
src: sql.Null[string]{V: "hello", Valid: true},
expected: value.OptionalValue(value.TextValue("hello")),
},
{
name: "sql.Null[string] invalid",
src: sql.Null[string]{V: "", Valid: false},
expected: value.NullValue(types.Text),
},
{
name: "sql.Null[int64] valid",
src: sql.Null[int64]{V: 42, Valid: true},
expected: value.OptionalValue(value.Int64Value(42)),
},
{
name: "sql.Null[int64] invalid",
src: sql.Null[int64]{V: 0, Valid: false},
expected: value.NullValue(types.Int64),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := toValue(tt.src)
require.NoError(t, err)
require.Equal(t, tt.expected.Type(), result.Type())
require.Equal(t, tt.expected.Yql(), result.Yql())
})
}
}
Loading