Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ import (
"database/sql"
"database/sql/driver"
"errors"
"github.com/lazada/sqle/embed"
"reflect"
"sync"
"time"
"unsafe"

"github.com/lazada/sqle/embed"
)

type ctorFunc func(unsafe.Pointer) unsafe.Pointer
Expand Down Expand Up @@ -149,7 +148,7 @@ func (m *Mapper) inspect(parent *structMap, offset uintptr, typ reflect.Type) *s
if fieldtyp = field.Type; fieldtyp.Kind() == reflect.Ptr {
fieldtyp, isptr = fieldtyp.Elem(), ptrMask
}
if fieldtyp.Kind() == reflect.Struct && !scannable(fieldtyp) {
if fieldtyp.Kind() == reflect.Struct && !scannable(fieldtyp) && !scannable(reflect.PointerTo(fieldtyp)) {
if s = m.inspect(smap, field.Offset|isptr, fieldtyp); s != nil {
smap.aliases = append(smap.aliases, s.aliases...)
smap.fields = append(smap.fields, s.fields...)
Expand Down
54 changes: 54 additions & 0 deletions mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,57 @@
// limitations under the License.

package sqle

import (
"database/sql"
"reflect"
"testing"
"time"
)

type testStruct struct {
ID int
Name string
CreatedAt time.Time
UpdatedAt *time.Time
Price *sql.NullFloat64
Price2 sql.NullFloat64
}

func TestInspect(t *testing.T) {
mapper := NewMapper("sql", nil)
typ := reflect.TypeOf(&testStruct{}).Elem()

smap := mapper.inspect(nil, 0, typ)

expectedAliases := []string{"ID", "Name", "CreatedAt", "UpdatedAt", "Price", "Price2"}
if !reflect.DeepEqual(smap.aliases, expectedAliases) {
t.Errorf("Expected aliases %v, but got %v", expectedAliases, smap.aliases)
}

expectedFieldsCount := len(expectedAliases)
if len(smap.fields) != expectedFieldsCount {
t.Errorf("Expected %d fields, but got %d", expectedFieldsCount, len(smap.fields))
}

for i, field := range smap.fields {
expectedOffset := typ.Field(i).Offset
if field.offset != expectedOffset {
t.Errorf("Expected offset %d for field %s, but got %d", expectedOffset, expectedAliases[i], field.offset)
}
}

expectedTypes := []reflect.Type{
reflect.TypeOf(int(0)),
reflect.TypeOf(""),
reflect.TypeOf(time.Time{}),
reflect.TypeOf(&time.Time{}),
reflect.TypeOf(&sql.NullFloat64{}),
reflect.TypeOf(sql.NullFloat64{}),
}
for i, field := range smap.fields {
if field.typ != expectedTypes[i] {
t.Errorf("Expected type %v for field %s, but got %v", expectedTypes[i], expectedAliases[i], field.typ)
}
}
}