Skip to content

Commit 8157bfb

Browse files
authored
feat(injector): Add ArgumentThatImplements method for argument interface matching (#617)
This commit introduces the `ArgumentThatImplements` method to the function signature, allowing the retrieval of the first argument that implements a specified interface type. The implementation includes a new helper function, `findImplementingField`, which checks for both exact type matches and interface implementations. Additionally, the `ResultThatImplements` method has been refactored to utilize the new helper for consistency. The changes enhance the injector's capability to handle argument interface checks, improving the overall functionality and maintainability of the code. Configuration YAML files and test cases have been added to validate the new functionality across various scenarios, ensuring robust testing of the argument matching feature. Signed-off-by: Kemal Akkoyun <[email protected]> Signed-off-by: Kemal Akkoyun <[email protected]>
1 parent ec50330 commit 8157bfb

File tree

6 files changed

+638
-53
lines changed

6 files changed

+638
-53
lines changed

internal/injector/aspect/advice/code/dot_function.go

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package code
88
import (
99
"errors"
1010
"fmt"
11+
"go/types"
1112

1213
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
1314
"github.com/DataDog/orchestrion/internal/injector/typed"
@@ -28,6 +29,9 @@ type (
2829
// ArgumentOfType returns the name of the first argument in this function that has the provided
2930
// type, or an empty string if none is found.
3031
ArgumentOfType(string) (string, error)
32+
// ArgumentThatImplements returns the name of the first argument in this function that implements
33+
// the provided interface type, or an empty string if none is found.
34+
ArgumentThatImplements(string) (string, error)
3135

3236
// Result returns the name of the return value at the given index in this function's type,
3337
// returning an error if the index is out of bounds.
@@ -111,6 +115,10 @@ func (noFunc) ArgumentOfType(string) (string, error) {
111115
return "", errNoFunction
112116
}
113117

118+
func (noFunc) ArgumentThatImplements(string) (string, error) {
119+
return "", errNoFunction
120+
}
121+
114122
func (noFunc) Result(int) (string, error) {
115123
return "", errNoFunction
116124
}
@@ -144,6 +152,10 @@ func (s signature) ArgumentOfType(name string) (string, error) {
144152
return fieldOfType(s.Params, name, "argument")
145153
}
146154

155+
func (s signature) ArgumentThatImplements(interfaceName string) (string, error) {
156+
return findImplementingField(s.context, s.Params, interfaceName, "argument")
157+
}
158+
147159
func (s signature) Result(index int) (name string, err error) {
148160
return fieldAt(s.Results, index, "result")
149161
}
@@ -152,39 +164,8 @@ func (s signature) ResultOfType(name string) (string, error) {
152164
return fieldOfType(s.Results, name, "result")
153165
}
154166

155-
func (s signature) ResultThatImplements(name string) (string, error) {
156-
// Return blank if there are no results.
157-
if s.Results == nil {
158-
return "", nil
159-
}
160-
161-
// Optimization: First, check for an exact match using the helper.
162-
if index, found := typed.FindMatchingTypeName(s.Results, name); found {
163-
return fieldAt(s.Results, index, "result")
164-
} // If not found, fall through to type resolution.
165-
166-
// Resolve the interface type.
167-
iface, err := typed.ResolveInterfaceTypeByName(name)
168-
if err != nil {
169-
return "", fmt.Errorf("resolving interface type %q: %w", name, err)
170-
}
171-
172-
// Check each result.
173-
index := 0
174-
for _, field := range s.Results.List {
175-
if typed.ExprImplements(s.context, field.Type, iface) {
176-
return fieldAt(s.Results, index, "result")
177-
}
178-
179-
count := len(field.Names)
180-
if count == 0 {
181-
count = 1
182-
}
183-
index += count
184-
}
185-
186-
// Not found.
187-
return "", nil
167+
func (s signature) ResultThatImplements(interfaceName string) (string, error) {
168+
return findImplementingField(s.context, s.Results, interfaceName, "result")
188169
}
189170

190171
func (s signature) LastResultThatImplements(name string) (string, error) {
@@ -277,11 +258,19 @@ func fieldAt(fields *dst.FieldList, index int, use string) (string, error) {
277258
}
278259
}
279260

280-
if idx < index {
281-
return "", fmt.Errorf("index out of bounds: %d (only %d items)", index, idx+1)
261+
if idx <= index { // Use <= to catch index being exactly the number of items
262+
return "", fmt.Errorf("index out of bounds: %d (only %d items)", index, idx)
282263
}
283264

284-
return name, nil
265+
// If anonymous, we should have assigned the synthetic name earlier.
266+
// If named, it should have been returned immediately.
267+
// If we reach here and it was anonymous, we return the generated name.
268+
if anonymous {
269+
return name, nil
270+
}
271+
272+
// This path should ideally not be reached for named parameters if logic is correct.
273+
return "", fmt.Errorf("fieldAt: failed to find field at index %d", index)
285274
}
286275

287276
func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, error) {
@@ -336,3 +325,42 @@ func (s signature) FinalResultImplements(interfaceName string) (bool, error) {
336325
// Check if the last field type implements the interface.
337326
return typed.ExprImplements(s.context, lastField.Type, iface), nil
338327
}
328+
329+
// findImplementingField is a helper to find the first field in a list that matches
330+
// an interface, either by exact type name or by implementation.
331+
func findImplementingField(ctx context.AdviceContext, fields *dst.FieldList, interfaceName string, fieldKind string) (string, error) {
332+
if fields == nil {
333+
return "", nil // No fields, no match.
334+
}
335+
336+
// 1. Check for exact type name match first.
337+
if index, found := typed.FindMatchingTypeName(fields, interfaceName); found {
338+
return fieldAt(fields, index, fieldKind)
339+
}
340+
341+
// 2. If no exact name match, check for interface implementation.
342+
iface, err := typed.ResolveInterfaceTypeByName(interfaceName)
343+
if err != nil {
344+
// Invalid interface name, cannot proceed with implementation check.
345+
return "", fmt.Errorf("resolving interface type %q: %w", interfaceName, err)
346+
}
347+
348+
// Iterate through fields to check for implementation.
349+
currentIndex := 0
350+
for _, field := range fields.List {
351+
actualType := ctx.ResolveType(field.Type)
352+
if actualType != nil && types.Implements(actualType, iface) {
353+
return fieldAt(fields, currentIndex, fieldKind)
354+
}
355+
356+
// Increment index based on field names.
357+
count := len(field.Names)
358+
if count == 0 {
359+
count = 1
360+
}
361+
currentIndex += count
362+
}
363+
364+
// 3. No match found.
365+
return "", nil
366+
}

internal/injector/aspect/join/function.go

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -391,39 +391,39 @@ func (_ *resultImplements) fileMayMatch(_ *may.FileContext) may.MatchType {
391391
return may.Unknown
392392
}
393393

394-
func (fo *resultImplements) evaluate(info functionInformation) bool {
395-
if info.Type.Results == nil || len(info.Type.Results.List) == 0 {
396-
// No return values, no match.
394+
// evaluateFieldListImplements checks if any field in the list matches the interfaceName,
395+
// either by exact type name or by interface implementation.
396+
func evaluateFieldListImplements(fields *dst.FieldList, interfaceName string, info functionInformation) bool {
397+
if fields == nil || len(fields.List) == 0 {
397398
return false
398399
}
399400

400401
// Optimization: First, check for an exact match using the helper.
401-
if _, found := typed.FindMatchingTypeName(info.Type.Results, fo.InterfaceName); found {
402+
if _, found := typed.FindMatchingTypeName(fields, interfaceName); found {
402403
return true // Found direct match
403-
} // If not found, fall through to type resolution.
404+
}
404405

405-
// Ensure the type resolver is available.
406+
// If no exact match, check implementation (requires type resolver).
406407
if info.typeResolver == nil {
407-
return false
408+
return false // Cannot check implementation without resolver.
408409
}
409410

410-
// Resolve the target interface name (e.g., "io.Reader", "error") to a types.Interface.
411-
targetInterface, err := typed.ResolveInterfaceTypeByName(fo.InterfaceName)
411+
targetInterface, err := typed.ResolveInterfaceTypeByName(interfaceName)
412412
if err != nil {
413-
// If the interface name is invalid or cannot be resolved, we cannot match.
414-
return false
413+
return false // Invalid interface name.
415414
}
416415

417-
for _, field := range info.Type.Results.List {
418-
// For each return type (dst.Expr), resolve it to types.Type using the provided resolver
419-
// and check if it implements the target interface.
416+
for _, field := range fields.List {
420417
if typed.ExprImplements(info.typeResolver, field.Type, targetInterface) {
421-
return true // Found at least one implementing type, match!
418+
return true // Found an implementing type.
422419
}
423420
}
424421

425-
// No return type matched.
426-
return false
422+
return false // No match found.
423+
}
424+
425+
func (fo *resultImplements) evaluate(info functionInformation) bool {
426+
return evaluateFieldListImplements(info.Type.Results, fo.InterfaceName, info)
427427
}
428428

429429
func (fo *resultImplements) Hash(h *fingerprint.Hasher) error {
@@ -496,6 +496,47 @@ func (fo *finalResultImplements) Hash(h *fingerprint.Hasher) error {
496496
return h.Named("final-result-implements", fingerprint.String(fo.InterfaceName))
497497
}
498498

499+
// argumentImplements matches functions where at least one argument's type
500+
// implements the specified interface.
501+
type argumentImplements struct {
502+
InterfaceName string
503+
}
504+
505+
// ArgumentImplements creates a FunctionOption that matches functions where at least one
506+
// argument implements the named interface.
507+
func ArgumentImplements(interfaceName string) FunctionOption {
508+
return &argumentImplements{InterfaceName: interfaceName}
509+
}
510+
511+
func (fo *argumentImplements) impliesImported() []string {
512+
pkgPath, _ := typed.SplitPackageAndName(fo.InterfaceName)
513+
if pkgPath != "" {
514+
return []string{pkgPath}
515+
}
516+
return nil
517+
}
518+
519+
func (_ *argumentImplements) packageMayMatch(_ *may.PackageContext) may.MatchType {
520+
// Cannot reliably determine possibility of match based on package imports
521+
// due to structural typing. A type can implement an interface without
522+
// importing the interface's package.
523+
return may.Unknown
524+
}
525+
526+
func (_ *argumentImplements) fileMayMatch(_ *may.FileContext) may.MatchType {
527+
// Cannot reliably determine possibility of match based on file contents
528+
// due to structural typing and type aliases.
529+
return may.Unknown
530+
}
531+
532+
func (fo *argumentImplements) evaluate(info functionInformation) bool {
533+
return evaluateFieldListImplements(info.Type.Params, fo.InterfaceName, info)
534+
}
535+
536+
func (fo *argumentImplements) Hash(h *fingerprint.Hasher) error {
537+
return h.Named("argument-implements", fingerprint.String(fo.InterfaceName))
538+
}
539+
499540
func init() {
500541
unmarshalers["function-body"] = func(ctx gocontext.Context, node ast.Node) (Point, error) {
501542
up, err := FromYAML(ctx, node)
@@ -621,6 +662,15 @@ func (o *unmarshalFuncDeclOption) UnmarshalYAML(ctx gocontext.Context, node ast.
621662
}
622663
// NOTE: Validation happens later during type resolution.
623664
o.FunctionOption = FinalResultImplements(ifaceName)
665+
case "argument-implements":
666+
var ifaceName string
667+
if err := yaml.NodeToValueContext(ctx, mapping.Values[0].Value, &ifaceName); err != nil {
668+
return err
669+
}
670+
if ifaceName == "" {
671+
return fmt.Errorf("line %d: 'argument-implements' cannot be empty", node.GetToken().Position.Line)
672+
}
673+
o.FunctionOption = ArgumentImplements(ifaceName)
624674
default:
625675
return fmt.Errorf("unknown FuncDeclOption name: %q", key)
626676
}

0 commit comments

Comments
 (0)