Skip to content

Commit b341ca5

Browse files
committed
fix(internal/injector): Fix named type issues discovered during integration testing
Signed-off-by: Kemal Akkoyun <[email protected]>
1 parent 01faa1c commit b341ca5

File tree

2 files changed

+71
-11
lines changed

2 files changed

+71
-11
lines changed

internal/injector/aspect/advice/struct.go

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ import (
1919

2020
type addStructField struct {
2121
Name string
22-
TypeName *typed.NamedType
22+
TypeExpr typed.Type
2323
}
2424

2525
// AddStructField adds a new synthetic field at the tail end of a struct declaration.
26-
func AddStructField(fieldName string, fieldType *typed.NamedType) *addStructField {
26+
func AddStructField(fieldName string, fieldType typed.Type) *addStructField {
2727
return &addStructField{fieldName, fieldType}
2828
}
2929

@@ -44,24 +44,28 @@ func (a *addStructField) Apply(ctx context.AdviceContext) (bool, error) {
4444

4545
typeDef.Fields.List = append(typeDef.Fields.List, &dst.Field{
4646
Names: []*dst.Ident{dst.NewIdent(a.Name)},
47-
Type: a.TypeName.AsNode(),
47+
Type: a.TypeExpr.AsNode(),
4848
})
4949

50-
if importPath := a.TypeName.ImportPath; importPath != "" {
51-
// If the type name is qualified, we may need to import the package, too.
52-
_ = ctx.AddImport(importPath, inferPkgName(importPath))
50+
if namedType, err := typed.ExtractNamedType(a.TypeExpr); err == nil {
51+
if importPath := namedType.ImportPath; importPath != "" {
52+
// If the type name is qualified, we may need to import the package, too.
53+
_ = ctx.AddImport(importPath, inferPkgName(importPath))
54+
}
5355
}
5456

5557
return true, nil
5658
}
5759

5860
func (a *addStructField) Hash(h *fingerprint.Hasher) error {
59-
return h.Named("add-struct-field", fingerprint.String(a.Name), a.TypeName)
61+
return h.Named("add-struct-field", fingerprint.String(a.Name), a.TypeExpr)
6062
}
6163

6264
func (a *addStructField) AddedImports() []string {
63-
if path := a.TypeName.ImportPath; path != "" {
64-
return []string{path}
65+
if namedType, err := typed.ExtractNamedType(a.TypeExpr); err == nil {
66+
if path := namedType.ImportPath; path != "" {
67+
return []string{path}
68+
}
6569
}
6670
return nil
6771
}
@@ -76,11 +80,12 @@ func init() {
7680
if err := yaml.NodeToValueContext(ctx, node, &spec); err != nil {
7781
return nil, err
7882
}
79-
namedType, err := typed.NewNamedType(spec.Type)
83+
// Use NewType instead of NewNamedType to preserve pointer information
84+
typeExpr, err := typed.NewType(spec.Type)
8085
if err != nil {
8186
return nil, fmt.Errorf("invalid type %q: %w", spec.Type, err)
8287
}
8388

84-
return AddStructField(spec.Name, namedType), nil
89+
return AddStructField(spec.Name, typeExpr), nil
8590
}
8691
}

internal/injector/typed/namedtype_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,58 @@ func TestNamedType_Matches(t *testing.T) {
261261
})
262262
}
263263
}
264+
265+
func TestNewNamedTypePointerHandling(t *testing.T) {
266+
testCases := []struct {
267+
name string
268+
typeStr string
269+
isPtr bool
270+
}{
271+
{
272+
name: "value type",
273+
typeStr: "string",
274+
isPtr: false,
275+
},
276+
{
277+
name: "pointer type",
278+
typeStr: "*string",
279+
isPtr: true,
280+
},
281+
{
282+
name: "qualified pointer type",
283+
typeStr: "*kafkatrace.Tracer",
284+
isPtr: true,
285+
},
286+
}
287+
288+
for _, tc := range testCases {
289+
t.Run(tc.name, func(t *testing.T) {
290+
t.Run("NewNamedType", func(t *testing.T) {
291+
namedType, err := NewNamedType(tc.typeStr)
292+
require.NoError(t, err)
293+
294+
node := namedType.AsNode()
295+
296+
_, isStarExpr := node.(*dst.StarExpr)
297+
require.False(t, isStarExpr,
298+
"NewNamedType(%s) should strip pointer info but got StarExpr", tc.typeStr)
299+
})
300+
301+
t.Run("NewType", func(t *testing.T) {
302+
typeExpr, err := NewType(tc.typeStr)
303+
require.NoError(t, err)
304+
305+
node := typeExpr.AsNode()
306+
307+
_, isStarExpr := node.(*dst.StarExpr)
308+
if tc.isPtr {
309+
require.True(t, isStarExpr,
310+
"NewType(%s) should preserve pointer info", tc.typeStr)
311+
} else {
312+
require.False(t, isStarExpr,
313+
"NewType(%s) should not be a pointer", tc.typeStr)
314+
}
315+
})
316+
})
317+
}
318+
}

0 commit comments

Comments
 (0)