Skip to content

Commit 9b0064f

Browse files
authored
Merge pull request #199 from gofiber/codex/2025-08-24-18-47-23
2 parents c75b11e + 3f99046 commit 9b0064f

File tree

3 files changed

+144
-37
lines changed

3 files changed

+144
-37
lines changed

cmd/internal/migrations/v3/middleware_locals.go

Lines changed: 94 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package v3
33
import (
44
"fmt"
55
"regexp"
6+
"strings"
67

78
semver "github.com/Masterminds/semver/v3"
89
"github.com/spf13/cobra"
@@ -11,8 +12,9 @@ import (
1112
)
1213

1314
type ctxRepl struct {
14-
pkg string
15-
replFmt string
15+
pkg string
16+
replFmt string
17+
isDefault bool
1618
}
1719

1820
func parseMiddlewareImports(content string, reImport *regexp.Regexp) map[string]string {
@@ -29,29 +31,7 @@ func parseMiddlewareImports(content string, reImport *regexp.Regexp) map[string]
2931
}
3032

3133
func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Version) error {
32-
ctxMap := map[string][]ctxRepl{
33-
"requestid": {
34-
{pkg: "requestid", replFmt: "requestid.FromContext(%s)"},
35-
},
36-
"csrf": {
37-
{pkg: "csrf", replFmt: "csrf.TokenFromContext(%s)"},
38-
},
39-
"csrf_handler": {
40-
{pkg: "csrf", replFmt: "csrf.HandlerFromContext(%s)"},
41-
},
42-
"session": {
43-
{pkg: "session", replFmt: "session.FromContext(%s)"},
44-
},
45-
"username": {
46-
{pkg: "basicauth", replFmt: "basicauth.UsernameFromContext(%s)"},
47-
},
48-
"password": {
49-
{pkg: "basicauth", replFmt: "basicauth.PasswordFromContext(%s)"},
50-
},
51-
"token": {
52-
{pkg: "keyauth", replFmt: "keyauth.TokenFromContext(%s)"},
53-
},
54-
}
34+
ctxMap := map[string][]ctxRepl{}
5535

5636
extractors := []struct {
5737
pkg string
@@ -76,10 +56,16 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio
7656
if e.pkg != pkg {
7757
continue
7858
}
79-
re := regexp.MustCompile(alias + `\.Config{[^}]*` + e.field + `:\s*"([^"]+)"`)
80-
matches := re.FindAllStringSubmatch(content, -1)
59+
reCfg := regexp.MustCompile(regexp.QuoteMeta(alias) + `\.Config{`)
60+
matches := reCfg.FindAllStringIndex(content, -1)
8161
for _, m := range matches {
82-
ctxMap[m[1]] = append(ctxMap[m[1]], ctxRepl{pkg: e.pkg, replFmt: e.replFmt})
62+
start := m[0]
63+
end := extractBlock(content, m[1], '{', '}')
64+
cfg := content[start:end]
65+
reField := regexp.MustCompile(e.field + `:\s*"([^"]+)"`)
66+
for _, fm := range reField.FindAllStringSubmatch(cfg, -1) {
67+
ctxMap[fm[1]] = append(ctxMap[fm[1]], ctxRepl{pkg: e.pkg, replFmt: e.replFmt})
68+
}
8369
}
8470
}
8571
}
@@ -90,6 +76,30 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio
9076
return fmt.Errorf("failed to gather middleware locals: %w", err)
9177
}
9278

