@@ -8,14 +8,10 @@ package code
88import (
99 "errors"
1010 "fmt"
11- "go/importer"
12- "go/types"
13- "strings"
1411
1512 "github.com/dave/dst"
1613
1714 "github.com/DataDog/orchestrion/internal/injector/aspect/context"
18- "github.com/DataDog/orchestrion/internal/injector/aspect/join"
1915 "github.com/DataDog/orchestrion/internal/injector/typed"
2016)
2117
@@ -163,6 +159,11 @@ func (s signature) ResultThatImplements(name string) (string, error) {
163159 return "" , nil
164160 }
165161
162+ // Optimization: First, check for an exact match using the helper.
163+ if index , found := typed .FindMatchingTypeName (s .Results , name ); found {
164+ return fieldAt (s .Results , index , "result" )
165+ } // If not found, fall through to type resolution.
166+
166167 // Resolve the interface type.
167168 iface , err := typed .ResolveInterfaceTypeByName (name )
168169 if err != nil {
@@ -193,18 +194,38 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
193194 return "" , nil
194195 }
195196
197+ // Optimization: First, check for an exact match using TypeName parsing, finding the last one.
198+ lastMatchIndex := - 1
199+ if tn , err := typed .NewTypeName (name ); err == nil {
200+ currentIndex := 0
201+ for _ , field := range s .Results .List {
202+ if tn .Matches (field .Type ) {
203+ lastMatchIndex = currentIndex // Update last found index
204+ }
205+ // Increment index by the number of names in the field (or 1 if unnamed).
206+ count := len (field .Names )
207+ if count == 0 {
208+ count = 1
209+ }
210+ currentIndex += count
211+ }
212+ }
213+ // If we found a match via TypeName, return it.
214+ if lastMatchIndex != - 1 {
215+ return fieldAt (s .Results , lastMatchIndex , "result" )
216+ } // If parsing failed or no match, fall through to type resolution.
217+
196218 // Resolve the interface type.
197219 iface , err := typed .ResolveInterfaceTypeByName (name )
198220 if err != nil {
221+ // Propagate error if interface resolution fails
199222 return "" , fmt .Errorf ("resolving interface type %q: %w" , name , err )
200223 }
201224
202- // First, we need to build a map of result fields to their indices
203- // that takes into account named and unnamed parameters.
204- var (
205- fieldIndices = make (map [* dst.Field ]int )
206- index = 0
207- )
225+ // Fallback: Check using ExprImplements, iterating backward.
226+ // Need field indices map again for this path.
227+ fieldIndices := make (map [* dst.Field ]int )
228+ index := 0
208229 for _ , field := range s .Results .List {
209230 fieldIndices [field ] = index
210231 count := len (field .Names )
@@ -214,7 +235,6 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
214235 index += count
215236 }
216237
217- // Loop backward through the results list.
218238 for i := len (s .Results .List ) - 1 ; i >= 0 ; i -- {
219239 field := s .Results .List [i ]
220240 if typed .ExprImplements (s .context , field .Type , iface ) {
@@ -266,7 +286,7 @@ func fieldAt(fields *dst.FieldList, index int, use string) (string, error) {
266286}
267287
268288func fieldOfType (fields * dst.FieldList , typeName string , use string ) (string , error ) {
269- tn , err := join .NewTypeName (typeName )
289+ tn , err := typed .NewTypeName (typeName )
270290 if err != nil {
271291 return "" , err
272292 }
@@ -293,118 +313,27 @@ func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, er
293313 return "" , nil
294314}
295315
296- // exprImplements checks if an expression's type implements an interface.
297- func exprImplements (ctx context.AdviceContext , expr dst.Expr , iface * types.Interface ) bool {
298- actualType := ctx .ResolveType (expr )
299- if actualType == nil {
300- return false
301- }
302-
303- return typeImplements (actualType , iface )
304- }
305-
306- // typeImplements checks if a type implements an interface (including pointer receivers).
307- func typeImplements (t types.Type , iface * types.Interface ) bool {
308- if t == nil || iface == nil {
309- return false
310- }
311-
312- // Direct implementation check.
313- if types .Implements (t , iface ) {
314- return true
315- }
316-
317- return false
318- }
319-
320- // resolveInterfaceTypeByName takes an interface name as a string and resolves it to an interface type.
321- // It supports built-in interfaces (e.g. "error"), package qualified interfaces (e.g. "io.Reader"),
322- // and third-party package interfaces (e.g. "example.com/pkg.Interface").
323- func resolveInterfaceTypeByName (name string ) (* types.Interface , error ) {
324- // Handle built-in types.
325- if obj := types .Universe .Lookup (name ); obj != nil {
326- typeObj , ok := obj .(* types.TypeName )
327- if ! ok {
328- return nil , fmt .Errorf ("object %s is not a type name but a %T" , name , obj )
329- }
330-
331- typ := typeObj .Type ()
332- if ! types .IsInterface (typ ) {
333- return nil , fmt .Errorf ("type %s is not an interface" , name )
334- }
335-
336- t , ok := typ .Underlying ().(* types.Interface )
337- if ! ok {
338- return nil , fmt .Errorf ("type %s is not an interface" , name )
339- }
340-
341- return t , nil
342- }
343-
344- // Handle package-qualified types (e.g., "io.Writer").
345- pkgName , typeName := splitPackageAndName (name )
346- if pkgName == "" {
347- return nil , fmt .Errorf ("invalid type name: %s" , name )
348- }
349-
350- // Import the package
351- imp := importer .Default ()
352- pkg , err := imp .Import (pkgName )
353- if err != nil {
354- return nil , fmt .Errorf ("failed to import package %q: %w" , pkgName , err )
355- }
356-
357- // Look up the type in the package's scope
358- obj := pkg .Scope ().Lookup (typeName )
359- if obj == nil {
360- return nil , fmt .Errorf ("type %q not found in package %q" , typeName , pkgName )
361- }
362-
363- typeObj , ok := obj .(* types.TypeName )
364- if ! ok {
365- return nil , fmt .Errorf ("object %s is not a type name but a %T" , name , obj )
366- }
367-
368- typ := typeObj .Type ()
369- if ! types .IsInterface (typ ) {
370- return nil , fmt .Errorf ("type %s is not an interface" , name )
371- }
372-
373- t , ok := typ .Underlying ().(* types.Interface )
374- if ! ok {
375- return nil , fmt .Errorf ("type %s is not an interface" , name )
376- }
377-
378- return t , nil
379- }
380-
381- // splitPackageAndName splits a fully qualified type name like "io.Reader" or "example.com/pkg.Type"
382- // into its package path and local name.
383- // Returns ("", "error") for built-in "error".
384- // Returns ("", "MyType") for unqualified "MyType".
385- func splitPackageAndName (fullName string ) (pkgPath string , localName string ) {
386- if ! strings .Contains (fullName , "." ) {
387- // Assume built-in type (like "error") or unqualified local type.
388- return "" , fullName
389- }
390- lastDot := strings .LastIndex (fullName , "." )
391- pkgPath = fullName [:lastDot ]
392- localName = fullName [lastDot + 1 :]
393- return pkgPath , localName
394- }
395-
396316// FinalResultImplements returns whether the final result implements the provided interface type.
397317func (s signature ) FinalResultImplements (interfaceName string ) (bool , error ) {
398318 if s .Results == nil || len (s .Results .List ) == 0 {
399319 return false , nil
400320 }
401321
402- iface , err := resolveInterfaceTypeByName (interfaceName )
322+ lastField := s .Results .List [len (s .Results .List )- 1 ]
323+
324+ // Optimization: First, check for an exact match using TypeName parsing.
325+ // Note: Not using FindMatchingTypeName as we only need to check the last field.
326+ if tn , err := typed .NewTypeName (interfaceName ); err == nil {
327+ if tn .Matches (lastField .Type ) {
328+ return true , nil
329+ }
330+ } // If parsing failed or no match, fall through to type resolution.
331+
332+ iface , err := typed .ResolveInterfaceTypeByName (interfaceName )
403333 if err != nil {
404334 return false , fmt .Errorf ("resolving interface type %q: %w" , interfaceName , err )
405335 }
406336
407337 // Check if the last field type implements the interface.
408- lastField := s .Results .List [len (s .Results .List )- 1 ]
409- return exprImplements (s .context , lastField .Type , iface ), nil
338+ return typed .ExprImplements (s .context , lastField .Type , iface ), nil
410339}
0 commit comments