Skip to content

Commit 58504af

Browse files
authored
Correct sql.Null* support (#1903)
1 parent 1063504 commit 58504af

File tree

3 files changed

+233
-0
lines changed

3 files changed

+233
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Supported `sql.Null*` from `database/sql` as query params in `toValue` func
2+
13
## v3.118.0
24
* Added support for nullable `Date32`, `Datetime64`, `Timestamp64`, and `Interval64` types in the `optional` parameter builder
35
* Added method `query.WithIssuesHandler` to get query issues

internal/bind/params.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"net/url"
1010
"reflect"
1111
"sort"
12+
"strings"
1213
"time"
1314

1415
"github.com/google/uuid"
@@ -47,6 +48,84 @@ func asUUID(v any) (value.Value, bool) {
4748
return nil, false
4849
}
4950

51+
func asSQLNull(v any) (value.Value, bool) {
52+
switch x := v.(type) {
53+
case sql.NullBool:
54+
return wrapWithNulls(x.Valid, value.BoolValue(x.Bool), types.Bool), true
55+
case sql.NullFloat64:
56+
return wrapWithNulls(x.Valid, value.DoubleValue(x.Float64), types.Double), true
57+
case sql.NullInt16:
58+
return wrapWithNulls(x.Valid, value.Int16Value(x.Int16), types.Int16), true
59+
case sql.NullInt32:
60+
return wrapWithNulls(x.Valid, value.Int32Value(x.Int32), types.Int32), true
61+
case sql.NullInt64:
62+
return wrapWithNulls(x.Valid, value.Int64Value(x.Int64), types.Int64), true
63+
case sql.NullString:
64+
return wrapWithNulls(x.Valid, value.TextValue(x.String), types.Text), true
65+
case sql.NullTime:
66+
return wrapWithNulls(x.Valid, value.TimestampValueFromTime(x.Time), types.Timestamp), true
67+
}
68+
69+
return asSQLNullGeneric(v)
70+
}
71+
72+
func wrapWithNulls(valid bool, val value.Value, t types.Type) value.Value {
73+
if valid {
74+
return value.OptionalValue(val)
75+
}
76+
77+
return value.NullValue(t)
78+
}
79+
80+
func asSQLNullGeneric(v any) (value.Value, bool) {
81+
if v == nil {
82+
return nil, false
83+
}
84+
85+
rv := reflect.ValueOf(v)
86+
rt := rv.Type()
87+
88+
if rv.Kind() != reflect.Struct {
89+
return nil, false
90+
}
91+
92+
vField := rv.FieldByName("V")
93+
validField := rv.FieldByName("Valid")
94+
95+
if !vField.IsValid() || !validField.IsValid() {
96+
return nil, false
97+
}
98+
99+
if validField.Kind() != reflect.Bool {
100+
return nil, false
101+
}
102+
103+
if !strings.HasPrefix(rt.String(), "sql.Null[") {
104+
return nil, false
105+
}
106+
107+
valid := validField.Bool()
108+
if !valid {
109+
nullType, err := toType(vField.Interface())
110+
if err != nil {
111+
return value.NullValue(types.Text), true
112+
}
113+
114+
return value.NullValue(nullType), true
115+
}
116+
117+
return asSQLNullValue(vField.Interface())
118+
}
119+
120+
func asSQLNullValue(v any) (value.Value, bool) {
121+
val, err := toValue(v)
122+
if err != nil {
123+
return nil, false
124+
}
125+
126+
return value.OptionalValue(val), true
127+
}
128+
50129
func toType(v any) (_ types.Type, err error) { //nolint:funlen
51130
switch x := v.(type) {
52131
case bool:
@@ -163,6 +242,10 @@ func toValue(v any) (_ value.Value, err error) {
163242
return x, nil
164243
}
165244

245+
if nullValue, ok := asSQLNull(v); ok {
246+
return nullValue, nil
247+
}
248+
166249
if valuer, ok := v.(driver.Valuer); ok {
167250
v, err = valuer.Value()
168251
if err != nil {

internal/bind/params_test.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,3 +998,151 @@ func BenchmarkAsUUIDUsingReflect(b *testing.B) {
998998
require.Equal(b, expUUIDValue, v)
999999
}
10001000
}
1001+
1002+
func TestSQLNullTypes(t *testing.T) {
1003+
tests := []struct {
1004+
name string
1005+
src any
1006+
expected value.Value
1007+
}{
1008+
{
1009+
name: "NullBool valid",
1010+
src: sql.NullBool{Bool: true, Valid: true},
1011+
expected: value.OptionalValue(value.BoolValue(true)),
1012+
},
1013+
{
1014+
name: "NullBool invalid",
1015+
src: sql.NullBool{Bool: false, Valid: false},
1016+
expected: value.NullValue(types.Bool),
1017+
},
1018+
{
1019+
name: "NullFloat64 valid",
1020+
src: sql.NullFloat64{Float64: 3.14, Valid: true},
1021+
expected: value.OptionalValue(value.DoubleValue(3.14)),
1022+
},
1023+
{
1024+
name: "NullFloat64 invalid",
1025+
src: sql.NullFloat64{Float64: 0, Valid: false},
1026+
expected: value.NullValue(types.Double),
1027+
},
1028+
{
1029+
name: "NullInt16 valid",
1030+
src: sql.NullInt16{Int16: 42, Valid: true},
1031+
expected: value.OptionalValue(value.Int16Value(42)),
1032+
},
1033+
{
1034+
name: "NullInt16 invalid",
1035+
src: sql.NullInt16{Int16: 0, Valid: false},
1036+
expected: value.NullValue(types.Int16),
1037+
},
1038+
{
1039+
name: "NullInt32 valid",
1040+
src: sql.NullInt32{Int32: 42, Valid: true},
1041+
expected: value.OptionalValue(value.Int32Value(42)),
1042+
},
1043+
{
1044+
name: "NullInt32 invalid",
1045+
src: sql.NullInt32{Int32: 0, Valid: false},
1046+
expected: value.NullValue(types.Int32),
1047+
},
1048+
{
1049+
name: "NullInt64 valid",
1050+
src: sql.NullInt64{Int64: 42, Valid: true},
1051+
expected: value.OptionalValue(value.Int64Value(42)),
1052+
},
1053+
{
1054+
name: "NullInt64 invalid",
1055+
src: sql.NullInt64{Int64: 0, Valid: false},
1056+
expected: value.NullValue(types.Int64),
1057+
},
1058+
{
1059+
name: "NullString valid",
1060+
src: sql.NullString{String: "hello", Valid: true},
1061+
expected: value.OptionalValue(value.TextValue("hello")),
1062+
},
1063+
{
1064+
name: "NullString invalid",
1065+
src: sql.NullString{String: "", Valid: false},
1066+
expected: value.NullValue(types.Text),
1067+
},
1068+
{
1069+
name: "NullTime valid",
1070+
src: sql.NullTime{Time: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), Valid: true},
1071+
expected: value.OptionalValue(value.TimestampValueFromTime(time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC))),
1072+
},
1073+
{
1074+
name: "NullTime invalid",
1075+
src: sql.NullTime{Time: time.Time{}, Valid: false},
1076+
expected: value.NullValue(types.Timestamp),
1077+
},
1078+
{
1079+
name: "sql.Null[string] valid",
1080+
src: sql.Null[string]{V: "hello", Valid: true},
1081+
expected: value.OptionalValue(value.TextValue("hello")),
1082+
},
1083+
{
1084+
name: "sql.Null[string] invalid",
1085+
src: sql.Null[string]{V: "", Valid: false},
1086+
expected: value.NullValue(types.Text),
1087+
},
1088+
{
1089+
name: "sql.Null[int64] valid",
1090+
src: sql.Null[int64]{V: 42, Valid: true},
1091+
expected: value.OptionalValue(value.Int64Value(42)),
1092+
},
1093+
{
1094+
name: "sql.Null[int64] invalid",
1095+
src: sql.Null[int64]{V: 0, Valid: false},
1096+
expected: value.NullValue(types.Int64),
1097+
},
1098+
{
1099+
name: "sql.Null[bool] valid",
1100+
src: sql.Null[bool]{V: true, Valid: true},
1101+
expected: value.OptionalValue(value.BoolValue(true)),
1102+
},
1103+
{
1104+
name: "sql.Null[bool] invalid",
1105+
src: sql.Null[bool]{V: false, Valid: false},
1106+
expected: value.NullValue(types.Bool),
1107+
},
1108+
{
1109+
name: "sql.Null[float64] valid",
1110+
src: sql.Null[float64]{V: 3.14, Valid: true},
1111+
expected: value.OptionalValue(value.DoubleValue(3.14)),
1112+
},
1113+
{
1114+
name: "sql.Null[float64] invalid",
1115+
src: sql.Null[float64]{V: 0, Valid: false},
1116+
expected: value.NullValue(types.Double),
1117+
},
1118+
{
1119+
name: "sql.Null[time.Time] valid",
1120+
src: sql.Null[time.Time]{V: time.Date(2024, 2, 3, 4, 5, 6, 7, time.UTC), Valid: true},
1121+
expected: value.OptionalValue(value.TimestampValueFromTime(time.Date(2024, 2, 3, 4, 5, 6, 7, time.UTC))),
1122+
},
1123+
{
1124+
name: "sql.Null[time.Time] invalid",
1125+
src: sql.Null[time.Time]{V: time.Time{}, Valid: false},
1126+
expected: value.NullValue(types.Timestamp),
1127+
},
1128+
{
1129+
name: "sql.Null[[]byte] valid",
1130+
src: sql.Null[[]byte]{V: []byte("abc"), Valid: true},
1131+
expected: value.OptionalValue(value.BytesValue([]byte("abc"))),
1132+
},
1133+
{
1134+
name: "sql.Null[[]byte] invalid",
1135+
src: sql.Null[[]byte]{V: nil, Valid: false},
1136+
expected: value.NullValue(types.Bytes),
1137+
},
1138+
}
1139+
1140+
for _, tt := range tests {
1141+
t.Run(tt.name, func(t *testing.T) {
1142+
result, err := toValue(tt.src)
1143+
require.NoError(t, err)
1144+
require.Equal(t, tt.expected.Type(), result.Type())
1145+
require.Equal(t, tt.expected.Yql(), result.Yql())
1146+
})
1147+
}
1148+
}

0 commit comments

Comments
 (0)