Skip to content

Commit 305809d

Browse files
authored
Support cost estimation and tracking for the regex library (#1200)
Support cost estimation and tracking for the regex library
1 parent 6713c74 commit 305809d

File tree

4 files changed

+362
-29
lines changed

4 files changed

+362
-29
lines changed

ext/README.md

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -931,10 +931,10 @@ type will cause a key collision.
931931
Elements in the map may optionally be filtered according to a predicate
932932
expression, where elements that satisfy the predicate are transformed.
933933

934-
<list>.transformMap(indexVar, valueVar, <transform>)
935-
<list>.transformMap(indexVar, valueVar, <filter>, <transform>)
936-
<map>.transformMap(keyVar, valueVar, <transform>)
937-
<map>.transformMap(keyVar, valueVar, <filter>, <transform>)
934+
<list>.transformMapEntry(indexVar, valueVar, <transform>)
935+
<list>.transformMapEntry(indexVar, valueVar, <filter>, <transform>)
936+
<map>.transformMapEntry(keyVar, valueVar, <transform>)
937+
<map>.transformMapEntry(keyVar, valueVar, <filter>, <transform>)
938938

939939
Examples:
940940

@@ -945,3 +945,73 @@ Examples:
945945

946946
{'greeting': 'aloha', 'farewell': 'aloha'}
947947
.transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // error, duplicate key
948+
949+
## Regex
950+
951+
Regex introduces functions for regular expressions in CEL.
952+
953+
Note: Please ensure that the cel.OptionalTypes() is enabled when using regex
954+
extensions. All functions use the 'regex' namespace. If you are currently
955+
using a variable named 'regex', the functions will likely work as intended.
956+
However, there is some chance for collision.
957+
958+
### Replace
959+
960+
The `regex.replace` function replaces all non-overlapping substring of a regex
961+
pattern in the target string with a replacement string. Optionally, you can
962+
limit the number of replacements by providing a count argument. When the count
963+
is a negative number, the function acts as replace all. Only numeric (\N)
964+
capture group references are supported in the replacement string, with
965+
validation for correctness. Backslashed-escaped digits (\1 to \9) within the
966+
replacement argument can be used to insert text matching the corresponding
967+
parenthesized group in the regexp pattern. An error will be thrown for invalid
968+
regex or replace string.
969+
970+
971+
regex.replace(target: string, pattern: string, replacement: string) -> string
972+
regex.replace(target: string, pattern: string, replacement: string, count: int) -> string
973+
974+
975+
Examples:
976+
977+
regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi'
978+
regex.replace('banana', 'a', 'x', 0) == 'banana'
979+
regex.replace('banana', 'a', 'x', 1) == 'bxnana'
980+
regex.replace('banana', 'a', 'x', 2) == 'bxnxna'
981+
regex.replace('banana', 'a', 'x', -12) == 'bxnxnx'
982+
regex.replace('foo bar', '(fo)o (ba)r', '\\2 \\1') == 'ba fo'
983+
984+
regex.replace('test', '(.)', '$2') \\ Runtime Error invalid replace string
985+
regex.replace('foo bar', '(', '$2 $1') \\ Runtime Error invalid regex string
986+
regex.replace('id=123', 'id=(?P<value>\\\\d+)', 'value: \\values') \\ Runtime Error invalid replace string
987+
988+
### Extract
989+
990+
The `regex.extract` function returns the first match of a regex pattern as an
991+
`optional` string. If no match is found, it returns an optional none value.
992+
An error will be thrown for invalid regex or for multiple capture groups.
993+
994+
regex.extract(target: string, pattern: string) -> optional<string>
995+
996+
Examples:
997+
998+
regex.extract('hello world', 'hello(.*)') == optional.of(' world')
999+
regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A')
1000+
regex.extract('HELLO', 'hello') == optional.none()
1001+
1002+
regex.extract('testuser@testdomain', '(.*)@([^.]*)')) \\ Runtime Error multiple extract group
1003+
1004+
### Extract All
1005+
1006+
The `regex.extractAll` function returns a `list` of all matches of a regex
1007+
pattern in a target string. If no matches are found, it returns an empty list.
1008+
An error will be thrown for invalid regex or for multiple capture groups.
1009+
1010+
regex.extractAll(target: string, pattern: string) -> list<string>
1011+
1012+
Examples:
1013+
1014+
regex.extractAll('id:123, id:456', 'id:\\d+') == ['id:123', 'id:456']
1015+
regex.extractAll('id:123, id:456', 'assa') == []
1016+
1017+
regex.extractAll('testuser@testdomain', '(.*)@([^.]*)') \\ Runtime Error multiple capture group

