Skip to content

Commit e1ea456

Browse files
authored
Refactor getMacroMatches & fix a couple of bugs (#863)
This addresses #859 and #861 and removes some duplicated logic.
1 parent d36c501 commit e1ea456

File tree

2 files changed

+147
-100
lines changed

2 files changed

+147
-100
lines changed

data/sqlutil/macros.go

Lines changed: 50 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -124,110 +124,72 @@ var DefaultMacros = Macros{
124124
"column": macroColumn,
125125
}
126126

127-
type Macro struct {
128-
Name string
129-
Args []string
127+
type macroMatch struct {
128+
full string
129+
args []string
130130
}
131131

132132
// getMacroMatches extracts macro strings with their respective arguments from the sql input given
133133
// It manually parses the string to find the closing parenthesis of the macro (because regex has no memory)
134-
func getMacroMatches(input string, name string) ([]Macro, error) {
135-
macroName := fmt.Sprintf("\\$__%s\\b", name)
136-
matchedMacros := []Macro{}
137-
rgx, err := regexp.Compile(macroName)
134+
func getMacroMatches(input string, name string) ([]macroMatch, error) {
135+
rgx, err := regexp.Compile(fmt.Sprintf(`\$__%s\b`, name))
138136

139137
if err != nil {
140138
return nil, err
141139
}
142140

143-
// get all matching macro instances
144-
matched := rgx.FindAllStringIndex(input, -1)
141+
var matches []macroMatch
142+
for _, window := range rgx.FindAllStringIndex(input, -1) {
143+
start, end := window[0], window[1]
144+
args, length := parseArgs(input[end:])
145+
if length < 0 {
146+
return nil, fmt.Errorf("failed to parse macro arguments (missing close bracket?)")
147+
}
148+
matches = append(matches, macroMatch{full: input[start : end+length], args: args})
149+
}
150+
return matches, nil
151+
}
145152

146-
if matched == nil {
147-
return nil, nil
153+
// parseArgs looks for a bracketed argument list at the beginning of argString.
154+
// If one is present, returns a list of whitespace-trimmed arguments and the
155+
// length of the string comprising the bracketed argument list.
156+
func parseArgs(argString string) ([]string, int) {
157+
if !strings.HasPrefix(argString, "(") {
158+
return nil, 0 // single empty arg for backwards compatibility
148159
}
149160

150-
for matchedIndex := 0; matchedIndex < len(matched); matchedIndex++ {
151-
var macroEnd = 0
152-
var argStart = 0
153-
// quick exit from the loop, when we encounter a closing bracket before an opening one (ie "($__macro)", where we can skip the closing one from the result)
154-
var forceBreak = false
155-
macroStart := matched[matchedIndex][0]
156-
inputCopy := input[macroStart:]
157-
cache := make([]rune, 0)
161+
var args []string
162+
depth := 0
163+
arg := []rune{}
158164

159-
// find the opening and closing arguments brackets
160-
for idx, r := range inputCopy {
161-
if len(cache) == 0 && macroEnd > 0 || forceBreak {
162-
break
165+
for i, r := range argString {
166+
switch r {
167+
case '(':
168+
depth++
169+
if depth == 1 {
170+
// don't include the outer bracket in the arg
171+
continue
163172
}
164-
switch r {
165-
case '(':
166-
cache = append(cache, r)
167-
if argStart == 0 {
168-
argStart = idx + 1
169-
}
170-
case ' ':
171-
// when we are inside an argument, we do not want to exit on space
172-
if argStart != 0 {
173-
continue
174-
}
175-
fallthrough
176-
case ')':
177-
l := len(cache)
178-
if l == 0 {
179-
macroEnd = 0
180-
forceBreak = true
181-
break
182-
}
183-
cache = cache[:l-1]
184-
macroEnd = idx + 1
185-
default:
173+
case ')':
174+
depth--
175+
if depth == 0 {
176+
// closing bracket
177+
args = append(args, strings.TrimSpace(string(arg)))
178+
return args, i + 1
179+
}
180+
case ',':
181+
if depth == 1 {
182+
// a comma at this level is separating args
183+
args = append(args, strings.TrimSpace(string(arg)))
184+
arg = []rune{}
186185
continue
187186
}
188187
}
189-
190-
// macroEnd equals to 0 means there are no parentheses, so just set it
191-
// to the end of the regex match
192-
if macroEnd == 0 {
193-
macroEnd = matched[matchedIndex][1] - macroStart
194-
}
195-
macroString := inputCopy[0:macroEnd]
196-
macroMatch := Macro{Name: macroString}
197-
198-
args := ""
199-
// if opening parenthesis was found, extract contents as arguments
200-
if argStart > 0 {
201-
args = inputCopy[argStart : macroEnd-1]
202-
}
203-
macroMatch.Args = parseArgs(args)
204-
matchedMacros = append(matchedMacros, macroMatch)
205-
}
206-
return matchedMacros, nil
207-
}
208-
209-
func parseArgs(args string) []string {
210-
argsArray := []string{}
211-
phrase := []rune{}
212-
bracketCount := 0
213-
for _, v := range args {
214-
phrase = append(phrase, v)
215-
if v == '(' {
216-
bracketCount++
217-
continue
218-
}
219-
if v == ')' {
220-
bracketCount--
221-
continue
222-
}
223-
if v == ',' && bracketCount == 0 {
224-
removeComma := phrase[:len(phrase)-1]
225-
argsArray = append(argsArray, string(removeComma))
226-
phrase = []rune{}
227-
}
188+
arg = append(arg, r)
228189
}
229-
argsArray = append(argsArray, strings.TrimSpace(string(phrase)))
230-
return argsArray
190+
// If we get here, we have seen an open bracket but not a close bracket. This
191+
// would formerly cause a panic; now it is treated as an error.
192+
return nil, -1
231193
}
232194

233195
// Interpolate returns an interpolated query string given a backend.DataQuery
@@ -243,17 +205,14 @@ func Interpolate(query *Query, macros Macros) (string, error) {
243205
if err != nil {
244206
return rawSQL, err
245207
}
246-
if len(matches) == 0 {
247-
continue
248-
}
249208

250209
for _, match := range matches {
251-
res, err := macro(query.WithSQL(rawSQL), match.Args)
210+
res, err := macro(query.WithSQL(rawSQL), match.args)
252211
if err != nil {
253212
return rawSQL, err
254213
}
255214

256-
rawSQL = strings.ReplaceAll(rawSQL, match.Name, res)
215+
rawSQL = strings.ReplaceAll(rawSQL, match.full, res)
257216
}
258217
}
259218

data/sqlutil/macros_test.go

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"testing"
77

88
"github.com/stretchr/testify/assert"
9-
"github.com/stretchr/testify/require"
109
)
1110

1211
func staticMacro(output string) MacroFunc {
@@ -26,8 +25,9 @@ var macros = Macros{
2625
}
2726
return "bar"
2827
}),
29-
"f": staticMacro("f(1)"),
30-
"g": staticMacro("g(1)"),
28+
"f": staticMacro("f(1)"),
29+
"g": staticMacro("g(1)"),
30+
"num": staticMacro("10000"),
3131
"multiParams": argMacro(func(args []string) string {
3232
return strings.Join(append([]string{"bar"}, args...), "_")
3333
}),
@@ -39,9 +39,10 @@ func TestInterpolate(t *testing.T) {
3939
tableName := "my_table"
4040
tableColumn := "my_col"
4141
type test struct {
42-
name string
43-
input string
44-
output string
42+
name string
43+
input string
44+
output string
45+
wantErr bool
4546
}
4647
tests := []test{
4748
{
@@ -174,6 +175,26 @@ func TestInterpolate(t *testing.T) {
174175
output: "select * from foo where f(1) > g(1)",
175176
name: "don't consume args after a space (see https://github.com/grafana/sqlds/issues/82)",
176177
},
178+
{
179+
input: "select * from foo where $__num*(table.a + table.b) > 1000000",
180+
output: "select * from foo where 10000*(table.a + table.b) > 1000000",
181+
name: "don't consume args after other non-paren characters",
182+
},
183+
{
184+
input: "select * from foo where $__params(whoops",
185+
name: "error (not panic) on missing close bracket",
186+
wantErr: true,
187+
},
188+
{
189+
input: "select * from foo where $__params(FUNC(foo, bar)",
190+
name: "error on missing close bracket, nested",
191+
wantErr: true,
192+
},
193+
{
194+
input: "select * from foo where $__params(FUNC(foo, bar)) > $__timeTo(uhoh",
195+
name: "error on missing close bracket after good macros",
196+
wantErr: true,
197+
},
177198
}
178199
for i, tc := range tests {
179200
t.Run(fmt.Sprintf("[%d/%d] %s", i+1, len(tests), tc.name), func(t *testing.T) {
@@ -183,8 +204,75 @@ func TestInterpolate(t *testing.T) {
183204
Column: tableColumn,
184205
}
185206
interpolatedQuery, err := Interpolate(query, macros)
186-
require.Nil(t, err)
187-
assert.Equal(t, tc.output, interpolatedQuery)
207+
assert.Equal(t, err != nil, tc.wantErr, "wantErr != gotErr")
208+
if !tc.wantErr {
209+
assert.Equal(t, tc.output, interpolatedQuery)
210+
}
211+
})
212+
}
213+
}
214+
215+
func Test_parseArgs(t *testing.T) {
216+
var tests = []struct {
217+
name string
218+
input string
219+
wantArgs []string
220+
wantLength int
221+
}{
222+
{
223+
name: "no parens, no args",
224+
input: "foo bar",
225+
wantArgs: nil,
226+
wantLength: 0,
227+
},
228+
{
229+
name: "parens not at beginning, still no args",
230+
input: "foo(bar)",
231+
wantArgs: nil,
232+
wantLength: 0,
233+
},
234+
{
235+
name: "even just a space is enough",
236+
input: " (bar)",
237+
wantArgs: nil,
238+
wantLength: 0,
239+
},
240+
{
241+
name: "simple one-arg case",
242+
input: "(bar)",
243+
wantArgs: []string{"bar"},
244+
wantLength: 5,
245+
},
246+
{
247+
name: "multiple args, spaces",
248+
input: "(bar, baz, quux)",
249+
wantArgs: []string{"bar", "baz", "quux"},
250+
wantLength: 16,
251+
},
252+
{
253+
name: "nested parens are not parsed further",
254+
input: "(bar(some,thing))",
255+
wantArgs: []string{"bar(some,thing)"},
256+
wantLength: 17,
257+
},
258+
{
259+
name: "stuff after the closing bracket is ignored",
260+
input: "(arg1, arg2), not_an_arg, nope",
261+
wantArgs: []string{"arg1", "arg2"},
262+
wantLength: 12,
263+
},
264+
{
265+
name: "missing close bracket is an error",
266+
input: "(arg1, arg2",
267+
wantArgs: nil,
268+
wantLength: -1,
269+
},
270+
}
271+
for _, tt := range tests {
272+
t.Run(tt.name, func(t *testing.T) {
273+
args, length := parseArgs(tt.input)
274+
assert.Equalf(t, tt.wantArgs, args, "parseArgs(%v) wrong args:\nwant: %v\ngot: %v\n", tt.input, tt.wantArgs, args)
275+
assert.Equalf(t, tt.wantLength, length, "parseArgs(%v) wrong length:\nwant: %d\ngot: %d\n", tt.input, tt.wantLength, length)
188276
})
189277
}
190278
}
@@ -195,7 +283,7 @@ func TestGetMacroMatches(t *testing.T) {
195283
matches, err := getMacroMatches(fmt.Sprintf("$__%s", macroName), macroName)
196284

197285
assert.NoError(t, err)
198-
assert.Equal(t, []Macro{{fmt.Sprintf("$__%s", macroName), []string{""}}}, matches)
286+
assert.Equal(t, []macroMatch{{fmt.Sprintf("$__%s", macroName), nil}}, matches)
199287
}
200288
})
201289
t.Run("does not return matches for macro name which is substring", func(t *testing.T) {

0 commit comments

Comments
 (0)