Skip to content

Commit 0f0aca9

Browse files
committed
feat(injector/typed): add interface-based type system with parser
- Introduce Type interface for polymorphic type handling - Add support for pointer, slice, array, and map types - Implement recursive descent parser for type expressions - Rename TypeName to NamedType for clarity - Add helper functions NewNamedType and MustNamedType - Update all usage sites to use new type system
1 parent 8157bfb commit 0f0aca9

25 files changed

+2636
-582
lines changed

internal/injector/aspect/advice/call.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ import (
2121
)
2222

2323
type appendArgs struct {
24-
TypeName typed.TypeName
24+
TypeName *typed.NamedType
2525
Templates []*code.Template
2626
}
2727

2828
// AppendArgs appends arguments of a given type to the end of a function call. All arguments must be
2929
// of the same type, as they may be appended at the tail end of a variadic call.
30-
func AppendArgs(typeName typed.TypeName, templates ...*code.Template) *appendArgs {
30+
func AppendArgs(typeName *typed.NamedType, templates ...*code.Template) *appendArgs {
3131
return &appendArgs{typeName, templates}
3232
}
3333

@@ -168,12 +168,12 @@ func init() {
168168
return nil, err
169169
}
170170

171-
tn, err := typed.NewTypeName(args.TypeName)
171+
namedType, err := typed.NewNamedType(args.TypeName)
172172
if err != nil {
173-
return nil, err
173+
return nil, fmt.Errorf("invalid type %q: %w", args.TypeName, err)
174174
}
175175

176-
return AppendArgs(tn, args.Values...), nil
176+
return AppendArgs(namedType, args.Values...), nil
177177
}
178178
unmarshalers["replace-function"] = func(ctx gocontext.Context, node ast.Node) (Advice, error) {
179179
var (

internal/injector/aspect/advice/call_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
func TestAppendArgs(t *testing.T) {
2020
t.Run("AddedImports", func(t *testing.T) {
2121
type testCase struct {
22-
argType typed.TypeName
22+
argType *typed.NamedType
2323
args []*code.Template
2424
expectedImports []string
2525
}
@@ -30,7 +30,7 @@ func TestAppendArgs(t *testing.T) {
3030
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
3131
},
3232
"imports-from-arg-type": {
33-
argType: typed.MustTypeName("*net/http.Request"),
33+
argType: typed.MustNamedType("*net/http.Request"),
3434
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
3535
expectedImports: []string{"net/http"},
3636
},

internal/injector/aspect/advice/code/dot_function.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
176176

177177
// Optimization: First, check for an exact match using TypeName parsing, finding the last one.
178178
lastMatchIndex := -1
179-
if tn, err := typed.NewTypeName(name); err == nil {
179+
if tn, err := typed.NewType(name); err == nil {
180180
currentIndex := 0
181181
for _, field := range s.Results.List {
182182
if tn.Matches(field.Type) {
@@ -274,7 +274,7 @@ func fieldAt(fields *dst.FieldList, index int, use string) (string, error) {
274274
}
275275

276276
func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, error) {
277-
tn, err := typed.NewTypeName(typeName)
277+
t, err := typed.NewType(typeName)
278278
if err != nil {
279279
return "", err
280280
}
@@ -286,7 +286,7 @@ func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, er
286286

287287
index := 0
288288
for _, field := range fields.List {
289-
if tn.Matches(field.Type) {
289+
if t.Matches(field.Type) {
290290
return fieldAt(fields, index, use)
291291
}
292292

@@ -310,9 +310,9 @@ func (s signature) FinalResultImplements(interfaceName string) (bool, error) {
310310
lastField := s.Results.List[len(s.Results.List)-1]
311311

312312
// Optimization: First, check for an exact match using TypeName parsing.
313-
// Note: Not using FindMatchingTypeName as we only need to check the last field.
314-
if tn, err := typed.NewTypeName(interfaceName); err == nil {
315-
if tn.Matches(lastField.Type) {
313+
// Note: Not using FindMatchingType as we only need to check the last field.
314+
if t, err := typed.NewType(interfaceName); err == nil {
315+
if t.Matches(lastField.Type) {
316316
return true, nil
317317
}
318318
} // If parsing failed or no match, fall through to type resolution.
@@ -334,7 +334,7 @@ func findImplementingField(ctx context.AdviceContext, fields *dst.FieldList, int
334334
}
335335

336336
// 1. Check for exact type name match first.
337-
if index, found := typed.FindMatchingTypeName(fields, interfaceName); found {
337+
if index, found := typed.FindMatchingType(fields, interfaceName); found {
338338
return fieldAt(fields, index, fieldKind)
339339
}
340340

internal/injector/aspect/advice/struct.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ import (
1919

2020
type addStructField struct {
2121
Name string
22-
TypeName typed.TypeName
22+
TypeName *typed.NamedType
2323
}
2424

2525
// AddStructField adds a new synthetic field at the tail end of a struct declaration.
26-
func AddStructField(fieldName string, fieldType typed.TypeName) *addStructField {
26+
func AddStructField(fieldName string, fieldType *typed.NamedType) *addStructField {
2727
return &addStructField{fieldName, fieldType}
2828
}
2929

@@ -76,11 +76,11 @@ func init() {
7676
if err := yaml.NodeToValueContext(ctx, node, &spec); err != nil {
7777
return nil, err
7878
}
79-
tn, err := typed.NewTypeName(spec.Type)
79+
namedType, err := typed.NewNamedType(spec.Type)
8080
if err != nil {
81-
return nil, err
81+
return nil, fmt.Errorf("invalid type %q: %w", spec.Type, err)
8282
}
8383

84-
return AddStructField(spec.Name, tn), nil
84+
return AddStructField(spec.Name, namedType), nil
8585
}
8686
}

internal/injector/aspect/join/declaration.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ func (i *declarationOf) Hash(h *fingerprint.Hasher) error {
7373
}
7474

7575
type valueDeclaration struct {
76-
TypeName typed.TypeName
76+
TypeName *typed.NamedType
7777
}
7878

79-
func ValueDeclaration(typeName typed.TypeName) *valueDeclaration {
79+
func ValueDeclaration(typeName *typed.NamedType) *valueDeclaration {
8080
return &valueDeclaration{typeName}
8181
}
8282

@@ -141,9 +141,9 @@ func init() {
141141
return nil, err
142142
}
143143

144-
tn, err := typed.NewTypeName(typeName)
144+
tn, err := typed.NewNamedType(typeName)
145145
if err != nil {
146-
return nil, err
146+
return nil, fmt.Errorf("invalid type %q: %w", typeName, err)
147147
}
148148

149149
return ValueDeclaration(tn), nil

internal/injector/aspect/join/function.go

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,13 @@ func (fo functionName) Hash(h *fingerprint.Hasher) error {
145145
}
146146

147147
type signature struct {
148-
Arguments []typed.TypeName
149-
Results []typed.TypeName
148+
Arguments []*typed.NamedType
149+
Results []*typed.NamedType
150150
}
151151

152152
// Signature matches function declarations based on their arguments and return
153153
// value types.
154-
func Signature(args []typed.TypeName, ret []typed.TypeName) FunctionOption {
154+
func Signature(args []*typed.NamedType, ret []*typed.NamedType) FunctionOption {
155155
return &signature{Arguments: args, Results: ret}
156156
}
157157

@@ -225,8 +225,8 @@ func (fo *signature) evaluate(info functionInformation) bool {
225225
func (fo *signature) Hash(h *fingerprint.Hasher) error {
226226
return h.Named(
227227
"signature",
228-
fingerprint.List[typed.TypeName](fo.Arguments),
229-
fingerprint.List[typed.TypeName](fo.Results),
228+
fingerprint.List[*typed.NamedType](fo.Arguments),
229+
fingerprint.List[*typed.NamedType](fo.Results),
230230
)
231231
}
232232

@@ -236,15 +236,15 @@ type signatureContains struct {
236236

237237
// SignatureContains matches function declarations based on their arguments and
238238
// return value types in any order and does not require all arguments or return values to be present.
239-
func SignatureContains(args []typed.TypeName, ret []typed.TypeName) FunctionOption {
239+
func SignatureContains(args []*typed.NamedType, ret []*typed.NamedType) FunctionOption {
240240
return &signatureContains{signature{Arguments: args, Results: ret}}
241241
}
242242

243243
func (fo *signatureContains) Hash(h *fingerprint.Hasher) error {
244244
return h.Named(
245245
"signature-contains",
246-
fingerprint.List[typed.TypeName](fo.Arguments),
247-
fingerprint.List[typed.TypeName](fo.Results),
246+
fingerprint.List[*typed.NamedType](fo.Arguments),
247+
fingerprint.List[*typed.NamedType](fo.Results),
248248
)
249249
}
250250

@@ -262,7 +262,7 @@ func (fo *signatureContains) evaluate(info functionInformation) bool {
262262

263263
// containsAnyType checks if any of the expected types match any of the actual types in the field list.
264264
// Returns false if either slice is empty or nil.
265-
func containsAnyType(expectedTypes []typed.TypeName, fieldList *dst.FieldList) bool {
265+
func containsAnyType(expectedTypes []*typed.NamedType, fieldList *dst.FieldList) bool {
266266
// Quick return if either side is empty.
267267
if len(expectedTypes) == 0 || fieldList == nil || len(fieldList.List) == 0 {
268268
return false
@@ -281,10 +281,10 @@ func containsAnyType(expectedTypes []typed.TypeName, fieldList *dst.FieldList) b
281281
}
282282

283283
type receiver struct {
284-
TypeName typed.TypeName
284+
TypeName *typed.NamedType
285285
}
286286

287-
func Receiver(typeName typed.TypeName) FunctionOption {
287+
func Receiver(typeName *typed.NamedType) FunctionOption {
288288
return &receiver{typeName}
289289
}
290290

@@ -393,18 +393,18 @@ func (_ *resultImplements) fileMayMatch(_ *may.FileContext) may.MatchType {
393393

394394
// evaluateFieldListImplements checks if any field in the list matches the interfaceName,
395395
// either by exact type name or by interface implementation.
396-
func evaluateFieldListImplements(fields *dst.FieldList, interfaceName string, info functionInformation) bool {
396+
func evaluateFieldListImplements(typeResolver typeResolver, fields *dst.FieldList, interfaceName string) bool {
397397
if fields == nil || len(fields.List) == 0 {
398398
return false
399399
}
400400

401401
// Optimization: First, check for an exact match using the helper.
402-
if _, found := typed.FindMatchingTypeName(fields, interfaceName); found {
402+
if _, found := typed.FindMatchingType(fields, interfaceName); found {
403403
return true // Found direct match
404404
}
405405

406406
// If no exact match, check implementation (requires type resolver).
407-
if info.typeResolver == nil {
407+
if typeResolver == nil {
408408
return false // Cannot check implementation without resolver.
409409
}
410410

@@ -414,7 +414,7 @@ func evaluateFieldListImplements(fields *dst.FieldList, interfaceName string, in
414414
}
415415

416416
for _, field := range fields.List {
417-
if typed.ExprImplements(info.typeResolver, field.Type, targetInterface) {
417+
if typed.ExprImplements(typeResolver, field.Type, targetInterface) {
418418
return true // Found an implementing type.
419419
}
420420
}
@@ -423,7 +423,7 @@ func evaluateFieldListImplements(fields *dst.FieldList, interfaceName string, in
423423
}
424424

425425
func (fo *resultImplements) evaluate(info functionInformation) bool {
426-
return evaluateFieldListImplements(info.Type.Results, fo.InterfaceName, info)
426+
return evaluateFieldListImplements(info.typeResolver, info.Type.Results, fo.InterfaceName)
427427
}
428428

429429
func (fo *resultImplements) Hash(h *fingerprint.Hasher) error {
@@ -467,10 +467,10 @@ func (fo *finalResultImplements) evaluate(info functionInformation) bool {
467467
return false
468468
}
469469

470-
// Optimization: First, check for an exact match using TypeName parsing.
471-
if tn, err := typed.NewTypeName(fo.InterfaceName); err == nil {
470+
// Optimization: First, check for an exact match using Type parsing.
471+
if t, err := typed.NewType(fo.InterfaceName); err == nil {
472472
lastField := info.Type.Results.List[len(info.Type.Results.List)-1]
473-
if tn.Matches(lastField.Type) {
473+
if t.Matches(lastField.Type) {
474474
return true // Found direct match
475475
}
476476
} // If parsing failed or no match, fall through to type resolution.
@@ -508,11 +508,9 @@ func ArgumentImplements(interfaceName string) FunctionOption {
508508
return &argumentImplements{InterfaceName: interfaceName}
509509
}
510510

511-
func (fo *argumentImplements) impliesImported() []string {
512-
pkgPath, _ := typed.SplitPackageAndName(fo.InterfaceName)
513-
if pkgPath != "" {
514-
return []string{pkgPath}
515-
}
511+
func (*argumentImplements) impliesImported() []string {
512+
// A type can implement an interface without importing the interface's package
513+
// due to Go's structural typing system.
516514
return nil
517515
}
518516

@@ -530,7 +528,7 @@ func (_ *argumentImplements) fileMayMatch(_ *may.FileContext) may.MatchType {
530528
}
531529

532530
func (fo *argumentImplements) evaluate(info functionInformation) bool {
533-
return evaluateFieldListImplements(info.Type.Params, fo.InterfaceName, info)
531+
return evaluateFieldListImplements(info.typeResolver, info.Type.Params, fo.InterfaceName)
534532
}
535533

536534
func (fo *argumentImplements) Hash(h *fingerprint.Hasher) error {
@@ -590,9 +588,9 @@ func (o *unmarshalFuncDeclOption) UnmarshalYAML(ctx gocontext.Context, node ast.
590588
if err := yaml.NodeToValueContext(ctx, mapping.Values[0].Value, &arg); err != nil {
591589
return err
592590
}
593-
tn, err := typed.NewTypeName(arg)
591+
tn, err := typed.NewNamedType(arg)
594592
if err != nil {
595-
return err
593+
return fmt.Errorf("invalid receiver type %q: %w", arg, err)
596594
}
597595
o.FunctionOption = Receiver(tn)
598596
case "signature", "signature-contains":
@@ -614,25 +612,27 @@ func (o *unmarshalFuncDeclOption) UnmarshalYAML(ctx gocontext.Context, node ast.
614612
return fmt.Errorf("unexpected keys: %s", strings.Join(keys, ", "))
615613
}
616614

617-
var args []typed.TypeName
615+
var args []*typed.NamedType
618616
if len(sig.Args) > 0 {
619-
args = make([]typed.TypeName, len(sig.Args))
617+
args = make([]*typed.NamedType, len(sig.Args))
620618
for i, a := range sig.Args {
621-
var err error
622-
if args[i], err = typed.NewTypeName(a); err != nil {
623-
return err
619+
tn, err := typed.NewNamedType(a)
620+
if err != nil {
621+
return fmt.Errorf("invalid argument type %q at position %d: %w", a, i, err)
624622
}
623+
args[i] = tn
625624
}
626625
}
627626

628-
var ret []typed.TypeName
627+
var ret []*typed.NamedType
629628
if len(sig.Ret) > 0 {
630-
ret = make([]typed.TypeName, len(sig.Ret))
629+
ret = make([]*typed.NamedType, len(sig.Ret))
631630
for i, r := range sig.Ret {
632-
var err error
633-
if ret[i], err = typed.NewTypeName(r); err != nil {
634-
return err
631+
tn, err := typed.NewNamedType(r)
632+
if err != nil {
633+
return fmt.Errorf("invalid return type %q at position %d: %w", r, i, err)
635634
}
635+
ret[i] = tn
636636
}
637637
}
638638

0 commit comments

Comments
 (0)