Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 74 additions & 4 deletions ext/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -931,10 +931,10 @@ type will cause a key collision.
Elements in the map may optionally be filtered according to a predicate
expression, where elements that satisfy the predicate are transformed.

<list>.transformMap(indexVar, valueVar, <transform>)
<list>.transformMap(indexVar, valueVar, <filter>, <transform>)
<map>.transformMap(keyVar, valueVar, <transform>)
<map>.transformMap(keyVar, valueVar, <filter>, <transform>)
<list>.transformMapEntry(indexVar, valueVar, <transform>)
<list>.transformMapEntry(indexVar, valueVar, <filter>, <transform>)
<map>.transformMapEntry(keyVar, valueVar, <transform>)
<map>.transformMapEntry(keyVar, valueVar, <filter>, <transform>)

Examples:

Expand All @@ -945,3 +945,73 @@ Examples:

{'greeting': 'aloha', 'farewell': 'aloha'}
.transformMapEntry(keyVar, valueVar, {valueVar: keyVar}) // error, duplicate key

## Regex

Regex introduces functions for regular expressions in CEL.

Note: Please ensure that the cel.OptionalTypes() is enabled when using regex
extensions. All functions use the 'regex' namespace. If you are currently
using a variable named 'regex', the functions will likely work as intended.
However, there is some chance for collision.

### Replace

The `regex.replace` function replaces all non-overlapping substring of a regex
pattern in the target string with a replacement string. Optionally, you can
limit the number of replacements by providing a count argument. When the count
is a negative number, the function acts as replace all. Only numeric (\N)
capture group references are supported in the replacement string, with
validation for correctness. Backslashed-escaped digits (\1 to \9) within the
replacement argument can be used to insert text matching the corresponding
parenthesized group in the regexp pattern. An error will be thrown for invalid
regex or replace string.


regex.replace(target: string, pattern: string, replacement: string) -> string
regex.replace(target: string, pattern: string, replacement: string, count: int) -> string


Examples:

regex.replace('hello world hello', 'hello', 'hi') == 'hi world hi'
regex.replace('banana', 'a', 'x', 0) == 'banana'
regex.replace('banana', 'a', 'x', 1) == 'bxnana'
regex.replace('banana', 'a', 'x', 2) == 'bxnxna'
regex.replace('banana', 'a', 'x', -12) == 'bxnxnx'
regex.replace('foo bar', '(fo)o (ba)r', '\\2 \\1') == 'ba fo'

regex.replace('test', '(.)', '$2') \\ Runtime Error invalid replace string
regex.replace('foo bar', '(', '$2 $1') \\ Runtime Error invalid regex string
regex.replace('id=123', 'id=(?P<value>\\\\d+)', 'value: \\values') \\ Runtime Error invalid replace string

### Extract

The `regex.extract` function returns the first match of a regex pattern as an
`optional` string. If no match is found, it returns an optional none value.
An error will be thrown for invalid regex or for multiple capture groups.

regex.extract(target: string, pattern: string) -> optional<string>

Examples:

regex.extract('hello world', 'hello(.*)') == optional.of(' world')
regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A')
regex.extract('HELLO', 'hello') == optional.none()

regex.extract('testuser@testdomain', '(.*)@([^.]*)')) \\ Runtime Error multiple extract group

### Extract All

The `regex.extractAll` function returns a `list` of all matches of a regex
pattern in a target string. If no matches are found, it returns an empty list.
An error will be thrown for invalid regex or for multiple capture groups.

regex.extractAll(target: string, pattern: string) -> list<string>

Examples:

regex.extractAll('id:123, id:456', 'id:\\d+') == ['id:123', 'id:456']
regex.extractAll('id:123, id:456', 'assa') == []

