Skip to content

Commit b30a1ec

Browse files
committed
feat(injector/typed): add interface-based type system with parser
- Introduce Type interface for polymorphic type handling - Add support for pointer, slice, array, and map types - Implement recursive descent parser for type expressions - Rename TypeName to NamedType for clarity - Add helper functions NewNamedType and MustNamedType - Update all usage sites to use new type system
1 parent d575c0b commit b30a1ec

25 files changed

+2626
-572
lines changed

internal/injector/aspect/advice/call.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ import (
2121
)
2222

2323
type appendArgs struct {
24-
TypeName typed.TypeName
24+
TypeName *typed.NamedType
2525
Templates []*code.Template
2626
}
2727

2828
// AppendArgs appends arguments of a given type to the end of a function call. All arguments must be
2929
// of the same type, as they may be appended at the tail end of a variadic call.
30-
func AppendArgs(typeName typed.TypeName, templates ...*code.Template) *appendArgs {
30+
func AppendArgs(typeName *typed.NamedType, templates ...*code.Template) *appendArgs {
3131
return &appendArgs{typeName, templates}
3232
}
3333

@@ -168,12 +168,12 @@ func init() {
168168
return nil, err
169169
}
170170

171-
tn, err := typed.NewTypeName(args.TypeName)
171+
namedType, err := typed.NewNamedType(args.TypeName)
172172
if err != nil {
173-
return nil, err
173+
return nil, fmt.Errorf("invalid type %q: %w", args.TypeName, err)
174174
}
175175

176-
return AppendArgs(tn, args.Values...), nil
176+
return AppendArgs(namedType, args.Values...), nil
177177
}
178178
unmarshalers["replace-function"] = func(ctx gocontext.Context, node ast.Node) (Advice, error) {
179179
var (

internal/injector/aspect/advice/call_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
func TestAppendArgs(t *testing.T) {
2020
t.Run("AddedImports", func(t *testing.T) {
2121
type testCase struct {
22-
argType typed.TypeName
22+
argType *typed.NamedType
2323
args []*code.Template
2424
expectedImports []string
2525
}
@@ -30,7 +30,7 @@ func TestAppendArgs(t *testing.T) {
3030
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
3131
},
3232
"imports-from-arg-type": {
33-
argType: typed.MustTypeName("*net/http.Request"),
33+
argType: typed.MustNamedType("*net/http.Request"),
3434
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
3535
expectedImports: []string{"net/http"},
3636
},

internal/injector/aspect/advice/code/dot_function.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func (s signature) ResultThatImplements(name string) (string, error) {
159159
}
160160

161161
// Optimization: First, check for an exact match using the helper.
162-
if index, found := typed.FindMatchingTypeName(s.Results, name); found {
162+
if index, found := typed.FindMatchingType(s.Results, name); found {
163163
return fieldAt(s.Results, index, "result")
164164
} // If not found, fall through to type resolution.
165165

@@ -195,7 +195,7 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
195195

196196
// Optimization: First, check for an exact match using TypeName parsing, finding the last one.
197197
lastMatchIndex := -1
198-
if tn, err := typed.NewTypeName(name); err == nil {
198+
if tn, err := typed.NewType(name); err == nil {
199199
currentIndex := 0
200200
for _, field := range s.Results.List {
201201
if tn.Matches(field.Type) {
@@ -285,7 +285,7 @@ func fieldAt(fields *dst.FieldList, index int, use string) (string, error) {
285285
}
286286

287287
func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, error) {
288-
tn, err := typed.NewTypeName(typeName)
288+
t, err := typed.NewType(typeName)
289289
if err != nil {
290290
return "", err
291291
}
@@ -297,7 +297,7 @@ func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, er
297297

298298
index := 0
299299
for _, field := range fields.List {
300-
if tn.Matches(field.Type) {
300+
if t.Matches(field.Type) {
301301
return fieldAt(fields, index, use)
302302
}
303303

@@ -321,9 +321,9 @@ func (s signature) FinalResultImplements(interfaceName string) (bool, error) {
321321
lastField := s.Results.List[len(s.Results.List)-1]
322322

323323
// Optimization: First, check for an exact match using TypeName parsing.
324-
// Note: Not using FindMatchingTypeName as we only need to check the last field.
325-
if tn, err := typed.NewTypeName(interfaceName); err == nil {
326-
if tn.Matches(lastField.Type) {
324+
// Note: Not using FindMatchingType as we only need to check the last field.
325+
if t, err := typed.NewType(interfaceName); err == nil {
326+
if t.Matches(lastField.Type) {
327327
return true, nil
328328
}
329329
} // If parsing failed or no match, fall through to type resolution.

internal/injector/aspect/advice/struct.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ import (
1919

2020
type addStructField struct {
2121
Name string
22-
TypeName typed.TypeName
22+
TypeName *typed.NamedType
2323
}
2424

2525
// AddStructField adds a new synthetic field at the tail end of a struct declaration.
26-
func AddStructField(fieldName string, fieldType typed.TypeName) *addStructField {
26+
func AddStructField(fieldName string, fieldType *typed.NamedType) *addStructField {
2727
return &addStructField{fieldName, fieldType}
2828
}
2929

@@ -76,11 +76,11 @@ func init() {
7676
if err := yaml.NodeToValueContext(ctx, node, &spec); err != nil {
7777
return nil, err
7878
}
79-
tn, err := typed.NewTypeName(spec.Type)
79+
namedType, err := typed.NewNamedType(spec.Type)
8080
if err != nil {
81-
return nil, err
81+
return nil, fmt.Errorf("invalid type %q: %w", spec.Type, err)
8282
}
8383

84-
return AddStructField(spec.Name, tn), nil
84+
return AddStructField(spec.Name, namedType), nil
8585
}
8686
}

internal/injector/aspect/join/declaration.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ func (i *declarationOf) Hash(h *fingerprint.Hasher) error {
7373
}
7474

7575
type valueDeclaration struct {
76-
TypeName typed.TypeName
76+
TypeName *typed.NamedType
7777
}
7878

79-
func ValueDeclaration(typeName typed.TypeName) *valueDeclaration {
79+
func ValueDeclaration(typeName *typed.NamedType) *valueDeclaration {
8080
return &valueDeclaration{typeName}
8181
}
8282

@@ -141,9 +141,9 @@ func init() {
141141
return nil, err
142142
}
143143

144-
tn, err := typed.NewTypeName(typeName)
144+
tn, err := typed.NewNamedType(typeName)
145145
if err != nil {
146-
return nil, err
146+
return nil, fmt.Errorf("invalid type %q: %w", typeName, err)
147147
}
148148

149149
return ValueDeclaration(tn), nil

internal/injector/aspect/join/function.go

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,13 @@ func (fo functionName) Hash(h *fingerprint.Hasher) error {
145145
}
146146

147147
type signature struct {
148-
Arguments []typed.TypeName
149-
Results []typed.TypeName
148+
Arguments []*typed.NamedType
149+
Results []*typed.NamedType
150150
}
151151

152152
// Signature matches function declarations based on their arguments and return
153153
// value types.
154-
func Signature(args []typed.TypeName, ret []typed.TypeName) FunctionOption {
154+
func Signature(args []*typed.NamedType, ret []*typed.NamedType) FunctionOption {
155155
return &signature{Arguments: args, Results: ret}
156156
}
157157

@@ -225,8 +225,8 @@ func (fo *signature) evaluate(info functionInformation) bool {
225225
func (fo *signature) Hash(h *fingerprint.Hasher) error {
226226
return h.Named(
227227
"signature",
228-
fingerprint.List[typed.TypeName](fo.Arguments),
229-
fingerprint.List[typed.TypeName](fo.Results),
228+
fingerprint.List[*typed.NamedType](fo.Arguments),
229+
fingerprint.List[*typed.NamedType](fo.Results),
230230
)
231231
}
232232

@@ -236,15 +236,15 @@ type signatureContains struct {
236236

237237
// SignatureContains matches function declarations based on their arguments and
238238
// return value types in any order and does not require all arguments or return values to be present.
239-
func SignatureContains(args []typed.TypeName, ret []typed.TypeName) FunctionOption {
239+
func SignatureContains(args []*typed.NamedType, ret []*typed.NamedType) FunctionOption {
240240
return &signatureContains{signature{Arguments: args, Results: ret}}
241241
}
242242

243243
func (fo *signatureContains) Hash(h *fingerprint.Hasher) error {
244244
return h.Named(
245245
"signature-contains",
246-
fingerprint.List[typed.TypeName](fo.Arguments),
247-
fingerprint.List[typed.TypeName](fo.Results),
246+
fingerprint.List[*typed.NamedType](fo.Arguments),
247+
fingerprint.List[*typed.NamedType](fo.Results),
248248
)
249249
}
250250

@@ -262,7 +262,7 @@ func (fo *signatureContains) evaluate(info functionInformation) bool {
262262

263263
// containsAnyType checks if any of the expected types match any of the actual types in the field list.
264264
// Returns false if either slice is empty or nil.
265-
func containsAnyType(expectedTypes []typed.TypeName, fieldList *dst.FieldList) bool {
265+
func containsAnyType(expectedTypes []*typed.NamedType, fieldList *dst.FieldList) bool {
266266
// Quick return if either side is empty.
267267
if len(expectedTypes) == 0 || fieldList == nil || len(fieldList.List) == 0 {
268268
return false
@@ -281,10 +281,10 @@ func containsAnyType(expectedTypes []typed.TypeName, fieldList *dst.FieldList) b
281281
}
282282

283283
type receiver struct {
284-
TypeName typed.TypeName
284+
TypeName *typed.NamedType
285285
}
286286

287-
func Receiver(typeName typed.TypeName) FunctionOption {
287+
func Receiver(typeName *typed.NamedType) FunctionOption {
288288
return &receiver{typeName}
289289
}
290290

@@ -398,7 +398,7 @@ func (fo *resultImplements) evaluate(info functionInformation) bool {
398398
}
399399

400400
// Optimization: First, check for an exact match using the helper.
401-
if _, found := typed.FindMatchingTypeName(info.Type.Results, fo.InterfaceName); found {
401+
if _, found := typed.FindMatchingType(info.Type.Results, fo.InterfaceName); found {
402402
return true // Found direct match
403403
} // If not found, fall through to type resolution.
404404

@@ -467,10 +467,10 @@ func (fo *finalResultImplements) evaluate(info functionInformation) bool {
467467
return false
468468
}
469469

470-
// Optimization: First, check for an exact match using TypeName parsing.
471-
if tn, err := typed.NewTypeName(fo.InterfaceName); err == nil {
470+
// Optimization: First, check for an exact match using Type parsing.
471+
if t, err := typed.NewType(fo.InterfaceName); err == nil {
472472
lastField := info.Type.Results.List[len(info.Type.Results.List)-1]
473-
if tn.Matches(lastField.Type) {
473+
if t.Matches(lastField.Type) {
474474
return true // Found direct match
475475
}
476476
} // If parsing failed or no match, fall through to type resolution.
@@ -549,9 +549,9 @@ func (o *unmarshalFuncDeclOption) UnmarshalYAML(ctx gocontext.Context, node ast.
549549
if err := yaml.NodeToValueContext(ctx, mapping.Values[0].Value, &arg); err != nil {
550550
return err
551551
}
552-
tn, err := typed.NewTypeName(arg)
552+
tn, err := typed.NewNamedType(arg)
553553
if err != nil {
554-
return err
554+
return fmt.Errorf("invalid receiver type %q: %w", arg, err)
555555
}
556556
o.FunctionOption = Receiver(tn)
557557
case "signature", "signature-contains":
@@ -573,25 +573,27 @@ func (o *unmarshalFuncDeclOption) UnmarshalYAML(ctx gocontext.Context, node ast.
573573
return fmt.Errorf("unexpected keys: %s", strings.Join(keys, ", "))
574574
}
575575

576-
var args []typed.TypeName
576+
var args []*typed.NamedType
577577
if len(sig.Args) > 0 {
578-
args = make([]typed.TypeName, len(sig.Args))
578+
args = make([]*typed.NamedType, len(sig.Args))
579579
for i, a := range sig.Args {
580-
var err error
581-
if args[i], err = typed.NewTypeName(a); err != nil {
582-
return err
580+
tn, err := typed.NewNamedType(a)
581+
if err != nil {
582+
return fmt.Errorf("invalid argument type %q at position %d: %w", a, i, err)
583583
}
584+
args[i] = tn
584585
}
585586
}
586587

587-
var ret []typed.TypeName
588+
var ret []*typed.NamedType
588589
if len(sig.Ret) > 0 {
589-
ret = make([]typed.TypeName, len(sig.Ret))
590+
ret = make([]*typed.NamedType, len(sig.Ret))
590591
for i, r := range sig.Ret {
591-
var err error
592-
if ret[i], err = typed.NewTypeName(r); err != nil {
593-
return err
592+
tn, err := typed.NewNamedType(r)
593+
if err != nil {
594+
return fmt.Errorf("invalid return type %q at position %d: %w", r, i, err)
594595
}
596+
ret[i] = tn
595597
}
596598
}
597599

0 commit comments

Comments
 (0)