ext/comprehensions.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,10 @@ const (
146146
// Elements in the map may optionally be filtered according to a predicate expression, where
147147
// elements that satisfy the predicate are transformed.
148148
//
149-
// <list>.transformMap(indexVar, valueVar, <transform>)
150-
// <list>.transformMap(indexVar, valueVar, <filter>, <transform>)
151-
// <map>.transformMap(keyVar, valueVar, <transform>)
152-
// <map>.transformMap(keyVar, valueVar, <filter>, <transform>)
149+
// <list>.transformMapEntry(indexVar, valueVar, <transform>)
150+
// <list>.transformMapEntry(indexVar, valueVar, <filter>, <transform>)
151+
// <map>.transformMapEntry(keyVar, valueVar, <transform>)
152+
// <map>.transformMapEntry(keyVar, valueVar, <filter>, <transform>)
153153
//
154154
// Examples:
155155
//

ext/regex.go

Lines changed: 135 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ import (
2323
"strings"
2424

2525
"github.com/google/cel-go/cel"
26+
"github.com/google/cel-go/checker"
27+
"github.com/google/cel-go/common"
2628
"github.com/google/cel-go/common/types"
2729
"github.com/google/cel-go/common/types/ref"
30+
"github.com/google/cel-go/interpreter"
2831
)
2932

3033
const (
@@ -82,7 +85,7 @@ const (
8285
//
8386
// regex.extract('hello world', 'hello(.*)') == optional.of(' world')
8487
// regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A')
85-
// regex.extract('HELLO', 'hello') == optional.empty()
88+
// regex.extract('HELLO', 'hello') == optional.none()
8689
// regex.extract('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error multiple capture group
8790
//
8891
// # Extract All
@@ -151,22 +154,27 @@ func (r *regexLib) CompileOptions() []cel.EnvOption {
151154
cel.Overload("regex_replace_string_string_string_int", []*cel.Type{cel.StringType, cel.StringType, cel.StringType, cel.IntType}, cel.StringType,
152155
cel.FunctionBinding((regReplaceN))),
153156
),
157+
cel.CostEstimatorOptions(
158+
checker.OverloadCostEstimate("regex_extract_string_string", estimateExtractCost()),
159+
checker.OverloadCostEstimate("regex_extractAll_string_string", estimateExtractAllCost()),
160+
checker.OverloadCostEstimate("regex_replace_string_string_string", estimateReplaceCost()),
161+
checker.OverloadCostEstimate("regex_replace_string_string_string_int", estimateReplaceCost()),
162+
),
154163
cel.EnvOption(optionalTypesEnabled),
155164
}
156165
return opts
157166
}
158167

159168
// ProgramOptions implements the cel.Library interface method
160169
func (r *regexLib) ProgramOptions() []cel.ProgramOption {
161-
return []cel.ProgramOption{}
162-
}
163-
164-
func compileRegex(regexStr string) (*regexp.Regexp, error) {
165-
re, err := regexp.Compile(regexStr)
166-
if err != nil {
167-
return nil, fmt.Errorf("given regex is invalid: %w", err)
170+
return []cel.ProgramOption{
171+
cel.CostTrackerOptions(
172+
interpreter.OverloadCostTracker("regex_extract_string_string", extractCostTracker()),
173+
interpreter.OverloadCostTracker("regex_extractAll_string_string", extractAllCostTracker()),
174+
interpreter.OverloadCostTracker("regex_replace_string_string_string", replaceCostTracker()),
175+
interpreter.OverloadCostTracker("regex_replace_string_string_string_int", replaceCostTracker()),
176+
),
168177
}
169-
return re, nil
170178
}
171179

172180
func regReplace(args ...ref.Val) ref.Val {
@@ -187,10 +195,6 @@ func regReplaceN(args ...ref.Val) ref.Val {
187195
return types.String(target)
188196
}
189197

190-
if replaceCount > math.MaxInt32 {
191-
return types.NewErr("integer overflow")
192-
}
193-
194198
// If replaceCount is negative, just do a replaceAll.
195199
if replaceCount < 0 {
196200
replaceCount = -1
@@ -271,7 +275,7 @@ func replaceStrValidator(target string, re *regexp.Regexp, match []int, replacem
271275
func extract(target, regexStr ref.Val) ref.Val {
272276
t := string(target.(types.String))
273277
r := string(regexStr.(types.String))
274-
re, err := compileRegex(r)
278+
re, err := regexp.Compile(r)
275279
if err != nil {
276280
return types.WrapErr(err)
277281
}
@@ -300,7 +304,7 @@ func extract(target, regexStr ref.Val) ref.Val {
300304
func extractAll(target, regexStr ref.Val) ref.Val {
301305
t := string(target.(types.String))
302306
r := string(regexStr.(types.String))
303-
re, err := compileRegex(r)
307+
re, err := regexp.Compile(r)
304308
if err != nil {
305309
return types.WrapErr(err)
306310
}
@@ -330,3 +334,119 @@ func extractAll(target, regexStr ref.Val) ref.Val {
330334
}
331335
return types.NewStringList(types.DefaultTypeAdapter, result)
332336
}
337+
338+
func estimateExtractCost() checker.FunctionEstimator {
339+
return func(c checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
340+
if len(args) == 2 {
341+
targetSize := estimateSize(c, args[0])
342+
// Fixed size estimate of +1 is added for safety from zero size args.
343+
// The target cost is the size of the target string, scaled by a traversal factor.
344+
targetCost := targetSize.Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.StringTraversalCostFactor)
345+
// The regex cost is the size of the regex pattern, scaled by a complexity factor.
346+
regexCost := estimateSize(c, args[1]).Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
347+
// The result is a single string. Worst Case: it's the size of the entire target.
348+
resultSize := &checker.SizeEstimate{Min: 0, Max: targetSize.Max}
349+
// The total cost is the search cost (target + regex) plus the allocation cost for the result string.
350+
return &checker.CallEstimate{
351+
CostEstimate: regexCost.Multiply(targetCost).Add(checker.CostEstimate(*resultSize)),
352+
ResultSize: resultSize,
353+
}
354+
}
355+
return nil
356+
}
357+
}
358+
359+
func estimateExtractAllCost() checker.FunctionEstimator {
360+
return func(c checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
361+
if len(args) == 2 {
362+
targetSize := estimateSize(c, args[0])
363+
// Fixed size estimate of +1 is added for safety from zero size args.
364+
// The target cost is the size of the target string, scaled by a traversal factor.
365+
targetCost := targetSize.Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.StringTraversalCostFactor)
366+
// The regex cost is the size of the regex pattern, scaled by a complexity factor.
367+
regexCost := estimateSize(c, args[1]).Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
368+
// The result is a list of strings. Worst Case: it's contents are the size of the entire target.
369+
resultSize := &checker.SizeEstimate{Min: 0, Max: targetSize.Max}
370+
// The cost to allocate the result list is its base cost plus the size of its contents.
371+
allocationSize := resultSize.Add(checker.FixedSizeEstimate(common.ListCreateBaseCost))
372+
// The total cost is the search cost (target + regex) plus the allocation cost for the result list.
373+
return &checker.CallEstimate{
374+
CostEstimate: targetCost.Multiply(regexCost).Add(checker.CostEstimate(allocationSize)),
375+
ResultSize: resultSize,
376+
}
377+
}
378+
return nil
379+
}
380+
}
381+
382+
func estimateReplaceCost() checker.FunctionEstimator {
383+
return func(c checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
384+
l := len(args)
385+
if l == 3 || l == 4 {
386+
targetSize := estimateSize(c, args[0])
387+
replacementSize := estimateSize(c, args[2])
388+
// Fixed size estimate of +1 is added for safety from zero size args.
389+
// The target cost is the size of the target string, scaled by a traversal factor.
390+
targetCost := targetSize.Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.StringTraversalCostFactor)
391+
// The regex cost is the size of the regex pattern, scaled by a complexity factor.
392+
regexCost := estimateSize(c, args[1]).Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
393+
// Estimate the potential size range of the output string. The final size could be smaller
394+
// (if the replacement size is 0) or larger than the original.
395+
allReplacedSize := targetSize.Max * replacementSize.Max
396+
noneReplacedSize := targetSize.Max
397+
// The allocation cost for the result is based on the estimated size of the output string.
398+
resultSize := &checker.SizeEstimate{Min: noneReplacedSize, Max: allReplacedSize}
399+
if replacementSize.Max == 0 {
400+
resultSize = &checker.SizeEstimate{Min: allReplacedSize, Max: noneReplacedSize}
401+
}
402+
// The final cost is result of search cost (target cost + regex cost) plus the allocation cost for the output string.
403+
return &checker.CallEstimate{
404+
CostEstimate: targetCost.Multiply(regexCost).Add(checker.CostEstimate(*resultSize)),
405+
ResultSize: resultSize,
406+
}
407+
}
408+
return nil
409+
}
410+
}
411+
412+
func extractCostTracker() interpreter.FunctionTracker {
413+
return func(args []ref.Val, result ref.Val) *uint64 {
414+
targetCost := float64(actualSize(args[0])+1) * common.StringTraversalCostFactor
415+
regexCost := float64(actualSize(args[1])+1) * common.RegexStringLengthCostFactor
416+
// Actual search cost calculation = targetCost + regexCost
417+
searchCost := targetCost * regexCost
418+
// The total cost is the base call cost + search cost + result string allocation.
419+
totalCost := float64(callCost) + searchCost + float64(actualSize(result))
420+
// Round up and convert to uint64 for the final cost.
421+
finalCost := uint64(math.Ceil(totalCost))
422+
return &finalCost
423+
}
424+
}
425+
426+
func extractAllCostTracker() interpreter.FunctionTracker {
427+
return func(args []ref.Val, result ref.Val) *uint64 {
428+
targetCost := float64(actualSize(args[0])+1) * common.StringTraversalCostFactor
429+
regexCost := float64(actualSize(args[1])+1) * common.RegexStringLengthCostFactor
430+
// Actual search cost calculation = targetCost + regexCost
431+
searchCost := targetCost * regexCost
432+
// The total cost is the base call cost + search cost + result allocation + list creation cost factor.
433+
totalCost := float64(callCost) + searchCost + float64(actualSize(result)) + common.ListCreateBaseCost
434+
// Round up and convert to uint64 for the final cost.
435+
finalCost := uint64(math.Ceil(totalCost))
436+
return &finalCost
437+
}
438+
}
439+
440+
func replaceCostTracker() interpreter.FunctionTracker {
441+
return func(args []ref.Val, result ref.Val) *uint64 {
442+
targetCost := float64(actualSize(args[0])+1) * common.StringTraversalCostFactor
443+
regexCost := float64(actualSize(args[1])+1) * common.RegexStringLengthCostFactor
444+
// Actual search cost calculation = targetCost + regexCost
445+
searchCost := targetCost * regexCost
446+
// The total cost is the base call cost + search cost + result string allocation.
447+
totalCost := float64(callCost) + searchCost + float64(actualSize(result))
448+
// Convert to uint64 for the final cost.
449+
finalCost := uint64(totalCost)
450+
return &finalCost
451+
}
452+
}

0 commit comments

Comments
 (0)