regex.extractAll('testuser@testdomain', '(.*)@([^.]*)') \\ Runtime Error multiple capture group
8 changes: 4 additions & 4 deletions ext/comprehensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ const (
// Elements in the map may optionally be filtered according to a predicate expression, where
// elements that satisfy the predicate are transformed.
//
// <list>.transformMap(indexVar, valueVar, <transform>)
// <list>.transformMap(indexVar, valueVar, <filter>, <transform>)
// <map>.transformMap(keyVar, valueVar, <transform>)
// <map>.transformMap(keyVar, valueVar, <filter>, <transform>)
// <list>.transformMapEntry(indexVar, valueVar, <transform>)
// <list>.transformMapEntry(indexVar, valueVar, <filter>, <transform>)
// <map>.transformMapEntry(keyVar, valueVar, <transform>)
// <map>.transformMapEntry(keyVar, valueVar, <filter>, <transform>)
//
// Examples:
//
Expand Down
150 changes: 135 additions & 15 deletions ext/regex.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ import (
"strings"

"github.com/google/cel-go/cel"
"github.com/google/cel-go/checker"
"github.com/google/cel-go/common"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/interpreter"
)

const (
Expand Down Expand Up @@ -82,7 +85,7 @@ const (
//
// regex.extract('hello world', 'hello(.*)') == optional.of(' world')
// regex.extract('item-A, item-B', 'item-(\\w+)') == optional.of('A')
// regex.extract('HELLO', 'hello') == optional.empty()
// regex.extract('HELLO', 'hello') == optional.none()
// regex.extract('testuser@testdomain', '(.*)@([^.]*)') // Runtime Error multiple capture group
//
// # Extract All
Expand Down Expand Up @@ -151,22 +154,27 @@ func (r *regexLib) CompileOptions() []cel.EnvOption {
cel.Overload("regex_replace_string_string_string_int", []*cel.Type{cel.StringType, cel.StringType, cel.StringType, cel.IntType}, cel.StringType,
cel.FunctionBinding((regReplaceN))),
),
cel.CostEstimatorOptions(
checker.OverloadCostEstimate("regex_extract_string_string", estimateExtractCost()),
checker.OverloadCostEstimate("regex_extractAll_string_string", estimateExtractAllCost()),
checker.OverloadCostEstimate("regex_replace_string_string_string", estimateReplaceCost()),
checker.OverloadCostEstimate("regex_replace_string_string_string_int", estimateReplaceCost()),
),
cel.EnvOption(optionalTypesEnabled),
}
return opts
}

// ProgramOptions implements the cel.Library interface method
func (r *regexLib) ProgramOptions() []cel.ProgramOption {
return []cel.ProgramOption{}
}

func compileRegex(regexStr string) (*regexp.Regexp, error) {
re, err := regexp.Compile(regexStr)
if err != nil {
return nil, fmt.Errorf("given regex is invalid: %w", err)
return []cel.ProgramOption{
cel.CostTrackerOptions(
interpreter.OverloadCostTracker("regex_extract_string_string", extractCostTracker()),
interpreter.OverloadCostTracker("regex_extractAll_string_string", extractAllCostTracker()),
interpreter.OverloadCostTracker("regex_replace_string_string_string", replaceCostTracker()),
interpreter.OverloadCostTracker("regex_replace_string_string_string_int", replaceCostTracker()),
),
}
return re, nil
}

func regReplace(args ...ref.Val) ref.Val {
Expand All @@ -187,10 +195,6 @@ func regReplaceN(args ...ref.Val) ref.Val {
return types.String(target)
}

if replaceCount > math.MaxInt32 {
return types.NewErr("integer overflow")
}

// If replaceCount is negative, just do a replaceAll.
if replaceCount < 0 {
replaceCount = -1
Expand Down Expand Up @@ -271,7 +275,7 @@ func replaceStrValidator(target string, re *regexp.Regexp, match []int, replacem
func extract(target, regexStr ref.Val) ref.Val {
t := string(target.(types.String))
r := string(regexStr.(types.String))
re, err := compileRegex(r)
re, err := regexp.Compile(r)
if err != nil {
return types.WrapErr(err)
}
Expand Down Expand Up @@ -300,7 +304,7 @@ func extract(target, regexStr ref.Val) ref.Val {
func extractAll(target, regexStr ref.Val) ref.Val {
t := string(target.(types.String))
r := string(regexStr.(types.String))
re, err := compileRegex(r)
re, err := regexp.Compile(r)
if err != nil {
return types.WrapErr(err)
}
Expand Down Expand Up @@ -330,3 +334,119 @@ func extractAll(target, regexStr ref.Val) ref.Val {
}
return types.NewStringList(types.DefaultTypeAdapter, result)
}

func estimateExtractCost() checker.FunctionEstimator {
return func(c checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) == 2 {
targetSize := estimateSize(c, args[0])
// Fixed size estimate of +1 is added for safety from zero size args.
// The target cost is the size of the target string, scaled by a traversal factor.
targetCost := targetSize.Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.StringTraversalCostFactor)
// The regex cost is the size of the regex pattern, scaled by a complexity factor.
regexCost := estimateSize(c, args[1]).Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
// The result is a single string. Worst Case: it's the size of the entire target.
resultSize := &checker.SizeEstimate{Min: 0, Max: targetSize.Max}
// The total cost is the search cost (target + regex) plus the allocation cost for the result string.
return &checker.CallEstimate{
CostEstimate: regexCost.Multiply(targetCost).Add(checker.CostEstimate(*resultSize)),
ResultSize: resultSize,
}
}
return nil
}
}

func estimateExtractAllCost() checker.FunctionEstimator {
return func(c checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
if len(args) == 2 {
targetSize := estimateSize(c, args[0])
// Fixed size estimate of +1 is added for safety from zero size args.
// The target cost is the size of the target string, scaled by a traversal factor.
targetCost := targetSize.Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.StringTraversalCostFactor)
// The regex cost is the size of the regex pattern, scaled by a complexity factor.
regexCost := estimateSize(c, args[1]).Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
// The result is a list of strings. Worst Case: it's contents are the size of the entire target.
resultSize := &checker.SizeEstimate{Min: 0, Max: targetSize.Max}
// The cost to allocate the result list is its base cost plus the size of its contents.
allocationSize := resultSize.Add(checker.FixedSizeEstimate(common.ListCreateBaseCost))
// The total cost is the search cost (target + regex) plus the allocation cost for the result list.
return &checker.CallEstimate{
CostEstimate: targetCost.Multiply(regexCost).Add(checker.CostEstimate(allocationSize)),
ResultSize: resultSize,
}
}
return nil
}
}

func estimateReplaceCost() checker.FunctionEstimator {
return func(c checker.CostEstimator, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate {
l := len(args)
if l == 3 || l == 4 {
targetSize := estimateSize(c, args[0])
replacementSize := estimateSize(c, args[2])
// Fixed size estimate of +1 is added for safety from zero size args.
// The target cost is the size of the target string, scaled by a traversal factor.
targetCost := targetSize.Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.StringTraversalCostFactor)
// The regex cost is the size of the regex pattern, scaled by a complexity factor.
regexCost := estimateSize(c, args[1]).Add(checker.FixedSizeEstimate(1)).MultiplyByCostFactor(common.RegexStringLengthCostFactor)
// Estimate the potential size range of the output string. The final size could be smaller
// (if the replacement size is 0) or larger than the original.
allReplacedSize := targetSize.Max * replacementSize.Max
noneReplacedSize := targetSize.Max
// The allocation cost for the result is based on the estimated size of the output string.
resultSize := &checker.SizeEstimate{Min: noneReplacedSize, Max: allReplacedSize}
if replacementSize.Max == 0 {
resultSize = &checker.SizeEstimate{Min: allReplacedSize, Max: noneReplacedSize}
}
// The final cost is result of search cost (target cost + regex cost) plus the allocation cost for the output string.
return &checker.CallEstimate{
CostEstimate: targetCost.Multiply(regexCost).Add(checker.CostEstimate(*resultSize)),
ResultSize: resultSize,
}
}
return nil
}
}

func extractCostTracker() interpreter.FunctionTracker {
return func(args []ref.Val, result ref.Val) *uint64 {
targetCost := float64(actualSize(args[0])+1) * common.StringTraversalCostFactor
regexCost := float64(actualSize(args[1])+1) * common.RegexStringLengthCostFactor
// Actual search cost calculation = targetCost + regexCost
searchCost := targetCost * regexCost
// The total cost is the base call cost + search cost + result string allocation.
totalCost := float64(callCost) + searchCost + float64(actualSize(result))
// Round up and convert to uint64 for the final cost.
finalCost := uint64(math.Ceil(totalCost))
return &finalCost
}
}

func extractAllCostTracker() interpreter.FunctionTracker {
return func(args []ref.Val, result ref.Val) *uint64 {
targetCost := float64(actualSize(args[0])+1) * common.StringTraversalCostFactor
regexCost := float64(actualSize(args[1])+1) * common.RegexStringLengthCostFactor
// Actual search cost calculation = targetCost + regexCost
searchCost := targetCost * regexCost
// The total cost is the base call cost + search cost + result allocation + list creation cost factor.
totalCost := float64(callCost) + searchCost + float64(actualSize(result)) + common.ListCreateBaseCost
// Round up and convert to uint64 for the final cost.
finalCost := uint64(math.Ceil(totalCost))
return &finalCost
}
}

func replaceCostTracker() interpreter.FunctionTracker {
return func(args []ref.Val, result ref.Val) *uint64 {
targetCost := float64(actualSize(args[0])+1) * common.StringTraversalCostFactor
regexCost := float64(actualSize(args[1])+1) * common.RegexStringLengthCostFactor
// Actual search cost calculation = targetCost + regexCost
searchCost := targetCost * regexCost
// The total cost is the base call cost + search cost + result string allocation.
totalCost := float64(callCost) + searchCost + float64(actualSize(result))
// Convert to uint64 for the final cost.
finalCost := uint64(totalCost)
return &finalCost
}
}
Loading