diff --git a/mapper.go b/mapper.go index cf06103..12d3d02 100644 --- a/mapper.go +++ b/mapper.go @@ -18,12 +18,11 @@ import ( "database/sql" "database/sql/driver" "errors" + "github.com/lazada/sqle/embed" "reflect" "sync" "time" "unsafe" - - "github.com/lazada/sqle/embed" ) type ctorFunc func(unsafe.Pointer) unsafe.Pointer @@ -149,7 +148,7 @@ func (m *Mapper) inspect(parent *structMap, offset uintptr, typ reflect.Type) *s if fieldtyp = field.Type; fieldtyp.Kind() == reflect.Ptr { fieldtyp, isptr = fieldtyp.Elem(), ptrMask } - if fieldtyp.Kind() == reflect.Struct && !scannable(fieldtyp) { + if fieldtyp.Kind() == reflect.Struct && !scannable(fieldtyp) && !scannable(reflect.PointerTo(fieldtyp)) { if s = m.inspect(smap, field.Offset|isptr, fieldtyp); s != nil { smap.aliases = append(smap.aliases, s.aliases...) smap.fields = append(smap.fields, s.fields...) diff --git a/mapper_test.go b/mapper_test.go index 92257bf..428dbd8 100644 --- a/mapper_test.go +++ b/mapper_test.go @@ -13,3 +13,57 @@ // limitations under the License. package sqle + +import ( + "database/sql" + "reflect" + "testing" + "time" +) + +type testStruct struct { + ID int + Name string + CreatedAt time.Time + UpdatedAt *time.Time + Price *sql.NullFloat64 + Price2 sql.NullFloat64 +} + +func TestInspect(t *testing.T) { + mapper := NewMapper("sql", nil) + typ := reflect.TypeOf(&testStruct{}).Elem() + + smap := mapper.inspect(nil, 0, typ) + + expectedAliases := []string{"ID", "Name", "CreatedAt", "UpdatedAt", "Price", "Price2"} + if !reflect.DeepEqual(smap.aliases, expectedAliases) { + t.Errorf("Expected aliases %v, but got %v", expectedAliases, smap.aliases) + } + + expectedFieldsCount := len(expectedAliases) + if len(smap.fields) != expectedFieldsCount { + t.Errorf("Expected %d fields, but got %d", expectedFieldsCount, len(smap.fields)) + } + + for i, field := range smap.fields { + expectedOffset := typ.Field(i).Offset + if field.offset != expectedOffset { + t.Errorf("Expected offset %d for field %s, but got %d", expectedOffset, expectedAliases[i], field.offset) + } + } + + expectedTypes := []reflect.Type{ + reflect.TypeOf(int(0)), + reflect.TypeOf(""), + reflect.TypeOf(time.Time{}), + reflect.TypeOf(&time.Time{}), + reflect.TypeOf(&sql.NullFloat64{}), + reflect.TypeOf(sql.NullFloat64{}), + } + for i, field := range smap.fields { + if field.typ != expectedTypes[i] { + t.Errorf("Expected type %v for field %s, but got %v", expectedTypes[i], expectedAliases[i], field.typ) + } + } +}