Skip to content

Commit f5ea563

Browse files
committed
refactor(injector): Refactor TypeName handling and migrate to typed package
This commit refactors the handling of TypeName across the injector package, transitioning from the join package to a new typed package. The changes include: - Replacing instances of join.TypeName with typed.TypeName in various files, ensuring consistent type handling. - Updating related functions and methods to accommodate the new TypeName structure, including adjustments to import path retrieval and pointer handling. - Adding comprehensive tests for the new TypeName implementation to validate its functionality and error handling. This refactor enhances code clarity and maintainability by centralizing type name logic within the typed package. Signed-off-by: Kemal Akkoyun <[email protected]>
1 parent f7476cf commit f5ea563

File tree

13 files changed

+320
-328
lines changed

13 files changed

+320
-328
lines changed

_docs/generator/template-funcs.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@ import (
1717
"strings"
1818
"unicode"
1919

20+
"golang.org/x/tools/go/packages"
21+
2022
"github.com/DataDog/orchestrion/internal/injector/aspect/advice"
2123
"github.com/DataDog/orchestrion/internal/injector/aspect/advice/code"
2224
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
23-
"golang.org/x/tools/go/packages"
25+
"github.com/DataDog/orchestrion/internal/injector/typed"
2426
)
2527

2628
var (
@@ -65,7 +67,7 @@ func render(val any) (template.HTML, error) {
6567

6668
templateName := "doc."
6769
switch val := val.(type) {
68-
case join.Point, join.TypeName, join.FunctionOption:
70+
case join.Point, typed.TypeName, join.FunctionOption:
6971
templateName += "join"
7072
case advice.Advice:
7173
templateName += "advice"

internal/injector/aspect/advice/call.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,20 @@ import (
1414
"github.com/DataDog/orchestrion/internal/fingerprint"
1515
"github.com/DataDog/orchestrion/internal/injector/aspect/advice/code"
1616
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
17-
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
17+
"github.com/DataDog/orchestrion/internal/injector/typed"
1818
"github.com/DataDog/orchestrion/internal/yaml"
1919
"github.com/dave/dst"
2020
"github.com/goccy/go-yaml/ast"
2121
)
2222

2323
type appendArgs struct {
24-
TypeName join.TypeName
24+
TypeName typed.TypeName
2525
Templates []*code.Template
2626
}
2727

2828
// AppendArgs appends arguments of a given type to the end of a function call. All arguments must be
2929
// of the same type, as they may be appended at the tail end of a variadic call.
30-
func AppendArgs(typeName join.TypeName, templates ...*code.Template) *appendArgs {
30+
func AppendArgs(typeName typed.TypeName, templates ...*code.Template) *appendArgs {
3131
return &appendArgs{typeName, templates}
3232
}
3333

@@ -92,7 +92,7 @@ func (a *appendArgs) Apply(ctx context.AdviceContext) (bool, error) {
9292
Ellipsis: true,
9393
}
9494

95-
if importPath := a.TypeName.ImportPath(); importPath != "" {
95+
if importPath := a.TypeName.ImportPath; importPath != "" {
9696
ctx.AddImport(importPath, inferPkgName(importPath))
9797
}
9898

@@ -101,7 +101,7 @@ func (a *appendArgs) Apply(ctx context.AdviceContext) (bool, error) {
101101

102102
func (a *appendArgs) AddedImports() []string {
103103
imports := make([]string, 0, len(a.Templates)+1)
104-
if argTypeImportPath := a.TypeName.ImportPath(); argTypeImportPath != "" {
104+
if argTypeImportPath := a.TypeName.ImportPath; argTypeImportPath != "" {
105105
imports = append(imports, argTypeImportPath)
106106
}
107107
for _, t := range a.Templates {
@@ -168,7 +168,7 @@ func init() {
168168
return nil, err
169169
}
170170

171-
tn, err := join.NewTypeName(args.TypeName)
171+
tn, err := typed.NewTypeName(args.TypeName)
172172
if err != nil {
173173
return nil, err
174174
}

internal/injector/aspect/advice/call_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,31 @@ import (
1111
"github.com/DataDog/orchestrion/internal/injector/aspect/advice"
1212
"github.com/DataDog/orchestrion/internal/injector/aspect/advice/code"
1313
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
14-
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
14+
"github.com/DataDog/orchestrion/internal/injector/typed"
1515
"github.com/stretchr/testify/assert"
1616
"github.com/stretchr/testify/require"
1717
)
1818

1919
func TestAppendArgs(t *testing.T) {
2020
t.Run("AddedImports", func(t *testing.T) {
2121
type testCase struct {
22-
argType join.TypeName
22+
argType typed.TypeName
2323
args []*code.Template
2424
expectedImports []string
2525
}
2626

2727
testCases := map[string]testCase{
2828
"imports-none": {
29-
argType: join.MustTypeName("any"),
29+
argType: typed.MustTypeName("any"),
3030
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
3131
},
3232
"imports-from-arg-type": {
33-
argType: join.MustTypeName("*net/http.Request"),
33+
argType: typed.MustTypeName("*net/http.Request"),
3434
args: []*code.Template{code.MustTemplate("true", nil, context.GoLangVersion{})},
3535
expectedImports: []string{"net/http"},
3636
},
3737
"imports-from-templates": {
38-
argType: join.MustTypeName("any"),
38+
argType: typed.MustTypeName("any"),
3939
args: []*code.Template{
4040
code.MustTemplate("imp.Value", map[string]string{"imp": "github.com/namespace/foo"}, context.GoLangVersion{}),
4141
code.MustTemplate("imp.Value", map[string]string{"imp": "github.com/namespace/bar"}, context.GoLangVersion{}),

internal/injector/aspect/advice/code/dot_function.go

Lines changed: 44 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,10 @@ package code
88
import (
99
"errors"
1010
"fmt"
11-
"go/importer"
12-
"go/types"
13-
"strings"
1411

1512
"github.com/dave/dst"
1613

1714
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
18-
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
1915
"github.com/DataDog/orchestrion/internal/injector/typed"
2016
)
2117

@@ -163,6 +159,11 @@ func (s signature) ResultThatImplements(name string) (string, error) {
163159
return "", nil
164160
}
165161

162+
// Optimization: First, check for an exact match using the helper.
163+
if index, found := typed.FindMatchingTypeName(s.Results, name); found {
164+
return fieldAt(s.Results, index, "result")
165+
} // If not found, fall through to type resolution.
166+
166167
// Resolve the interface type.
167168
iface, err := typed.ResolveInterfaceTypeByName(name)
168169
if err != nil {
@@ -193,18 +194,38 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
193194
return "", nil
194195
}
195196

197+
// Optimization: First, check for an exact match using TypeName parsing, finding the last one.
198+
lastMatchIndex := -1
199+
if tn, err := typed.NewTypeName(name); err == nil {
200+
currentIndex := 0
201+
for _, field := range s.Results.List {
202+
if tn.Matches(field.Type) {
203+
lastMatchIndex = currentIndex // Update last found index
204+
}
205+
// Increment index by the number of names in the field (or 1 if unnamed).
206+
count := len(field.Names)
207+
if count == 0 {
208+
count = 1
209+
}
210+
currentIndex += count
211+
}
212+
}
213+
// If we found a match via TypeName, return it.
214+
if lastMatchIndex != -1 {
215+
return fieldAt(s.Results, lastMatchIndex, "result")
216+
} // If parsing failed or no match, fall through to type resolution.
217+
196218
// Resolve the interface type.
197219
iface, err := typed.ResolveInterfaceTypeByName(name)
198220
if err != nil {
221+
// Propagate error if interface resolution fails
199222
return "", fmt.Errorf("resolving interface type %q: %w", name, err)
200223
}
201224

202-
// First, we need to build a map of result fields to their indices
203-
// that takes into account named and unnamed parameters.
204-
var (
205-
fieldIndices = make(map[*dst.Field]int)
206-
index = 0
207-
)
225+
// Fallback: Check using ExprImplements, iterating backward.
226+
// Need field indices map again for this path.
227+
fieldIndices := make(map[*dst.Field]int)
228+
index := 0
208229
for _, field := range s.Results.List {
209230
fieldIndices[field] = index
210231
count := len(field.Names)
@@ -214,7 +235,6 @@ func (s signature) LastResultThatImplements(name string) (string, error) {
214235
index += count
215236
}
216237

217-
// Loop backward through the results list.
218238
for i := len(s.Results.List) - 1; i >= 0; i-- {
219239
field := s.Results.List[i]
220240
if typed.ExprImplements(s.context, field.Type, iface) {
@@ -266,7 +286,7 @@ func fieldAt(fields *dst.FieldList, index int, use string) (string, error) {
266286
}
267287

268288
func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, error) {
269-
tn, err := join.NewTypeName(typeName)
289+
tn, err := typed.NewTypeName(typeName)
270290
if err != nil {
271291
return "", err
272292
}
@@ -293,118 +313,27 @@ func fieldOfType(fields *dst.FieldList, typeName string, use string) (string, er
293313
return "", nil
294314
}
295315

296-
// exprImplements checks if an expression's type implements an interface.
297-
func exprImplements(ctx context.AdviceContext, expr dst.Expr, iface *types.Interface) bool {
298-
actualType := ctx.ResolveType(expr)
299-
if actualType == nil {
300-
return false
301-
}
302-
303-
return typeImplements(actualType, iface)
304-
}
305-
306-
// typeImplements checks if a type implements an interface (including pointer receivers).
307-
func typeImplements(t types.Type, iface *types.Interface) bool {
308-
if t == nil || iface == nil {
309-
return false
310-
}
311-
312-
// Direct implementation check.
313-
if types.Implements(t, iface) {
314-
return true
315-
}
316-
317-
return false
318-
}
319-
320-
// resolveInterfaceTypeByName takes an interface name as a string and resolves it to an interface type.
321-
// It supports built-in interfaces (e.g. "error"), package qualified interfaces (e.g. "io.Reader"),
322-
// and third-party package interfaces (e.g. "example.com/pkg.Interface").
323-
func resolveInterfaceTypeByName(name string) (*types.Interface, error) {
324-
// Handle built-in types.
325-
if obj := types.Universe.Lookup(name); obj != nil {
326-
typeObj, ok := obj.(*types.TypeName)
327-
if !ok {
328-
return nil, fmt.Errorf("object %s is not a type name but a %T", name, obj)
329-
}
330-
331-
typ := typeObj.Type()
332-
if !types.IsInterface(typ) {
333-
return nil, fmt.Errorf("type %s is not an interface", name)
334-
}
335-
336-
t, ok := typ.Underlying().(*types.Interface)
337-
if !ok {
338-
return nil, fmt.Errorf("type %s is not an interface", name)
339-
}
340-
341-
return t, nil
342-
}
343-
344-
// Handle package-qualified types (e.g., "io.Writer").
345-
pkgName, typeName := splitPackageAndName(name)
346-
if pkgName == "" {
347-
return nil, fmt.Errorf("invalid type name: %s", name)
348-
}
349-
350-
// Import the package
351-
imp := importer.Default()
352-
pkg, err := imp.Import(pkgName)
353-
if err != nil {
354-
return nil, fmt.Errorf("failed to import package %q: %w", pkgName, err)
355-
}
356-
357-
// Look up the type in the package's scope
358-
obj := pkg.Scope().Lookup(typeName)
359-
if obj == nil {
360-
return nil, fmt.Errorf("type %q not found in package %q", typeName, pkgName)
361-
}
362-
363-
typeObj, ok := obj.(*types.TypeName)
364-
if !ok {
365-
return nil, fmt.Errorf("object %s is not a type name but a %T", name, obj)
366-
}
367-
368-
typ := typeObj.Type()
369-
if !types.IsInterface(typ) {
370-
return nil, fmt.Errorf("type %s is not an interface", name)
371-
}
372-
373-
t, ok := typ.Underlying().(*types.Interface)
374-
if !ok {
375-
return nil, fmt.Errorf("type %s is not an interface", name)
376-
}
377-
378-
return t, nil
379-
}
380-
381-
// splitPackageAndName splits a fully qualified type name like "io.Reader" or "example.com/pkg.Type"
382-
// into its package path and local name.
383-
// Returns ("", "error") for built-in "error".
384-
// Returns ("", "MyType") for unqualified "MyType".
385-
func splitPackageAndName(fullName string) (pkgPath string, localName string) {
386-
if !strings.Contains(fullName, ".") {
387-
// Assume built-in type (like "error") or unqualified local type.
388-
return "", fullName
389-
}
390-
lastDot := strings.LastIndex(fullName, ".")
391-
pkgPath = fullName[:lastDot]
392-
localName = fullName[lastDot+1:]
393-
return pkgPath, localName
394-
}
395-
396316
// FinalResultImplements returns whether the final result implements the provided interface type.
397317
func (s signature) FinalResultImplements(interfaceName string) (bool, error) {
398318
if s.Results == nil || len(s.Results.List) == 0 {
399319
return false, nil
400320
}
401321

402-
iface, err := resolveInterfaceTypeByName(interfaceName)
322+
lastField := s.Results.List[len(s.Results.List)-1]
323+
324+
// Optimization: First, check for an exact match using TypeName parsing.
325+
// Note: Not using FindMatchingTypeName as we only need to check the last field.
326+
if tn, err := typed.NewTypeName(interfaceName); err == nil {
327+
if tn.Matches(lastField.Type) {
328+
return true, nil
329+
}
330+
} // If parsing failed or no match, fall through to type resolution.
331+
332+
iface, err := typed.ResolveInterfaceTypeByName(interfaceName)
403333
if err != nil {
404334
return false, fmt.Errorf("resolving interface type %q: %w", interfaceName, err)
405335
}
406336

407337
// Check if the last field type implements the interface.
408-
lastField := s.Results.List[len(s.Results.List)-1]
409-
return exprImplements(s.context, lastField.Type, iface), nil
338+
return typed.ExprImplements(s.context, lastField.Type, iface), nil
410339
}

internal/injector/aspect/advice/struct.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,19 @@ import (
1111

1212
"github.com/DataDog/orchestrion/internal/fingerprint"
1313
"github.com/DataDog/orchestrion/internal/injector/aspect/context"
14-
"github.com/DataDog/orchestrion/internal/injector/aspect/join"
14+
"github.com/DataDog/orchestrion/internal/injector/typed"
1515
"github.com/DataDog/orchestrion/internal/yaml"
1616
"github.com/dave/dst"
1717
"github.com/goccy/go-yaml/ast"
1818
)
1919

2020
type addStructField struct {
2121
Name string
22-
TypeName join.TypeName
22+
TypeName typed.TypeName
2323
}
2424

2525
// AddStructField adds a new synthetic field at the tail end of a struct declaration.
26-
func AddStructField(fieldName string, fieldType join.TypeName) *addStructField {
26+
func AddStructField(fieldName string, fieldType typed.TypeName) *addStructField {
2727
return &addStructField{fieldName, fieldType}
2828
}
2929

@@ -47,7 +47,7 @@ func (a *addStructField) Apply(ctx context.AdviceContext) (bool, error) {
4747
Type: a.TypeName.AsNode(),
4848
})
4949

50-
if importPath := a.TypeName.ImportPath(); importPath != "" {
50+
if importPath := a.TypeName.ImportPath; importPath != "" {
5151
// If the type name is qualified, we may need to import the package, too.
5252
_ = ctx.AddImport(importPath, inferPkgName(importPath))
5353
}
@@ -60,7 +60,7 @@ func (a *addStructField) Hash(h *fingerprint.Hasher) error {
6060
}
6161

6262
func (a *addStructField) AddedImports() []string {
63-
if path := a.TypeName.ImportPath(); path != "" {
63+
if path := a.TypeName.ImportPath; path != "" {
6464
return []string{path}
6565
}
6666
return nil
@@ -76,7 +76,7 @@ func init() {
7676
if err := yaml.NodeToValueContext(ctx, node, &spec); err != nil {
7777
return nil, err
7878
}
79-
tn, err := join.NewTypeName(spec.Type)
79+
tn, err := typed.NewTypeName(spec.Type)
8080
if err != nil {
8181
return nil, err
8282
}

0 commit comments

Comments
 (0)