Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 139 additions & 2 deletions internal/parser/enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,142 @@ import (
"fmt"
"go/ast"
"go/token"
"go/types"
"strings"

"github.com/davecgh/go-spew/spew"
"github.com/webrpc/webrpc/schema"
"golang.org/x/tools/go/packages"
)

func (p *Parser) ExtractEnumConsts(pkg *packages.Package) error {
enumMap := map[string]*schema.Type{}

// First pass: find all enum types with //gospeak:enum comment.
for _, file := range pkg.Syntax {
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}

for _, spec := range genDecl.Specs {
typeSpec := spec.(*ast.TypeSpec)

comments := []string{}

// Look for last comment in form of //gospeak:enum
if genDecl.Doc == nil || len(genDecl.Doc.List) == 0 {
continue
}
if genDecl.Doc.List[len(genDecl.Doc.List)-1].Text != "//gospeak:enum" {
continue
}
for _, comment := range genDecl.Doc.List[:(len(genDecl.Doc.List) - 1)] {
comments = append(comments, strings.TrimSpace(strings.TrimPrefix(comment.Text, "//")))
}

enumName := typeSpec.Name.Name
enumElemType := pkg.TypesInfo.TypeOf(typeSpec.Type)
if enumElemType == nil {
continue
}

enumType := &schema.Type{
Kind: schema.TypeKind_Enum,
Name: enumName,
Type: &schema.VarType{
Expr: enumElemType.String(),
Type: schema.T_Enum,
},
Fields: []*schema.TypeField{},
Comments: comments,
}

enumImportTypeName := fmt.Sprintf("%v.%v", p.Pkg.PkgPath, enumName)

// Save for second pass
enumMap[enumImportTypeName] = enumType

// Save to schema
p.Schema.Types = append(p.Schema.Types, enumType)
p.ParsedEnumTypes[enumImportTypeName] = enumType
}
}
}

// Second pass: collect consts
for _, file := range pkg.Syntax {
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.CONST {
continue
}

for _, spec := range genDecl.Specs {
valSpec := spec.(*ast.ValueSpec)

for _, ident := range valSpec.Names {
obj := pkg.TypesInfo.Defs[ident]
if obj == nil {
continue
}
constObj, ok := obj.(*types.Const)
if !ok {
continue
}

enumType, ok := enumMap[constObj.Type().String()]
if !ok {
continue
}

enumName := constObj.Type().String()
fmt.Println(enumName)

// Get value from trailing comment, e.g. // "some value"
var value string
if valSpec.Comment != nil {
value = strings.TrimSpace(valSpec.Comment.Text())
if strings.HasPrefix(value, `"`) && strings.HasSuffix(value, `"`) {
// Parse quoted string, handling escaped quotes
value = value[1 : len(value)-1] // Remove outer quotes
value = strings.ReplaceAll(value, `\"`, `"`) // Unescape quotes
}
}
if value == "" {
return fmt.Errorf(`Enum %v: Missing value comment, e.g. // "value"`, enumName)
}

// TODO: how can we pass a custom key value (uint8) to the webrpc enum?
enumType.Fields = append(enumType.Fields, &schema.TypeField{
Name: ident.Name,
TypeExtra: schema.TypeExtra{
Value: fmt.Sprintf("%q", value), // hmm, webrpc requires quotes for string enums values
},
})

p.Schema.Types = append(p.Schema.Types, enumType)
p.ParsedEnumTypes[fmt.Sprintf("%v.%v", p.Pkg.PkgPath, enumName)] = enumType
}
}
}
}

return nil
}

// CollectEnums collects ENUM definitions, ie.:
//
// // approved = 0
// // pending = 1
// // closed = 2
// // new = 3
// type Status gospeak.Enum[int]
//
// Deprecated: We have switche to ExtractEnumConsts instead. Left here for now to print error to users.
func (p *Parser) CollectEnums() error {

debug := spew.NewDefaultConfig()
debug.DisableMethods = true
debug.DisablePointerAddresses = true
Expand Down Expand Up @@ -103,8 +225,23 @@ func (p *Parser) CollectEnums() error {
}
}

p.Schema.Types = append(p.Schema.Types, enumType)
p.ParsedEnumTypes[fmt.Sprintf("%v.%v", p.Pkg.PkgPath, enumName)] = enumType
typeName := fmt.Sprintf("%v.%v", p.Pkg.PkgPath, enumName)

return fmt.Errorf(`Obsolete ENUM definition for type %v.

Please, migrate to this new ENUM format:

//gospeak:enum
type Status uint8

const (
StatusUnknown Status = iota // "unknown"
StatusActive // "active"
)
`, typeName)

// p.Schema.Types = append(p.Schema.Types, enumType)
// p.ParsedEnumTypes[typeName] = enumType
}
}
}
Expand Down
89 changes: 89 additions & 0 deletions internal/parser/test/enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,95 @@ import (
"github.com/webrpc/webrpc/schema"
)

func TestStructFieldEnumConst(t *testing.T) {
t.Parallel()

tt := []struct {
in string
t schema.CoreType
out []*schema.TypeField
}{
{
in: `
// Some comments
//gospeak:enum
type Status uint8

const (
StatusUnknown Status = iota // "unknown"
StatusActive // "active"
StatusInactive // "inactive"
StatusArchived // "archived"
StatusDeleted // "deleted"
)
`,
t: schema.T_String,
out: []*schema.TypeField{
// TODO: how can we pass a custom value (uint8) to the webrpc?
{Name: "StatusUnknown", TypeExtra: schema.TypeExtra{Value: "unknown"}},
{Name: "StatusActive", TypeExtra: schema.TypeExtra{Value: "active"}},
{Name: "StatusInactive", TypeExtra: schema.TypeExtra{Value: "inactive"}},
{Name: "StatusArchived", TypeExtra: schema.TypeExtra{Value: "archived"}},
{Name: "StatusDeleted", TypeExtra: schema.TypeExtra{Value: "deleted"}},
},
},
}

for _, tc := range tt {
srcCode := fmt.Sprintf(`package test

import (
"context"

//"github.com/golang-cz/gospeak/enum"
)

%s

type TestStruct struct {
Status Status
}

//go:webrpc json -out=/dev/null
type TestAPI interface{
Test(ctx context.Context) (tst *TestStruct, err error)
}
`, tc.in)

p, err := testParser(srcCode)
if err != nil {
t.Fatal(fmt.Errorf("parsing: %w", err))
}

if err := p.ExtractEnumConsts(p.Pkg); err != nil {
t.Fatalf("collecting enums: %v", err)
}

want := &schema.Type{
Kind: schema.TypeKind_Enum,
Name: "Status",
Type: &schema.VarType{
Expr: "uint8", //tc.t.String()
Type: schema.T_Enum,
},
Fields: tc.out,
Comments: []string{"Some comments"},
}

var got *schema.Type
for _, schemaType := range p.Schema.Types {
if schemaType.Name == "Status" {
got = schemaType
}
}

if !cmp.Equal(want, got) {
t.Errorf("%s\n%s\n", tc.in, coloredDiff(want, got))
}

}
}

func TestStructFieldEnum(t *testing.T) {
t.Parallel()

Expand Down
4 changes: 4 additions & 0 deletions parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ func Parse(filePath string) ([]*Target, error) {
p := parser.New(pkg)
p.Schema.SchemaName = target.InterfaceName

if err := p.ExtractEnumConsts(pkg); err != nil {
return nil, fmt.Errorf("collecting enums: %w", err)
}

if err := p.CollectEnums(); err != nil {
return nil, fmt.Errorf("collecting enums: %w", err)
}
Expand Down
Loading