79+
defaults := map[string][]ctxRepl{
80+
"requestid": {{pkg: "requestid", replFmt: "requestid.FromContext(%s)", isDefault: true}},
81+
"csrf": {{pkg: "csrf", replFmt: "csrf.TokenFromContext(%s)", isDefault: true}},
82+
"csrf_handler": {{pkg: "csrf", replFmt: "csrf.HandlerFromContext(%s)", isDefault: true}},
83+
"session": {{pkg: "session", replFmt: "session.FromContext(%s)", isDefault: true}},
84+
"username": {{pkg: "basicauth", replFmt: "basicauth.UsernameFromContext(%s)", isDefault: true}},
85+
"password": {{pkg: "basicauth", replFmt: "basicauth.PasswordFromContext(%s)", isDefault: true}},
86+
"token": {{pkg: "keyauth", replFmt: "keyauth.TokenFromContext(%s)", isDefault: true}},
87+
}
88+
for key, repls := range defaults {
89+
for _, r := range repls {
90+
exists := false
91+
for _, existing := range ctxMap[key] {
92+
if existing.pkg == r.pkg {
93+
exists = true
94+
break
95+
}
96+
}
97+
if !exists {
98+
ctxMap[key] = append(ctxMap[key], r)
99+
}
100+
}
101+
}
102+
93103
// second pass: perform replacements and clean up
94104
changed, err := internal.ChangeFileContent(cwd, func(content string) string {
95105
imports := parseMiddlewareImports(content, reImport)
@@ -99,17 +109,45 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio
99109
sub := reLocals.FindStringSubmatch(s)
100110
ctx := sub[1]
101111
key := sub[2]
102-
if repls, ok := ctxMap[key]; ok {
103-
if len(repls) == 1 {
104-
return fmt.Sprintf(repls[0].replFmt, ctx)
112+
repls, ok := ctxMap[key]
113+
if !ok {
114+
return s
115+
}
116+
117+
var custom, defs []ctxRepl
118+
for _, r := range repls {
119+
if r.isDefault {
120+
defs = append(defs, r)
121+
} else {
122+
custom = append(custom, r)
105123
}
106-
for _, r := range repls {
124+
}
125+
126+
choose := func(r ctxRepl) string { return fmt.Sprintf(r.replFmt, ctx) }
127+
128+
if len(custom) == 1 {
129+
return choose(custom[0])
130+
}
131+
if len(custom) > 1 {
132+
for _, r := range custom {
107133
for _, pkg := range imports {
108134
if pkg == r.pkg {
109-
return fmt.Sprintf(r.replFmt, ctx)
135+
return choose(r)
110136
}
111137
}
112138
}
139+
return s
140+
}
141+
142+
if len(defs) == 1 {
143+
return choose(defs[0])
144+
}
145+
for _, r := range defs {
146+
for _, pkg := range imports {
147+
if pkg == r.pkg {
148+
return choose(r)
149+
}
150+
}
113151
}
114152
return s
115153
})
@@ -121,13 +159,32 @@ func MigrateMiddlewareLocals(cmd *cobra.Command, cwd string, _, _ *semver.Versio
121159
content = reComma.ReplaceAllString(content, "$1, $2 := $3, true")
122160

123161
for alias := range imports {
124-
reCfg := regexp.MustCompile(alias + `\.Config{[^}]*}`)
125-
content = reCfg.ReplaceAllStringFunc(content, func(cfg string) string {
162+
reCfg := regexp.MustCompile(regexp.QuoteMeta(alias) + `\.Config{`)
163+
matches := reCfg.FindAllStringIndex(content, -1)
164+
if len(matches) == 0 {
165+
continue
166+
}
167+
var b strings.Builder
168+
last := 0
169+
for _, m := range matches {
170+
if _, err := b.WriteString(content[last:m[0]]); err != nil {
171+
return content
172+
}
173+
start := m[0]
174+
end := extractBlock(content, m[1], '{', '}')
175+
cfg := content[start:end]
126176
cfg = removeConfigField(cfg, "ContextKey")
127177
cfg = removeConfigField(cfg, "ContextUsername")
128178
cfg = removeConfigField(cfg, "ContextPassword")
129-
return cfg
130-
})
179+
if _, err := b.WriteString(cfg); err != nil {
180+
return content
181+
}
182+
last = end
183+
}
184+
if _, err := b.WriteString(content[last:]); err != nil {
185+
return content
186+
}
187+
content = b.String()
131188
}
132189

133190
return content

cmd/internal/migrations/v3/middleware_locals_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,52 @@ func handler(c fiber.Ctx) error {
221221
assert.NotContains(t, cfgContent, "ContextKey")
222222
}
223223

224+
func Test_MigrateMiddlewareLocals_CustomContextKeyWithFunc(t *testing.T) {
225+
t.Parallel()
226+
227+
dir, err := os.MkdirTemp("", "mcustomctxfunc")
228+
require.NoError(t, err)
229+
defer func() { require.NoError(t, os.RemoveAll(dir)) }()
230+
231+
cfg := `package main
232+
import (
233+
"github.com/gofiber/fiber/v2"
234+
"github.com/gofiber/fiber/v2/middleware/csrf"
235+
)
236+
237+
var _ = csrf.New(csrf.Config{
238+
Next: func(c *fiber.Ctx) bool { return csrfActivated },
239+
ContextKey: "token",
240+
})`
241+
cfgPath := filepath.Join(dir, "config.go")
242+
require.NoError(t, os.WriteFile(cfgPath, []byte(cfg), 0o600))
243+
244+
handler := `package main
245+
import (
246+
"github.com/gofiber/fiber/v2"
247+
"github.com/gofiber/fiber/v2/middleware/keyauth"
248+
)
249+
250+
func handler(c fiber.Ctx) error {
251+
_ = keyauth.New(keyauth.Config{})
252+
token := c.Locals("token").(string)
253+
_ = token
254+
return nil
255+
}`
256+
handlerPath := filepath.Join(dir, "handler.go")
257+
require.NoError(t, os.WriteFile(handlerPath, []byte(handler), 0o600))
258+
259+
var buf bytes.Buffer
260+
cmd := newCmd(&buf)
261+
require.NoError(t, v3.MigrateMiddlewareLocals(cmd, dir, nil, nil))
262+
263+
content := readFile(t, handlerPath)
264+
assert.Contains(t, content, `token := csrf.TokenFromContext(c)`)
265+
assert.NotContains(t, content, `keyauth.TokenFromContext`)
266+
cfgContent := readFile(t, cfgPath)
267+
assert.NotContains(t, cfgContent, "ContextKey")
268+
}
269+
224270
func Test_MigrateMiddlewareLocals_SameContextKeyDifferentPackages(t *testing.T) {
225271
t.Parallel()
226272

cmd/migrate_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ import (
7979
"github.com/gofiber/fiber/v2"
8080
"github.com/gofiber/fiber/v2/middleware/monitor"
8181
"github.com/gofiber/fiber/v2/middleware/csrf"
82+
"github.com/gofiber/fiber/v2/middleware/keyauth"
8283
)
8384
8485
func handler(c *fiber.Ctx) error {
@@ -103,6 +104,7 @@ func main() {
103104
Network: "tcp",
104105
})
105106
app.Use(csrf.New(csrf.Config{ContextKey: "token"}))
107+
app.Use(keyauth.New(keyauth.Config{}))
106108
app.Static("/", "./public")
107109
app.Add(fiber.MethodGet, "/foo", handler)
108110
app.Mount("/api", app)
@@ -128,6 +130,7 @@ func main() {
128130
at := assert.New(t)
129131
at.Contains(content, "github.com/gofiber/fiber/v3")
130132
at.Contains(content, "github.com/gofiber/contrib/monitor")
133+
at.Contains(content, "github.com/gofiber/fiber/v3/middleware/keyauth")
131134
at.NotContains(content, "*fiber.Ctx")
132135
at.Contains(content, "fiber.Ctx")
133136
at.Contains(content, ".Bind().Body(&v)")
@@ -136,6 +139,7 @@ func main() {
136139
at.Contains(content, ".Redirect().Back()")
137140
at.Contains(content, "fiber.Params[int](c, \"id\"")
138141
at.Contains(content, "csrf.TokenFromContext(c)")
142+
at.NotContains(content, "keyauth.TokenFromContext")
139143
at.NotContains(content, "ContextKey")
140144
at.Contains(content, ".Use(\"/api\", app)")
141145
at.Contains(content, ".Listen(")

0 commit comments

Comments
 (0)