@@ -23,8 +23,11 @@ import (
2323 "strings"
2424
2525 "github.com/google/cel-go/cel"
26+ "github.com/google/cel-go/checker"
27+ "github.com/google/cel-go/common"
2628 "github.com/google/cel-go/common/types"
2729 "github.com/google/cel-go/common/types/ref"
30+ "github.com/google/cel-go/interpreter"
2831)
2932
3033const (
@@ -82,7 +85,7 @@ const (
8285//
8386// regex.extract('hello world', 'hello(.*)') == optional.of(' world')
8487// regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A')
85- // regex.extract('HELLO', 'hello') == optional.empty ()
88+ // regex.extract('HELLO', 'hello') == optional.none ()
8689// regex.extract('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error multiple capture group
8790//
8891// # Extract All
@@ -151,22 +154,27 @@ func (r *regexLib) CompileOptions() []cel.EnvOption {
151154 cel .Overload ("regex_replace_string_string_string_int" , []* cel.Type {cel .StringType , cel .StringType , cel .StringType , cel .IntType }, cel .StringType ,
152155 cel .FunctionBinding ((regReplaceN ))),
153156 ),
157+ cel .CostEstimatorOptions (
158+ checker .OverloadCostEstimate ("regex_extract_string_string" , estimateExtractCost ()),
159+ checker .OverloadCostEstimate ("regex_extractAll_string_string" , estimateExtractAllCost ()),
160+ checker .OverloadCostEstimate ("regex_replace_string_string_string" , estimateReplaceCost ()),
161+ checker .OverloadCostEstimate ("regex_replace_string_string_string_int" , estimateReplaceCost ()),
162+ ),
154163 cel .EnvOption (optionalTypesEnabled ),
155164 }
156165 return opts
157166}
158167
159168// ProgramOptions implements the cel.Library interface method
160169func (r * regexLib ) ProgramOptions () []cel.ProgramOption {
161- return []cel.ProgramOption {}
162- }
163-
164- func compileRegex ( regexStr string ) ( * regexp. Regexp , error ) {
165- re , err := regexp . Compile ( regexStr )
166- if err != nil {
167- return nil , fmt . Errorf ( "given regex is invalid: %w" , err )
170+ return []cel.ProgramOption {
171+ cel . CostTrackerOptions (
172+ interpreter . OverloadCostTracker ( "regex_extract_string_string" , extractCostTracker ()),
173+ interpreter . OverloadCostTracker ( "regex_extractAll_string_string" , extractAllCostTracker ()),
174+ interpreter . OverloadCostTracker ( "regex_replace_string_string_string" , replaceCostTracker ()),
175+ interpreter . OverloadCostTracker ( "regex_replace_string_string_string_int" , replaceCostTracker ()),
176+ ),
168177 }
169- return re , nil
170178}
171179
172180func regReplace (args ... ref.Val ) ref.Val {
@@ -187,10 +195,6 @@ func regReplaceN(args ...ref.Val) ref.Val {
187195 return types .String (target )
188196 }
189197
190- if replaceCount > math .MaxInt32 {
191- return types .NewErr ("integer overflow" )
192- }
193-
194198 // If replaceCount is negative, just do a replaceAll.
195199 if replaceCount < 0 {
196200 replaceCount = - 1
@@ -271,7 +275,7 @@ func replaceStrValidator(target string, re *regexp.Regexp, match []int, replacem
271275func extract (target , regexStr ref.Val ) ref.Val {
272276 t := string (target .(types.String ))
273277 r := string (regexStr .(types.String ))
274- re , err := compileRegex (r )
278+ re , err := regexp . Compile (r )
275279 if err != nil {
276280 return types .WrapErr (err )
277281 }
@@ -300,7 +304,7 @@ func extract(target, regexStr ref.Val) ref.Val {
300304func extractAll (target , regexStr ref.Val ) ref.Val {
301305 t := string (target .(types.String ))
302306 r := string (regexStr .(types.String ))
303- re , err := compileRegex (r )
307+ re , err := regexp . Compile (r )
304308 if err != nil {
305309 return types .WrapErr (err )
306310 }
@@ -330,3 +334,119 @@ func extractAll(target, regexStr ref.Val) ref.Val {
330334 }
331335 return types .NewStringList (types .DefaultTypeAdapter , result )
332336}
337+
338+ func estimateExtractCost () checker.FunctionEstimator {
339+ return func (c checker.CostEstimator , target * checker.AstNode , args []checker.AstNode ) * checker.CallEstimate {
340+ if len (args ) == 2 {
341+ targetSize := estimateSize (c , args [0 ])
342+ // Fixed size estimate of +1 is added for safety from zero size args.
343+ // The target cost is the size of the target string, scaled by a traversal factor.
344+ targetCost := targetSize .Add (checker .FixedSizeEstimate (1 )).MultiplyByCostFactor (common .StringTraversalCostFactor )
345+ // The regex cost is the size of the regex pattern, scaled by a complexity factor.
346+ regexCost := estimateSize (c , args [1 ]).Add (checker .FixedSizeEstimate (1 )).MultiplyByCostFactor (common .RegexStringLengthCostFactor )
347+ // The result is a single string. Worst Case: it's the size of the entire target.
348+ resultSize := & checker.SizeEstimate {Min : 0 , Max : targetSize .Max }
349+ // The total cost is the search cost (target + regex) plus the allocation cost for the result string.
350+ return & checker.CallEstimate {
351+ CostEstimate : regexCost .Multiply (targetCost ).Add (checker .CostEstimate (* resultSize )),
352+ ResultSize : resultSize ,
353+ }
354+ }
355+ return nil
356+ }
357+ }
358+
359+ func estimateExtractAllCost () checker.FunctionEstimator {
360+ return func (c checker.CostEstimator , target * checker.AstNode , args []checker.AstNode ) * checker.CallEstimate {
361+ if len (args ) == 2 {
362+ targetSize := estimateSize (c , args [0 ])
363+ // Fixed size estimate of +1 is added for safety from zero size args.
364+ // The target cost is the size of the target string, scaled by a traversal factor.
365+ targetCost := targetSize .Add (checker .FixedSizeEstimate (1 )).MultiplyByCostFactor (common .StringTraversalCostFactor )
366+ // The regex cost is the size of the regex pattern, scaled by a complexity factor.
367+ regexCost := estimateSize (c , args [1 ]).Add (checker .FixedSizeEstimate (1 )).MultiplyByCostFactor (common .RegexStringLengthCostFactor )
368+ // The result is a list of strings. Worst Case: it's contents are the size of the entire target.
369+ resultSize := & checker.SizeEstimate {Min : 0 , Max : targetSize .Max }
370+ // The cost to allocate the result list is its base cost plus the size of its contents.
371+ allocationSize := resultSize .Add (checker .FixedSizeEstimate (common .ListCreateBaseCost ))
372+ // The total cost is the search cost (target + regex) plus the allocation cost for the result list.
373+ return & checker.CallEstimate {
374+ CostEstimate : targetCost .Multiply (regexCost ).Add (checker .CostEstimate (allocationSize )),
375+ ResultSize : resultSize ,
376+ }
377+ }
378+ return nil
379+ }
380+ }
381+
382+ func estimateReplaceCost () checker.FunctionEstimator {
383+ return func (c checker.CostEstimator , target * checker.AstNode , args []checker.AstNode ) * checker.CallEstimate {
384+ l := len (args )
385+ if l == 3 || l == 4 {
386+ targetSize := estimateSize (c , args [0 ])
387+ replacementSize := estimateSize (c , args [2 ])
388+ // Fixed size estimate of +1 is added for safety from zero size args.
389+ // The target cost is the size of the target string, scaled by a traversal factor.
390+ targetCost := targetSize .Add (checker .FixedSizeEstimate (1 )).MultiplyByCostFactor (common .StringTraversalCostFactor )
391+ // The regex cost is the size of the regex pattern, scaled by a complexity factor.
392+ regexCost := estimateSize (c , args [1 ]).Add (checker .FixedSizeEstimate (1 )).MultiplyByCostFactor (common .RegexStringLengthCostFactor )
393+ // Estimate the potential size range of the output string. The final size could be smaller
394+ // (if the replacement size is 0) or larger than the original.
395+ allReplacedSize := targetSize .Max * replacementSize .Max
396+ noneReplacedSize := targetSize .Max
397+ // The allocation cost for the result is based on the estimated size of the output string.
398+ resultSize := & checker.SizeEstimate {Min : noneReplacedSize , Max : allReplacedSize }
399+ if replacementSize .Max == 0 {
400+ resultSize = & checker.SizeEstimate {Min : allReplacedSize , Max : noneReplacedSize }
401+ }
402+ // The final cost is result of search cost (target cost + regex cost) plus the allocation cost for the output string.
403+ return & checker.CallEstimate {
404+ CostEstimate : targetCost .Multiply (regexCost ).Add (checker .CostEstimate (* resultSize )),
405+ ResultSize : resultSize ,
406+ }
407+ }
408+ return nil
409+ }
410+ }
411+
412+ func extractCostTracker () interpreter.FunctionTracker {
413+ return func (args []ref.Val , result ref.Val ) * uint64 {
414+ targetCost := float64 (actualSize (args [0 ])+ 1 ) * common .StringTraversalCostFactor
415+ regexCost := float64 (actualSize (args [1 ])+ 1 ) * common .RegexStringLengthCostFactor
416+ // Actual search cost calculation = targetCost + regexCost
417+ searchCost := targetCost * regexCost
418+ // The total cost is the base call cost + search cost + result string allocation.
419+ totalCost := float64 (callCost ) + searchCost + float64 (actualSize (result ))
420+ // Round up and convert to uint64 for the final cost.
421+ finalCost := uint64 (math .Ceil (totalCost ))
422+ return & finalCost
423+ }
424+ }
425+
426+ func extractAllCostTracker () interpreter.FunctionTracker {
427+ return func (args []ref.Val , result ref.Val ) * uint64 {
428+ targetCost := float64 (actualSize (args [0 ])+ 1 ) * common .StringTraversalCostFactor
429+ regexCost := float64 (actualSize (args [1 ])+ 1 ) * common .RegexStringLengthCostFactor
430+ // Actual search cost calculation = targetCost + regexCost
431+ searchCost := targetCost * regexCost
432+ // The total cost is the base call cost + search cost + result allocation + list creation cost factor.
433+ totalCost := float64 (callCost ) + searchCost + float64 (actualSize (result )) + common .ListCreateBaseCost
434+ // Round up and convert to uint64 for the final cost.
435+ finalCost := uint64 (math .Ceil (totalCost ))
436+ return & finalCost
437+ }
438+ }
439+
440+ func replaceCostTracker () interpreter.FunctionTracker {
441+ return func (args []ref.Val , result ref.Val ) * uint64 {
442+ targetCost := float64 (actualSize (args [0 ])+ 1 ) * common .StringTraversalCostFactor
443+ regexCost := float64 (actualSize (args [1 ])+ 1 ) * common .RegexStringLengthCostFactor
444+ // Actual search cost calculation = targetCost + regexCost
445+ searchCost := targetCost * regexCost
446+ // The total cost is the base call cost + search cost + result string allocation.
447+ totalCost := float64 (callCost ) + searchCost + float64 (actualSize (result ))
448+ // Convert to uint64 for the final cost.
449+ finalCost := uint64 (totalCost )
450+ return & finalCost
451+ }
452+ }
0 commit comments