Skip to content

Commit a7024e3

Browse files
committed
array arg wrapping
1 parent 2983a05 commit a7024e3

File tree

6 files changed

+47
-18
lines changed

6 files changed

+47
-18
lines changed

impl/arrays.go

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,48 @@ import (
88
"github.com/lib/pq"
99
)
1010

11-
func WrapForArray(a any) interface {
11+
type ValuerScanner interface {
1212
driver.Valuer
1313
sql.Scanner
14-
} {
14+
}
15+
16+
func WrapArray(a any) ValuerScanner {
17+
// TODO replace with own implementation
1518
return pq.Array(a)
1619
}
1720

18-
func ShouldWrapForArray(v reflect.Value) bool {
21+
func NeedsArrayWrappingForScanning(v reflect.Value) bool {
1922
t := v.Type()
2023
switch t.Kind() {
2124
case reflect.Slice:
22-
if t.Elem() == typeOfByte {
23-
return false // Byte slices are scanned as strings
24-
}
25-
return !v.Addr().Type().Implements(typeOfSQLScanner)
25+
// Byte slices are scanned as strings
26+
return t.Elem() != typeOfByte && !v.Addr().Type().Implements(typeOfSQLScanner)
2627
case reflect.Array:
2728
return !v.Addr().Type().Implements(typeOfSQLScanner)
2829
}
2930
return false
3031
}
3132

33+
func NeedsArrayWrappingForArg(arg any) bool {
34+
t := reflect.TypeOf(arg)
35+
switch t.Kind() {
36+
case reflect.Slice:
37+
// Byte slices are interpreted as strings
38+
return t.Elem() != typeOfByte && !t.Implements(typeOfDriverValuer)
39+
case reflect.Array:
40+
return !t.Implements(typeOfDriverValuer)
41+
}
42+
return false
43+
}
44+
45+
func WrapArrayArgs(args []any) {
46+
for i, arg := range args {
47+
if NeedsArrayWrappingForArg(arg) {
48+
args[i] = WrapArray(arg)
49+
}
50+
}
51+
}
52+
3253
// type ArrayScanner struct {
3354
// Dest reflect.Value
3455
// }

impl/arrays_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/domonda/go-types/nullable"
1010
)
1111

12-
func TestShouldWrapForArray(t *testing.T) {
12+
func TestNeedsArrayWrappingForScanning(t *testing.T) {
1313
tests := []struct {
1414
v reflect.Value
1515
want bool
@@ -27,8 +27,8 @@ func TestShouldWrapForArray(t *testing.T) {
2727
{v: reflect.ValueOf(new([]sql.NullString)).Elem(), want: true},
2828
}
2929
for _, tt := range tests {
30-
if got := ShouldWrapForArray(tt.v); got != tt.want {
31-
t.Errorf("shouldWrapArray() = %v, want %v", got, tt.want)
30+
if got := NeedsArrayWrappingForScanning(tt.v); got != tt.want {
31+
t.Errorf("NeedsArrayWrappingForScanning() = %v, want %v", got, tt.want)
3232
}
3333
}
3434
}

impl/foreachrow.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package impl
33
import (
44
"context"
55
"database/sql"
6+
"database/sql/driver"
67
"fmt"
78
"reflect"
89
"time"
@@ -11,12 +12,13 @@ import (
1112
)
1213

1314
var (
14-
typeOfError = reflect.TypeOf((*error)(nil)).Elem()
15-
typeOfContext = reflect.TypeOf((*context.Context)(nil)).Elem()
16-
typeOfSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
17-
typeOfTime = reflect.TypeOf(time.Time{})
18-
typeOfByte = reflect.TypeOf(byte(0))
19-
typeOfByteSlice = reflect.TypeOf((*[]byte)(nil)).Elem()
15+
typeOfError = reflect.TypeFor[error]()
16+
typeOfContext = reflect.TypeFor[context.Context]()
17+
typeOfSQLScanner = reflect.TypeFor[sql.Scanner]()
18+
typeOfDriverValuer = reflect.TypeFor[driver.Valuer]()
19+
typeOfTime = reflect.TypeFor[time.Time]()
20+
typeOfByte = reflect.TypeFor[byte]()
21+
typeOfByteSlice = reflect.TypeFor[[]byte]()
2022
)
2123

2224
// ForEachRowCallFunc will call the passed callback with scanned values or a struct for every row.

impl/reflectstruct.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ func reflectStructColumnPointers(structVal reflect.Value, namer sqldb.StructFiel
103103
// If field is a slice or array that does not implement sql.Scanner
104104
// and it's not a string scannable []byte type underneath
105105
// then wrap it with WrapForArray to make it scannable
106-
if ShouldWrapForArray(fieldValue) {
107-
pointer = WrapForArray(pointer)
106+
if NeedsArrayWrappingForScanning(fieldValue) {
107+
pointer = WrapArray(pointer)
108108
}
109109
pointers[colIndex] = pointer
110110
}

pqconn/connection.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ func (conn *connection) Now() (time.Time, error) {
111111
}
112112

113113
func (conn *connection) Exec(query string, args ...any) error {
114+
impl.WrapArrayArgs(args)
114115
_, err := conn.db.ExecContext(conn.ctx, query, args...)
115116
return wrapError(err, query, argFmt, args)
116117
}
@@ -162,6 +163,7 @@ func (conn *connection) UpsertStruct(table string, rowStruct any, ignoreColumns
162163
}
163164

164165
func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner {
166+
impl.WrapArrayArgs(args)
165167
rows, err := conn.db.QueryContext(conn.ctx, query, args...)
166168
if err != nil {
167169
err = wrapError(err, query, argFmt, args)
@@ -171,6 +173,7 @@ func (conn *connection) QueryRow(query string, args ...any) sqldb.RowScanner {
171173
}
172174

173175
func (conn *connection) QueryRows(query string, args ...any) sqldb.RowsScanner {
176+
impl.WrapArrayArgs(args)
174177
rows, err := conn.db.QueryContext(conn.ctx, query, args...)
175178
if err != nil {
176179
err = wrapError(err, query, argFmt, args)

pqconn/transaction.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ func (conn *transaction) Now() (time.Time, error) {
6868
}
6969

7070
func (conn *transaction) Exec(query string, args ...any) error {
71+
impl.WrapArrayArgs(args)
7172
_, err := conn.tx.Exec(query, args...)
7273
return impl.WrapNonNilErrorWithQuery(err, query, argFmt, args)
7374
}
@@ -117,6 +118,7 @@ func (conn *transaction) InsertStructs(table string, rowStructs any, ignoreColum
117118
}
118119

119120
func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner {
121+
impl.WrapArrayArgs(args)
120122
rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...)
121123
if err != nil {
122124
err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args)
@@ -126,6 +128,7 @@ func (conn *transaction) QueryRow(query string, args ...any) sqldb.RowScanner {
126128
}
127129

128130
func (conn *transaction) QueryRows(query string, args ...any) sqldb.RowsScanner {
131+
impl.WrapArrayArgs(args)
129132
rows, err := conn.tx.QueryContext(conn.parent.ctx, query, args...)
130133
if err != nil {
131134
err = impl.WrapNonNilErrorWithQuery(err, query, argFmt, args)

0 commit comments

Comments
 (0)