1
+ /*
2
+ The MIT License (MIT)
3
+
4
+ Copyright (c) 2015,2016 Ernesto Jiménez
5
+ Copyright (c) 2019 Leigh McCulloch
6
+ Copyright (c) 2020 Matt Gorzka
7
+ Copyright (c) 2024 Simon Schulte
8
+ Copyright (c) 2023,2025 Olivier Mengué
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+ */
28
+
1
29
// This program reads all assertion functions from the assert package and
2
30
// automatically generates the corresponding requires and forwarded assertions
3
31
@@ -22,8 +50,6 @@ import (
22
50
"regexp"
23
51
"strings"
24
52
"text/template"
25
-
26
- "github.com/stretchr/testify/_codegen/internal/imports"
27
53
)
28
54
29
55
var (
@@ -42,17 +68,17 @@ func main() {
42
68
log .Fatal (err )
43
69
}
44
70
45
- importer , funcs , err := analyzeCode (scope , docs )
71
+ imports , funcs , err := analyzeCode (scope , docs )
46
72
if err != nil {
47
73
log .Fatal (err )
48
74
}
49
75
50
- if err := generateCode (importer , funcs ); err != nil {
76
+ if err := generateCode (imports , funcs ); err != nil {
51
77
log .Fatal (err )
52
78
}
53
79
}
54
80
55
- func generateCode (importer imports. Importer , funcs []testFunc ) error {
81
+ func generateCode (imports * imports , funcs []testFunc ) error {
56
82
buff := bytes .NewBuffer (nil )
57
83
58
84
tmplHead , tmplFunc , err := parseTemplates ()
@@ -66,7 +92,7 @@ func generateCode(importer imports.Importer, funcs []testFunc) error {
66
92
Imports map [string ]string
67
93
}{
68
94
* outputPkg ,
69
- importer . Imports () ,
95
+ imports . imports ,
70
96
}); err != nil {
71
97
return err
72
98
}
@@ -128,10 +154,13 @@ func outputFile() (*os.File, error) {
128
154
129
155
// analyzeCode takes the types scope and the docs and returns the import
130
156
// information and information about all the assertion functions.
131
- func analyzeCode (scope * types.Scope , docs * doc.Package ) (imports. Importer , []testFunc , error ) {
157
+ func analyzeCode (scope * types.Scope , docs * doc.Package ) (* imports , []testFunc , error ) {
132
158
testingT := scope .Lookup ("TestingT" ).Type ().Underlying ().(* types.Interface )
133
159
134
- importer := imports .New (* outputPkg )
160
+ importer := & imports {
161
+ currentPkg : * outputPkg ,
162
+ imports : map [string ]string {},
163
+ }
135
164
var funcs []testFunc
136
165
// Go through all the top level functions
137
166
for _ , fdocs := range docs .Funcs {
@@ -166,11 +195,43 @@ func analyzeCode(scope *types.Scope, docs *doc.Package) (imports.Importer, []tes
166
195
}
167
196
168
197
funcs = append (funcs , testFunc {* outputPkg , fdocs , fn })
169
- importer .AddImportsFrom (sig .Params ())
198
+ importer .addImportsFrom (sig .Params ())
170
199
}
171
200
return importer , funcs , nil
172
201
}
173
202
203
+ // imports collects a map of imported packages for a source file.
204
+ //
205
+ // This code has been copied from package github.com/ernesto-jimenez/gogen/imports
206
+ type imports struct {
207
+ currentPkg string
208
+ imports map [string ]string
209
+ }
210
+
211
+ func (imp * imports ) addImportsFrom (t types.Type ) {
212
+ switch el := t .(type ) {
213
+ case * types.Basic :
214
+ case * types.Slice :
215
+ imp .addImportsFrom (el .Elem ())
216
+ case * types.Pointer :
217
+ imp .addImportsFrom (el .Elem ())
218
+ case * types.Named :
219
+ pkg := el .Obj ().Pkg ()
220
+ if pkg == nil {
221
+ return
222
+ }
223
+ if pkg .Name () == imp .currentPkg {
224
+ return
225
+ }
226
+ imp .imports [pkg .Path ()] = pkg .Name ()
227
+ case * types.Tuple :
228
+ for i := 0 ; i < el .Len (); i ++ {
229
+ imp .addImportsFrom (el .At (i ).Type ())
230
+ }
231
+ default :
232
+ }
233
+ }
234
+
174
235
// parsePackageSource returns the types scope and the package documentation from the package
175
236
func parsePackageSource (pkg string ) (* types.Scope , * doc.Package , error ) {
176
237
pd , err := build .Import (pkg , "." , 0 )
0 commit comments