diff --git a/cel/cel_test.go b/cel/cel_test.go index cbe89c3e7..14dd2fb37 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -39,6 +39,7 @@ import ( "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/interpreter" + "github.com/google/cel-go/interpreter/functions" "github.com/google/cel-go/parser" "github.com/google/cel-go/test" @@ -1169,6 +1170,215 @@ func TestContextEvalUnknowns(t *testing.T) { } } +func TestEvalLateBinding(t *testing.T) { + + // functions statically bound to function call nodes + // during parsing and program planning. + + f1_int := func(_ ...ref.Val) ref.Val { + return types.Int(10) + } + + f1_int_int := func(arg ref.Val) ref.Val { + return arg.(traits.Multiplier).Multiply(types.Int(2)) + } + + f1 := func() EnvOption { + return Function("f1", + Overload("f1_int", []*Type{}, types.IntType, FunctionBinding(f1_int)), + Overload("f1_int_int", []*Type{types.IntType}, types.IntType, UnaryBinding(f1_int_int)), + ) + } + + // functions supplied during evaluation to override the + // logic implemented in the static bindings. + + f1_int_override := func() *functions.Overload { + return &functions.Overload{ + Operator: "f1_int", + Function: func(_ ...ref.Val) ref.Val { + + return types.Int(0) + }, + NonStrict: false, + OperandTrait: 0, + } + } + + f1_int_int_override := func() *functions.Overload { + return &functions.Overload{ + Operator: "f1_int_int", + Unary: func(arg ref.Val) ref.Val { + return arg.(traits.Adder).Add(types.Int(100)) + }, + NonStrict: false, + OperandTrait: 0, + } + } + + activation := func(t *testing.T, vars map[string]any, ovls ...*functions.Overload) Activation { + t.Helper() + act, err := NewActivation(vars) + if err == nil { + act, err = interpreter.NewLateBindActivation(act, ovls...) + } + if err != nil { + t.Fatalf("pre-condition failed: could not create activation (cause: %v)", err) + } + return act + } + + // expectValue generates an expectation function that checks that the outcome of the + // evaluation has generated no error and has returned the value originally passed as + // argument. The comparision is performed by invoking Equal on the expected value and + // passing the outcome of the evaluation as argument. + expectValue := func(expected ref.Val) func(t *testing.T, actual ref.Val, _ *EvalDetails, err error) { + + return func(t *testing.T, actual ref.Val, _ *EvalDetails, err error) { + + if err != nil { + t.Errorf("unexpected error (cause: %v)", err) + } + + if expected.Equal(actual) != types.True { + t.Errorf("unexpected value (got: %v, want: %v)", actual, expected) + } + } + } + + // expectError generates an expectation function that checks whether the outcome of the + // execution of the test (program generation, and evaluation) has generated an error and + // that error contains a predefined message. + expectError := func(errMsg string) func(t *testing.T, _ ref.Val, _ *EvalDetails, err error) { + + return func(t *testing.T, _ ref.Val, _ *EvalDetails, err error) { + + if err == nil { + t.Fatal("expected error, but error is nil") + } + + if !strings.Contains(err.Error(), errMsg) { + t.Errorf("the evaluation error does not contain expected message (got: %s, want: %s)", err.Error(), errMsg) + } + } + } + + testCases := []struct { + name string + env *Env + expression string + parseOnly bool + opts []ProgramOption + activation Activation + expect func(t *testing.T, out ref.Val, details *EvalDetails, err error) + }{ + { + name: "OK_Happy_Path_No_Overrides", + env: testEnv(t, f1()), + expression: `f1() + f1(10)`, + parseOnly: false, + opts: []ProgramOption{EvalOptions(OptLateBindCalls)}, + activation: NoVars(), + expect: expectValue(types.Int(10 + 10*2)), + }, { + name: "OK_Happy_Path_With_Overrides", + env: testEnv(t, + Variable("a", types.IntType), + f1(), + ), + expression: `f1() + f1(a)`, + parseOnly: false, + opts: []ProgramOption{EvalOptions(OptLateBindCalls)}, + activation: activation(t, + map[string]any{ + "a": 15, + }, + f1_int_override(), + ), + expect: expectValue(types.Int(0 + 15*2)), + }, { + name: "OK_Happy_Path_With_Overrides_Explicit_Program_Option", + env: testEnv(t, + Variable("a", types.IntType), + f1(), + ), + expression: "f1() + f1(a)", + parseOnly: false, + opts: []ProgramOption{LateBindOptions()}, + activation: activation(t, + map[string]any{ + "a": 10, + }, + f1_int_override(), + f1_int_int_override(), + ), + expect: expectValue(types.Int(0 + 100 + 10)), + }, { + name: "ERROR_Invalid_Overloads", + env: testEnv(t, f1()), + expression: "f1() + f1(10)", + parseOnly: false, + opts: []ProgramOption{LateBindOptions()}, + activation: activation(t, map[string]any{}, + f1_int_override(), + &functions.Overload{ + Operator: "f1_int_int", + Binary: func(lhs ref.Val, rhs ref.Val) ref.Val { + return types.Int(50) + }, + NonStrict: false, + OperandTrait: 0, + }, + ), + expect: expectError( + interpreter.OverloadSignatureError( + "", + "f1_int_int", + "binary{ func(ref.Val, ref.Val) ref.Val }", + "unary{ func(ref.Val) ref.Val }", + ).Error(), + ), + }, { + name: "ERROR_Unchecked_AST", + env: testEnv(t, f1()), + expression: "f1 + f1(20)", + parseOnly: true, + opts: []ProgramOption{EvalOptions(OptLateBindCalls)}, + activation: activation(t, map[string]any{}), + expect: expectError(interpreter.UncheckedAstError().Error()), + }, + } + + for _, testCase := range testCases { + + t.Run(testCase.name, func(t *testing.T) { + + var ast *Ast + var issues *Issues + + if testCase.parseOnly == true { + ast, issues = testCase.env.Parse(testCase.expression) + } else { + ast, issues = testCase.env.Compile(testCase.expression) + } + + err := issues.Err() + if err != nil { + t.Fatalf("pre-condition failed could not parse/compile expression (cause: %v)", err) + } + + prg, err := testCase.env.Program(ast, testCase.opts...) + if err != nil { + testCase.expect(t, nil, nil, err) + } else { + + out, details, err := prg.Eval(testCase.activation) + testCase.expect(t, out, details, err) + } + }) + } +} + func BenchmarkContextEval(b *testing.B) { env := testEnv(b, Variable("items", ListType(IntType)), diff --git a/cel/options.go b/cel/options.go index fee67323c..4cf6e9e7b 100644 --- a/cel/options.go +++ b/cel/options.go @@ -653,6 +653,12 @@ const ( // // Deprecated: use ext.StringsValidateFormatCalls() as this option is now a no-op. OptCheckStringFormat EvalOption = 1 << iota + + // OptLateBindCalls enables overriding the binding of function overloads at evaluation time. + // + // This option works in concert with the a specific implementation of the activation that is + // wraps dispatcher, otherwise it defaults to standard behaviour. + OptLateBindCalls EvalOption = 1 << iota ) // EvalOptions sets one or more evaluation options which may affect the evaluation or Result. @@ -665,6 +671,18 @@ func EvalOptions(opts ...EvalOption) ProgramOption { } } +// LateBindOptions sets one of more LateBindCallOption and automatically +// add the OptLateBindCalls to the evaluation options, to enable the late +// binding behaviour. +func LateBindOptions(opts ...interpreter.LateBindCallOption) ProgramOption { + return func(p *prog) (*prog, error) { + p.lateBindOptions = append(p.lateBindOptions, opts...) + p.evalOpts |= OptLateBindCalls + return p, nil + } + +} + // InterruptCheckFrequency configures the number of iterations within a comprehension to evaluate // before checking whether the function evaluation has been interrupted. func InterruptCheckFrequency(checkFrequency uint) ProgramOption { diff --git a/cel/program.go b/cel/program.go index 24f41a4a7..065d9e4dd 100644 --- a/cel/program.go +++ b/cel/program.go @@ -163,6 +163,8 @@ type prog struct { callCostEstimator interpreter.ActualCostEstimator costOptions []interpreter.CostTrackerOption costLimit *uint64 + + lateBindOptions []interpreter.LateBindCallOption } // newProgram creates a program instance with an environment, an ast, and an optional list of @@ -176,10 +178,11 @@ func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) { // Ensure the default attribute factory is set after the adapter and provider are // configured. p := &prog{ - Env: e, - plannerOptions: []interpreter.PlannerOption{}, - dispatcher: disp, - costOptions: []interpreter.CostTrackerOption{}, + Env: e, + plannerOptions: []interpreter.PlannerOption{}, + dispatcher: disp, + costOptions: []interpreter.CostTrackerOption{}, + lateBindOptions: []interpreter.LateBindCallOption{}, } // Configure the program via the ProgramOption values. @@ -264,6 +267,20 @@ func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) { plannerOptions = append(plannerOptions, observers...) } } + + // add behaviour for latebinding calls. + if p.evalOpts&OptLateBindCalls != 0 { + + // we need to ensure that the AST is checked otherwise + // we won't be able to resolve overloaded functions by + // overload identifiers + if !a.IsChecked() { + return nil, interpreter.UncheckedAstError() + } + + plannerOptions = append(plannerOptions, interpreter.LateBindCalls(p.lateBindOptions...)) + } + return p.initInterpretable(a, plannerOptions) } @@ -309,6 +326,18 @@ func (p *prog) Eval(input any) (out ref.Val, det *EvalDetails, err error) { if p.defaultVars != nil { vars = interpreter.NewHierarchicalActivation(p.defaultVars, vars) } + + // before executing the evaluation if the late bind option was + // configured we ensure that the activation does have compatible + // overloads with the one maintained in the dispatcher. + if p.evalOpts&OptLateBindCalls != 0 { + + err := interpreter.ValidateOverloads(p.dispatcher, vars) + if err != nil { + return nil, nil, err + } + } + if p.observable != nil { det = &EvalDetails{} out = p.observable.ObserveEval(vars, func(observed any) { diff --git a/checker/checker.go b/checker/checker.go index a9e04fc26..0057c16cc 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -152,7 +152,7 @@ func (c *checker) checkOptSelect(e ast.Expr) { } c.errors.notAnOptionalFieldSelectionCall(e.ID(), c.location(e), fmt.Sprintf( - "incorrect signature.%s argument count: %d%s", t, len(call.Args()))) + "incorrect signature.%s argument count: %d", t, len(call.Args()))) return } diff --git a/interpreter/activation.go b/interpreter/activation.go index dd40619ee..1262fbdc0 100644 --- a/interpreter/activation.go +++ b/interpreter/activation.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" + "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/types/ref" ) @@ -190,3 +191,143 @@ func AsPartialActivation(vars Activation) (PartialActivation, bool) { } return nil, false } + +// NewLateBindActivation creates an activation that wraps the given activation and +// exposes the given function overloads to the evaluation. If the list of overloads +// has duplicates or the given activation is nil, it will return an error. +func NewLateBindActivation(activation Activation, overloads ...*functions.Overload) (LateBindActivation, error) { + + dispatcher := NewDispatcher() + err := dispatcher.Add(overloads...) + if err != nil { + return nil, err + } + + if activation == nil { + return nil, errors.New("cannot create a late bind activation with a nil activation") + } + + return &lateBindActivation{ + vars: activation, + dispatcher: dispatcher, + }, nil +} + +// LateBindActivation provides an interface that defines +// the contract for exposing function overloads during +// the evaluation. +// +// This interface enables the integration of external +// implementations of the late bind behaviour, without +// limiting the design to a given concrete type. +type LateBindActivation interface { + Activation + // ResolveOverload resolves the function overload that is + // mapped to overloadId. Implementations of this function + // are expected to recursively navigate the activation tree + // by respecting the parent-child relationships to find the + // first overload definition that is mapped to overloadId. + ResolveOverload(overloadId string) *functions.Overload + // ResolveOverloads returns a Dispatcher implementation that maintains all + // the overload functions that are defined starting from the instance of the + // concrete type implementing this method. The list is guaranteed to be + // unique (i.e. with no duplicates). Should duplicates be found, only the + // first occurrence of the overload is added to the list, thus ensuring + // that the correct behaviour is being implemented. + ResolveOverloads() Dispatcher +} + +// lateBindActivation is an Activation implementation +// that carries a dispatcher which can be used to +// supply overrides for function overloads during +// evaluation. +type lateBindActivation struct { + vars Activation + dispatcher Dispatcher +} + +// ResolveName implemments Activation.ResolveName(string). The +// method defers the name resolution to the activation instance +// that is wrapped. +func (activation *lateBindActivation) ResolveName(name string) (any, bool) { + return activation.vars.ResolveName(name) +} + +// Parent implements Activation.Parent() and returns the +// activation that is wrapped by this struct. +func (activation *lateBindActivation) Parent() Activation { + return activation.vars +} + +// ResolveOverload resolves function overload that is mapped by +// the given overloadId. The implementation first checks if the +// dispatcher configured with the current activation defines an +// overload for overloadId, and if found it returns such overload. +// If the dispatcher does not define such overloads the function +// recursively checks the activation to find any LateBindActivation +// that might declare such overload. +func (activation *lateBindActivation) ResolveOverload(overloadId string) *functions.Overload { + + if activation.dispatcher != nil { + ovl, found := activation.dispatcher.FindOverload(overloadId) + if found { + return ovl + } + } + + return resolveOverload(overloadId, activation.vars) +} + +// ResolveOverloads returns a Dispatcher implementation that aggregates +// all function overloads definition that are accessible from the current +// activation reference. The preference is given to the overloads of the +// defined dispatcher, and then the hierarchy of activations originating +// from the configured parent activation. If there are any duplicates +func (activation *lateBindActivation) ResolveOverloads() Dispatcher { + + dispatcher := NewDispatcher() + for _, ovlId := range activation.dispatcher.OverloadIds() { + ovl, _ := activation.dispatcher.FindOverload(ovlId) + dispatcher.Add(ovl) + } + + resolveAllOverloads(dispatcher, activation.vars) + + return dispatcher +} + +// resolveOverload travels the hierarchy of activations originating from the given +// Activation implementation to find the overload associatd to overloadId. Since the +// Activation APIs allow for different types of activations and compositions we need +// to ensure that if there is any valid overload that is mapped to overloadId we can +// find it. +func resolveOverload(overloadId string, activation Activation) *functions.Overload { + + if activation == nil { + return nil + } + + switch act := activation.(type) { + case *mapActivation: + return nil + case *emptyActivation: + return nil + case *partActivation: + return resolveOverload(overloadId, act.Activation) + case *hierarchicalActivation: + ovl := resolveOverload(overloadId, act.child) + if ovl == nil { + return resolveOverload(overloadId, act.parent) + } + return ovl + case LateBindActivation: + + return act.ResolveOverload(overloadId) + default: + // this is to cater for all other implementations + // that we don't known about but that rightfully + // implement the Activation interface. + return resolveOverload(overloadId, act.Parent()) + } + +} diff --git a/interpreter/activation_test.go b/interpreter/activation_test.go index 731313804..af0070cfa 100644 --- a/interpreter/activation_test.go +++ b/interpreter/activation_test.go @@ -15,11 +15,15 @@ package interpreter import ( + "fmt" + "strings" "testing" "time" + "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" ) func TestActivation(t *testing.T) { @@ -130,3 +134,737 @@ func TestAsPartialActivation(t *testing.T) { t.Error("AsPartialActivation() failed, did not find parent partial activation") } } + +// TestNewLateBindingActivation verifies the implementation of NewLateBindingActivation. The +// expectation is for the constructor function to produce a LateBindActivation implementation +// (i.e. lateBindActivation) that is configured with the given parent activation and with a +// dispatcher declaring containing the specified function overloads. The function should return +// a nil implementation and an error in case of duplicate overload function definitions or a nil +// activation. +func TestNewLateBindingActivation(t *testing.T) { + + // expectActivation generates an expectation function that verifies that + // the outcome of NewLateBindingActivation has not generated any error and + // contains the given activation as well as the specified function overloads. + expectActivation := func(expected Activation, overloads ...*functions.Overload) func(t *testing.T, actual LateBindActivation, err error) { + + return func(t *testing.T, actual LateBindActivation, err error) { + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if actual == nil { + t.Errorf("expected non nil activation") + } else { + + lateBind, ok := actual.(*lateBindActivation) + if !ok { + t.Errorf("unexpected activation type (got: %T, want: %T)", actual, &lateBindActivation{}) + } else { + if lateBind.vars != expected { + t.Errorf("unexpected wrapped activation (got: %v, want: %v)", lateBind.vars, expected) + } + if lateBind.dispatcher == nil { + t.Errorf("expected non-nil dispatcher") + } + + actualIds := lateBind.dispatcher.OverloadIds() + if len(actualIds) != len(overloads) { + t.Errorf("number of overloads do not match (got: %d, want: %d)", len(actualIds), len(overloads)) + } else { + + for _, expOvl := range overloads { + + actOvl, found := lateBind.dispatcher.FindOverload(expOvl.Operator) + if !found { + t.Errorf("expected overload (id: %s)", expOvl.Operator) + } else { + + if expOvl != actOvl { + t.Errorf("overload (id: %s) mismatch", expOvl.Operator) + } + } + } + } + } + } + } + } + + // expectError generates an expectation function that checks that the + // outcome of NewLateBindActivation produces a nil activation and an + // error that contains the specified message. + expectError := func(msg string) func(t *testing.T, actual LateBindActivation, err error) { + + return func(t *testing.T, actual LateBindActivation, err error) { + + if actual != nil { + t.Errorf("expected nil activation") + } + if err == nil { + t.Errorf("expected non-nil error") + } else { + if !strings.Contains(err.Error(), msg) { + t.Errorf("error message (value: %s) does not contain '%s'", err.Error(), msg) + } + } + } + } + + f1_string_string := unary("f1_string_string", 0, false, func(value ref.Val) ref.Val { + return types.String("f1_string_string") + }) + + f1_string_string_string := binary("f1_string_string_string", 0, false, func(lhs ref.Val, rhs ref.Val) ref.Val { + return types.String("f1_string_string_string") + }) + + f1_varargs_string := function("f1_varargs_string", 0, false, func(args ...ref.Val) ref.Val { + return types.String("f1_varargs_string") + }) + + f2_string := function("f2_string", 0, false, func(args ...ref.Val) ref.Val { + return types.String("f2_string") + }) + + f2_string_string := &functions.Overload{ + Operator: "f2_string_string", + NonStrict: true, + Unary: func(arg ref.Val) ref.Val { + return types.String("f2_string_string") + }, + } + + actHierarchical := &hierarchicalActivation{ + parent: &emptyActivation{}, + child: &mapActivation{ + bindings: map[string]any{}, + }, + } + + actEmpty := &emptyActivation{} + + testCases := []struct { + name string + activation Activation + overloads []*functions.Overload + expect func(t *testing.T, activation LateBindActivation, err error) + }{ + { + name: "OK_No_Overloads", + activation: actEmpty, + overloads: nil, + expect: expectActivation(actEmpty), + }, + { + name: "OK_Happy_Path", + activation: actHierarchical, + overloads: []*functions.Overload{ + f1_string_string, + f1_string_string_string, + f1_varargs_string, + }, + expect: expectActivation(actHierarchical, f1_string_string, f1_string_string_string, f1_varargs_string), + }, + { + name: "ERROR_Activation_Nil", + activation: nil, + overloads: []*functions.Overload{ + { + Operator: "f2", + Function: func(values ...ref.Val) ref.Val { + return types.String("f2") + }, + NonStrict: false, + }, + }, + expect: expectError("cannot create a late bind activation with a nil activation"), + }, + { + name: "ERROR_Duplicate_Overloads", + activation: &mapActivation{}, + overloads: []*functions.Overload{ + f2_string_string, + f2_string, + f2_string_string, + }, + expect: expectError(fmt.Sprintf("overload already exists '%s'", "f2_string_string")), + }, + } + + for _, testCase := range testCases { + + t.Run(testCase.name, func(t *testing.T) { + + actual, err := NewLateBindActivation(testCase.activation, testCase.overloads...) + testCase.expect(t, actual, err) + }) + } +} + +// TestLateBindActivation_Parent verifies the implementation of lateBindActivation.Parent(). The +// expectation is for the function to return the Activation implementation configured with the +// vars field of the activation. +func TestLateBindActivation_Parent(t *testing.T) { + + actHierarchical := &hierarchicalActivation{ + parent: &mapActivation{ + bindings: map[string]any{ + "a": 5, + "b": 10, + }, + }, + child: &mapActivation{ + bindings: map[string]any{ + "a": 4, + "c": 22, + }, + }, + } + + testCases := []struct { + name string + activation func() *lateBindActivation + expect func(t *testing.T, actual Activation) + }{ + // NOTE: this test is implemented for completeness but unless + // we have access to the private type there is no way to + // produce a nil parent. + { + name: "OK_Nil_Parent", + activation: func() *lateBindActivation { + return &lateBindActivation{ + vars: nil, + dispatcher: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{}, + }, + } + }, + expect: func(t *testing.T, actual Activation) { + + if actual != nil { + t.Error("expected nil parent.") + } + }, + }, + { + name: "OK_Non_Nil_Parent", + activation: func() *lateBindActivation { + return &lateBindActivation{ + vars: actHierarchical, + dispatcher: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{}, + }, + } + }, + expect: func(t *testing.T, actual Activation) { + if actual != actHierarchical { + t.Errorf("unexpected parent (got: %v, want: %v)", actual, actHierarchical) + } + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + + candidate := testCase.activation() + actual := candidate.Parent() + testCase.expect(t, actual) + }) + } +} + +// TestLateBindActivation_ResolveName verifies the implemented behaviour of +// lateBindActivation.ResolveName(string). The expectation is for the function +// to defer all the name resolution to the configured activation. +func TestLateBindActivation_ResolveName(t *testing.T) { + + activation := func() Activation { + return &hierarchicalActivation{ + parent: &mapActivation{ + bindings: map[string]any{ + "a": 5, + "b": 10, + }, + }, + child: &mapActivation{ + bindings: map[string]any{ + "a": 4, + "c": 22, + }, + }, + } + } + + testCases := []struct { + name string + vars Activation + varName string + found bool + expected any + }{ + { + name: "TRUE_Single_Name_Occurrence", + vars: activation(), + varName: "c", + found: true, + expected: 22, + }, + { + name: "TRUE_Multiple_Name_Occurrences", + vars: activation(), + varName: "a", + found: true, + expected: 4, + }, + { + name: "FALSE_Missing_Name", + vars: activation(), + varName: "d", + found: false, + expected: nil, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + candidate := &lateBindActivation{ + vars: testCase.vars, + dispatcher: NewDispatcher(), + } + + actual, found := candidate.ResolveName(testCase.varName) + if testCase.found != found { + t.Errorf("found mistmatch for var (name: '%s', got: %v, want: %v", testCase.varName, found, testCase.found) + } + + if testCase.expected != actual { + t.Errorf("value mismatch for var (name: '%s', got: %v, want: %v)", testCase.varName, actual, testCase.expected) + } + }) + } +} + +// TestLateBindActivation_ResolveOverload verifies the implemented behaviour of +// lateBindActivation.ResolveOverload(string). The expectation is for the function +// to resolve the overload that is mapped to the given overload identifier if this +// is present. The resolution rules are as follows: +// +// - the overload is first searched in the dispatcher associated to the instance. +// - if a non nil function overload is found, it is returned. +// - if a nil overload is found, the search continues by inspecting the activation +// bound to the instance. +// - if the activation bound to the instance is an empty activation the search is +// complete and nil is returned. +// - if the activation bound to the instance is a mapActivation the search is complete +// and nil is returned. +// - if the activation bound to the instance is a hierarchical activation, first the +// child is searched to determine whether there is a LateBindActivation implementation +// in the tree that originates from the parent. +// - if the child search returns a nil overload, the parent is searched to determine +// whether there is a LateBindActivation implementation in the tree that originates +// from the child. +// - if a LateBindActivation implementation is found, the ResolveOverload(string) name +// is invoked to repeat the search detailed in the previous step. +// +// If the activation tree is exhausted and no overload is found matching the given +// identifier, nil is returned. +func TestLateBindActivation_ResolveOverload(t *testing.T) { + + nestedActivation, overloads := prepareNestedActivation() + + testCases := []struct { + name string + candidate func() *lateBindActivation + overloadId string + expected *functions.Overload + }{ + { + name: "TRUE_Simple_Case", + candidate: func() *lateBindActivation { + + return &lateBindActivation{ + vars: &emptyActivation{}, + dispatcher: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f1_string": overloads["f1_string"], + "f1_string_string": overloads["f1_string_string"], + }, + }, + } + }, + overloadId: "f1_string", + expected: overloads["f1_string"], + }, { + name: "FALSE_Simple_Case", + candidate: func() *lateBindActivation { + + return &lateBindActivation{ + vars: &emptyActivation{}, + dispatcher: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f1_string": overloads["f1_string"], + "f1_string_string": overloads["f1_string_string"], + }, + }, + } + }, + overloadId: "f1_string_string_string", + expected: nil, + }, { + name: "FALSE_Simple_Case_With_Nil", + overloadId: "f1_string", + candidate: func() *lateBindActivation { + + return &lateBindActivation{ + vars: &hierarchicalActivation{ + parent: nil, + child: &emptyActivation{}, + }, + } + }, + expected: nil, + }, { + name: "FALSE_Simple_Case_With_Partial_Activation", + overloadId: "f1_string", + candidate: func() *lateBindActivation { + + return &lateBindActivation{ + vars: &hierarchicalActivation{ + parent: &partActivation{ + Activation: &emptyActivation{}, + }, + child: &emptyActivation{}, + }, + } + }, + expected: nil, + }, { + name: "TRUE_Complex_Case_With_Nesting_Top_Level", + candidate: nestedActivation, + overloadId: "f1_string", + expected: overloads["f1_string"], + }, { + name: "TRUE_Complex_Case_With_Nesting_Top_Level_Parent", + candidate: nestedActivation, + overloadId: "f2_string", + expected: overloads["f2_string_parent"], + }, { + name: "TRUE_Complex_Case_With_Nesting_Top_Level_Shadows_Vars_Parent", + candidate: nestedActivation, + overloadId: "f3_string", + expected: overloads["f3_string"], + }, { + name: "TRUE_Complex_Case_With_Nesting_Top_Level_Shadows_Vars_Child", + candidate: nestedActivation, + overloadId: "f4_string", + expected: overloads["f4_string"], + }, { + name: "TRUE_Complex_Case_With_Nesting_Vars_Child_Shadows_Parent", + candidate: nestedActivation, + overloadId: "f5_string", + expected: overloads["f5_string_nested_child"], + }, { + name: "TRUE_Complex_Case_With_Nexting_Vars_Child_Only_Find", + candidate: nestedActivation, + overloadId: "f6_string", + expected: overloads["f6_string_nested_child"], + }, { + name: "TRUE_Complex_Case_With_Nesting_Vars_Parent_Only_Find", + candidate: nestedActivation, + overloadId: "f7_string", + expected: overloads["f7_string_nested_parent"], + }, { + name: "FALSE_Complex_Case_With_Nesting_Missing", + candidate: nestedActivation, + overloadId: "f8_string", + expected: nil, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + + activation := testCase.candidate() + actual := activation.ResolveOverload(testCase.overloadId) + if actual != testCase.expected { + t.Errorf("mismatch function for overload (id: %s, nil: %v)", testCase.overloadId, actual == nil) + } + }) + } + +} + +// TestLateBindActivation_ResolverOverloads verifies the implemented behaviour of +// latebindActivation.ResolveOverloads(). The expectation is for the function to +// generate a dispatcher that aggregates all the function overloads definition by +// following the precedence rules implemented for ResolveOverload(string) when +// duplicates are encountered. +func TestLateBindActivation_ResolveOverloads(t *testing.T) { + + nestedActivation, overloads := prepareNestedActivation() + + testCases := []struct { + name string + candidate func() *lateBindActivation + expected Dispatcher + }{ + { + name: "OK_Simple_Activation_Empty", + candidate: func() *lateBindActivation { + return &lateBindActivation{ + vars: &mapActivation{}, + dispatcher: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{}, + }, + } + }, + expected: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{}, + }, + }, + { + name: "OK_Simple_Activation_Not_Empty", + candidate: func() *lateBindActivation { + return &lateBindActivation{ + vars: &mapActivation{}, + dispatcher: &defaultDispatcher{ + parent: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f2_string": overloads["f2_string_parent"], + }, + }, + overloads: overloadMap{ + "f1_string": overloads["f1_string"], + "f3_string": overloads["f3_string"], + }, + }, + } + }, + expected: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f1_string": overloads["f1_string"], + "f2_string": overloads["f2_string_parent"], + "f3_string": overloads["f3_string"], + }, + }, + }, + { + name: "OK_Nested_Activation", + candidate: nestedActivation, + expected: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f1_string": overloads["f1_string"], + "f1_string_string": overloads["f1_string_string"], + "f1_string_string_string": overloads["f1_string_string_string"], + "f2_string": overloads["f2_string_parent"], + "f3_string": overloads["f3_string"], + "f4_string": overloads["f4_string"], + "f5_string": overloads["f5_string_nested_child"], + "f6_string": overloads["f6_string_nested_child"], + "f7_string": overloads["f7_string_nested_parent"], + }, + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + + activation := testCase.candidate() + actual := activation.ResolveOverloads() + + if actual == nil { + t.Fatal("unexpected nil reference returned by ResolveOverloads") + } + + expectedIds := testCase.expected.OverloadIds() + actualIds := actual.OverloadIds() + + if len(expectedIds) != len(actualIds) { + t.Errorf("number of overloads mismatch (got: %d, want: %d)", len(actualIds), len(expectedIds)) + } + + for _, ovlId := range expectedIds { + + expectedOverload, found := testCase.expected.FindOverload(ovlId) + if !found { + t.Fatalf("unexpected: overload (id: %s) declared but not found", ovlId) + } + actualOverload, found := actual.FindOverload(ovlId) + if !found { + t.Errorf("overload (id: %s) not found in result", ovlId) + } + if actualOverload == nil { + t.Errorf("overload (id: %s) is found, but nil", ovlId) + } + if expectedOverload != actualOverload { + t.Errorf("overload (id: %s) mismatch (got: %v, want: %v)", ovlId, actualOverload, expectedOverload) + } + } + }) + } +} + +// prepareNestedActivation generates a map of overloads and a function that produces a +// lateBindActivation reference which holds a tree of activations with implementations +// of LateBindActivation in the tree. The resulting activation is as structured as shown +// below: +// +// lateBindActivation: +// +// ├─ vars ---> hierarchicalActivation: +// │ ├─ parent ---> partActivation: +// │ └─ Activation: emptyActivation +// │ └─ child ---> hierarchicalActivation: +// │ ├─ parent ---> lateBindActivation: +// │ │ ├─ vars: mapActivation, +// │ │ └─ dispatcher: defaultDispatcher +// │ │ ├─ parent: nil +// │ │ └─ overloads: +// │ │ ├─ "f3_string" --> f3_string_nested_parent +// │ │ ├─ "f5_string" --> f5_string_nested_parent +// │ │ └─ "f7_string" --> f7_string_nested_parent +// │ └─ child ---> lateBindActivation: +// │ ├─ vars: mapActivation, +// │ └─ dispatcher: defaultDispatcher +// │ ├─ parent: nil +// │ └─ overloads +// │ ├─ "f3_string" --> f3_string_nested_child +// │ ├─ "f4_string" --> f5_string_nested_child +// │ └─ "f6_string" --> f7_string_nested_child +// └─ dispatcher: defaultDispatcher: +// ├─ parent: defaultDispatcher: +// │ ├─ parent: nil +// │ └─ overloads: +// │ ├─ "f1_string": f1_string_parent +// │ └─ "f2_string": f2_string_parent +// └─ overloads: +// ├─ "f1_string": f1_string +// ├─ "f1_string_string": f1_string_string +// ├─ "f1_string_string_string": f1_string_stirng_string +// ├─ "f3_string": f3_string +// └─ "f4_string": f4_string +func prepareNestedActivation() (func() *lateBindActivation, map[string]*functions.Overload) { + + f1_string := function("f1_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f1_string") }) + f3_string := function("f3_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f3_string") }) + f4_string := function("f4_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f4_string") }) + + // this function creates an upper case version of the original string passed as + // argument (see: TestLateBindEvalUnaryEval). + f1_string_string := unary("f1_string_string", 0, false, func(arg ref.Val) ref.Val { + text, _ := arg.(types.String) + return types.String(strings.ToUpper(string(text))) + }) + + // this function composes the two strings passed as arguments in inverse order and with + // a space in the middle (see TestLateBindEvalBinaryEval). + f1_string_string_string := binary("f1_string_string_string", 0, false, func(lhs ref.Val, rhs ref.Val) ref.Val { + + a, _ := lhs.(types.String) + b, _ := rhs.(types.String) + + return b.Add(types.String(" ")).(types.String).Add(a) + }) + f1_string_parent := function("f1_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f1_string_parent") }) + f2_string_parent := function("f2_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f2_string_parent") }) + + f3_string_nested_parent := function("f3_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f3_string_nested_parent") }) + f5_string_nested_parent := function("f5_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f5_string_nested_parent") }) + f7_string_nested_parent := function("f7_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f7_string_nested_parent") }) + + f4_string_nested_child := function("f4_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f4_string_nested_child") }) + f5_string_nested_child := function("f5_string", 0, false, func(args ...ref.Val) ref.Val { return types.String("f5_string_nested_child") }) + f6_string_nested_child := function("f6_string", 0, false, func(args ...ref.Val) ref.Val { + + var result traits.Adder = types.String("") + + for _, arg := range args { + text, _ := arg.(types.String) + result = result.Add(text).(traits.Adder) + } + return result.(ref.Val) + }) + + overloads := map[string]*functions.Overload{ + "f1_string": f1_string, + "f1_string_string": f1_string_string, + "f1_string_string_string": f1_string_string_string, + "f1_string_parent": f1_string_parent, + "f2_string_parent": f2_string_parent, + "f3_string": f3_string, + "f3_string_nested_parent": f3_string_nested_parent, + "f4_string": f4_string, + "f4_string_nested_child": f4_string_nested_child, + "f5_string_nested_child": f5_string_nested_child, + "f5_string_nested_parent": f5_string_nested_parent, + "f6_string_nested_child": f6_string_nested_child, + "f7_string_nested_parent": f7_string_nested_parent, + } + + nestedActivation := func() *lateBindActivation { + + return &lateBindActivation{ + vars: &hierarchicalActivation{ + parent: &partActivation{ + Activation: &emptyActivation{}, + }, + child: &hierarchicalActivation{ + parent: &lateBindActivation{ + vars: &mapActivation{}, + dispatcher: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f3_string": f3_string_nested_parent, + "f5_string": f5_string_nested_parent, + "f7_string": f7_string_nested_parent, + }, + }, + }, + child: &lateBindActivation{ + vars: &mapActivation{}, + dispatcher: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f4_string": f4_string_nested_child, + "f5_string": f5_string_nested_child, + "f6_string": f6_string_nested_child, + }, + }, + }, + }, + }, + dispatcher: &defaultDispatcher{ + parent: &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f1_string": f1_string_parent, + "f2_string": f2_string_parent, + }, + }, + overloads: overloadMap{ + "f1_string": f1_string, + "f1_string_string": f1_string_string, + "f1_string_string_string": f1_string_string_string, + "f3_string": f3_string, + "f4_string": f4_string, + }, + }, + } + } + + return nestedActivation, overloads + +} diff --git a/interpreter/decorators.go b/interpreter/decorators.go index 502db35fc..b14077bf1 100644 --- a/interpreter/decorators.go +++ b/interpreter/decorators.go @@ -15,6 +15,8 @@ package interpreter import ( + "reflect" + "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" @@ -270,3 +272,108 @@ func maybeOptimizeSetMembership(i Interpretable, inlist InterpretableCall) (Inte valueSet: valueSet, }, nil } + +// decLateBinding creates an InterpretableDecorator that is configured +// with the given options and transforms the Interpretable created by +// the planner with wrappers around function call nodes to defer the +// selection of the overload at evaluation time. +func decLateBinding(options ...LateBindCallOption) InterpretableDecorator { + + // initialise the configuration with the known types + // of injectors. + config := defaultInjectors(&lateBindConfig{ + injectors: map[reflect.Type]OverloadInjector{}, + }) + // add any other options to the configuration + for _, option := range options { + config = option(config) + } + // make sure that the cache is clean + config.cache = map[int64]Interpretable{} + + // return the decorator. + return func(interpretable Interpretable) (Interpretable, error) { + + return lateBind(config, interpretable) + } +} + +// lateBind implements the late binding decoration behaviour. The function +// uses a configuration to maintain a map of injectors that can be used to +// replicate and reconfigure InterpretableCall nodes with a runtime version +// of the matching overload identifier. +func lateBind(config *lateBindConfig, i Interpretable) (Interpretable, error) { + + if i == nil { + return nil, nil + } + + // have we already seen the interpretable, this is more of a safety + // guard than anything else, which may happen because evalWatchXXX + // structs wrap other Intepretable implementation, which may have been + // already processed based on the order of decorators. + id := i.ID() + if _, seen := config.cache[id]; seen { + return i, nil + } + + // we need to make sure that we process nodes that wrap other + // nodes that have the same identifiers. Therefore, we add the + // node only when we complete this scope, otherwise the recursion + // won't do anything on a node wrapping another. + defer func() { + // store the interpretable in the cache. + config.cache[id] = i + }() + + switch interpretable := i.(type) { + + case InterpretableCall: + + switch evalCall := interpretable.(type) { + + // we don't want to override the standard equality and + // and non equality behaviour. + case *evalEq: + case *evalNe: + // we don't want to double down on our own late binding + // in case we have multiple late bind calls options in + // planner. + case *evalLateBind: + return i, nil + + // all the other implementations of InterpretableCall are + // not supported. We could rely on a default behaviour, + // which relies only on InterpretableCall, but we wont be + // executing possibly additional logic that is implemented + // in the Eval method. + default: + + evalType := reflect.TypeOf(evalCall) + + injector, found := config.injectors[evalType] + if !found { + return nil, UnknownCallNodeError(id, evalCall) + } + + return &evalLateBind{ + target: evalCall, + injectOverload: injector, + flags: config.flags, + }, nil + } + + case *evalWatch: + + mapped, err := lateBind(config, interpretable.Interpretable) + if err != nil { + return nil, err + } + interpretable.Interpretable = mapped + + return interpretable, nil + } + + // all the other cases aren't relevant. + return i, nil +} diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index be57e7439..09488a0dc 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -234,6 +234,13 @@ func CompileRegexConstants(regexOptimizations ...*RegexOptimization) PlannerOpti return CustomDecorator(decRegexOptimizer(regexOptimizations...)) } +// LateBindCalls returns a PlannerOption that allows for mutating +// the Intepretable with injections for replacing at evaluation +// time the bindings to the function calls. +func LateBindCalls(options ...LateBindCallOption) PlannerOption { + return CustomDecorator(decLateBinding(options...)) +} + type exprInterpreter struct { dispatcher Dispatcher container *containers.Container diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 0f1057c42..94a919134 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -2076,6 +2076,904 @@ func TestInterpreter_PlanMapComprehensionTwoVar(t *testing.T) { } } +func TestInterpreter_LateBindCalls(t *testing.T) { + + f1 := func(t *testing.T) *decls.FunctionDecl { + decl, err := decls.NewFunction( + "f1", + decls.Overload("f1_int", []*types.Type{}, types.IntType, decls.FunctionBinding( + func(_ ...ref.Val) ref.Val { + return types.Int(37) + }, + )), + decls.Overload("f1_string_int", []*types.Type{types.StringType}, types.IntType, decls.UnaryBinding( + func(arg ref.Val) ref.Val { + return arg.(types.String).Size() + }, + )), + decls.Overload("f1_int_int_int", []*types.Type{types.IntType, types.IntType}, types.IntType, decls.BinaryBinding( + func(lhs ref.Val, rhs ref.Val) ref.Val { + return lhs.(types.Int).Add(rhs) + }, + )), + decls.Overload("f1_bool_bool_bool_int", []*types.Type{types.BoolType, types.BoolType, types.BoolType}, types.IntType, decls.FunctionBinding( + func(args ...ref.Val) ref.Val { + count := 0 + for _, arg := range args { + if arg == types.True { + count++ + } + } + return types.Int(count) + }, + )), + ) + + if err != nil { + t.Fatalf("pre-condition failed: could not create function declaration for f1 (cause: %v)", err) + } + return decl + } + + // overrides supplied at runtime with the activation. + f1_int := function("f1_int", 0, false, func(_ ...ref.Val) ref.Val { return types.Int(51) }) + + f1_string_int := unary("f1_string_int", 0, false, func(arg ref.Val) ref.Val { + size := arg.(types.String).Size().(types.Int) + return size.Multiply(types.Int(2)) + }) + f1_int_int_int := binary("f1_int_int_int", 0, false, func(lhs ref.Val, rhs ref.Val) ref.Val { + return lhs.(types.Int).Subtract(rhs) + }) + f1_bool_bool_bool_int := function("f1_bool_bool_bool_int", 0, false, func(args ...ref.Val) ref.Val { + count := 0 + for _, arg := range args { + if arg == types.False { + count++ + } + } + return types.Int(count) + }) + + // activation configures an activation that exposes the given variables + // and the supplied runtime overrides for function overloads. + activation := func(vars Activation, ovls ...*functions.Overload) Activation { + + d := &defaultDispatcher{ + overloads: overloadMap{}, + } + for _, ovl := range ovls { + d.overloads[ovl.Operator] = ovl + } + return &lateBindActivation{ + vars: vars, + dispatcher: d, + } + } + + // dummyDecorator substitutes the evalZeroArity with the + // custom type dummyEval for the purpose of demonstrating + // the handling of unknown implementations of IntepretableCall + dummyDecorator := func() PlannerOption { + + return CustomDecorator(func(i Interpretable) (Interpretable, error) { + switch expr := i.(type) { + case *evalZeroArity: + return &dummyEval{ + id: expr.id, + function: expr.function, + overload: expr.overload, + impl: expr.impl, + }, nil + default: + return i, nil + } + }) + } + + testCases := []testCase{ + // Test Group 01 - Single Function Call Expressions + // ------------------------------------------------ + // This is to verify that the very simple case works + // when we don't supply any function overload. In this + // case the presence of the decorator should not alter + // the execution. + { + name: "T01.01__OK_ZeroArity_No_Overrides", + expr: "f1()", + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: EmptyActivation(), + out: types.Int(37), + }, + { + name: "T01.02__OK_Unary_No_Overrides", + expr: `f1("hello")`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: EmptyActivation(), + out: types.Int(5), + }, + { + name: "T01.03__OK_Binary_No_Overrides", + expr: `f1(3,4)`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: EmptyActivation(), + out: types.Int(7), + }, + { + name: "T01.04__OK_VarArgs_No_Overrides", + expr: `f1(true, false, true)`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: EmptyActivation(), + out: types.Int(2), + }, + // Test Group 02 - Single Function Call Expression Overrides + // --------------------------------------------------------- + // This case is to ensure that the decorator injects and + // configures correctly the lateBindEval to replace the + // call with the one supplied via the activation. + { + name: "T02.01__OK_ZeroArity_With_Overrides", + expr: "f1()", + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int), + out: types.Int(51), + }, + { + name: "T02.02__OK_Unary_With_Overrides", + expr: `f1("hello")`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_string_int), + out: types.Int(10), + }, + { + name: "T02.03__OK_Binary_With_Overrides", + expr: `f1(3,4)`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int_int_int), + out: types.Int(-1), + }, + { + name: "T02.04__OK_VarArgs_With_Overrides", + expr: `f1(true, false, true)`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_bool_bool_bool_int), + out: types.Int(1), + }, + // TestGroup 03 - Expressions with Operators + // ----------------------------------------- + // We expect the expressions of the operators + // to be processed by the planner and decorated + // accordingly. + { + name: "T03.01__OK_Equal_With_Overrides", + // without overrides: + // - f1() -> 37 + // - f1("hello") -> 5 + // result: 37 == 5 + 32 = true + // with overrides: false (51 != 10 + 32) + expr: `f1() == f1("hello") + 32`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int, f1_string_int), + out: types.False, + }, + { + name: "T03.02__OK_Not_Equal_With_Overrides", + // without overrides: + // - f1(3,4) -> 7 + // - f1(true, true, true) -> 3 + // - f1() -> 37 + // result: 7 - 3 + 37 == 41 = false + // with overrides: true (-1 - 0 + 51 != 41) + expr: `f1(3,4) - f1(true, true, true) + f1() != 41`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int, f1_int_int_int, f1_bool_bool_bool_int), + out: types.True, + }, + { + name: "T03.03__OK_And_Or_With_Overrides", + // without overrides: + // - f1(3,4) -> 7 + // - f1(true, true, true) -> 3 + // - f1("hello") -> 5 + // result: (7 > 0) && (3 == 3 || 5 == 10): true + // with overrides: false (-1 < 0) && (0 == 3 || 10 == 10) + expr: `f1(3,4) > 0 && (f1(true, true, true) == 3 || f1("hello") == 10)`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_string_int, f1_int_int_int, f1_bool_bool_bool_int), + out: types.False, + }, + { + name: "T03.04__OK_Ternary_Operator", + // without overrides: + // - f1(3,4) -> 7 + // - f1() -> 51 + // - f1("hello") -> 5 + // result: (7 > 0 ? 51 : 5) = 51 + // with overrides: 10 (-1 > 0 ? 51 : 10) + expr: `f1(3,4) > 0 ? f1() : f1("hello")`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int, f1_string_int, f1_int_int_int), + out: types.Int(10), + }, + { + name: "T03.05__OK_List_Index_Operator", + // without overrides: + // - f1(true,true,true) -> 3 + // result: [3] = 6 + // with overrides: 1 ([0]) + expr: `[1, 3, 5, 6][f1(true,true,true)]`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_bool_bool_bool_int), + out: types.Int(1), + }, + { + name: "T03.06__OK_Map_Index_Operator", + // without overrides: + // - f1() -> 37 + // result: { 31: 1, 51: 2 }[37] = 1 + // with overrides: 2 ([51]) + expr: `{ 37: 1, 51: 2 }[f1()]`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int), + out: types.Int(2), + }, + // Test Group 04 - Object Construction + // ----------------------------------- + // We expect the expressions to initialise + // the object to be processed by the planner + // and then decorated accordingly. + { + name: "T04.01__OK_Map_Construction_Values", + // without overrides: + // - f1() -> 37 + // - f1("hello") -> 5 + // - f1(2,3) -> 5 + // result: { "a": 37, "b": 5, "c": 5 } + // with overrides: { "a": 51, "b": 10, "c": -1 } + expr: `{ "a": f1(), "b": f1("hello"), "c": f1(2,3) }`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int, f1_string_int, f1_int_int_int), + out: types.NewDynamicMap(types.DefaultTypeAdapter, map[string]ref.Val{ + "a": types.Int(51), "b": types.Int(10), "c": types.Int(-1), + }), + }, + { + name: "T04.02__OK_Map_Construction_Keys", + // without overrides: + // - f1() -> 37 + // - f1("hello") -> 5 + // - f1(2,3) -> 5 + // result: { 37: 10, 5: 3, 5: 6 } + // with overrides: { 51: 10, 10: 3, -1: 6 } + expr: `{ f1(): 10, f1("hello"): 3, f1(2,3): 6 }`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int, f1_string_int, f1_int_int_int), + out: types.NewDynamicMap(types.DefaultTypeAdapter, map[ref.Val]ref.Val{ + types.Int(51): types.Int(10), + types.Int(10): types.Int(3), + types.Int(-1): types.Int(6), + }), + }, + { + name: "T04.03__OK_List_Construction", + // without overrides: + // - f1() -> 37 + // - f1("hi") -> 2 + // - f1(2,3) -> 5 + // result: [37, 2, 5] + // with overrides: [51, 4, 3] + expr: `[ f1(), f1("hi"), f1(false, false, false) ]`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int, f1_string_int, f1_bool_bool_bool_int), + out: types.NewDynamicList(types.DefaultTypeAdapter, []ref.Val{ + types.Int(51), types.Int(4), types.Int(3), + }), + }, + { + name: "T04.04__OK_Object_Construction", + // without overrides: + // - f1() -> 37 + // result: { single_int64: 37 } + // with overrides: { single_int64: 51 } + expr: `test.TestAllTypes{single_int64: f1()}`, + container: "google.expr.proto3", + types: []proto.Message{&proto3pb.TestAllTypes{}}, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + in: activation(EmptyActivation(), f1_int), + out: &proto3pb.TestAllTypes{ + SingleInt64: 51, + }, + }, + // Test Group 05 - Macros (Comprehensions) + // --------------------------------------- + // The expectation is that the planner processes + // all the expressions and invokes the decorator + // accordingly. + { + name: "T05.01__OK_Map_Filter", + // without overrides: + // - f1() -> 37 + // result: [39, 52] + // with overrides: [52] + expr: `m.filter(k, k > f1())`, + funcs: []*decls.FunctionDecl{f1(t)}, + vars: []*decls.VariableDecl{ + decls.NewVariable("m", types.NewMapType(types.IntType, types.StringType)), + }, + in: activation(&mapActivation{ + bindings: map[string]any{ + "m": map[int]string{ + 39: "hello", + 52: "hi", + }, + }}, + f1_int, + ), + out: types.NewDynamicList(types.DefaultTypeAdapter, []ref.Val{ + types.Int(52), + }), + }, + { + name: "T05.02__OK_Map_All", + // without overrides: + // - f1() -> 37 + // result: [39 > 37, 52 > 37] = true + // with overrides: [39 < 51, 52 > 51] = false + expr: `m.all(k, k > f1())`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + vars: []*decls.VariableDecl{ + decls.NewVariable("m", types.NewMapType(types.IntType, types.StringType)), + }, + in: activation(&mapActivation{ + bindings: map[string]any{ + "m": map[int]string{ + 39: "hello", + 52: "hi", + }, + }}, + f1_int, + ), + out: types.False, + }, + { + name: "T05.03__OK_Map_ExistOne", + // without overrides: + // - f1() -> 37 + // result: [39 > 37, 45 > 37, 52 > 37] = false + // with overrides: [39 < 51, 45 < 51, 52 > 51] = true + expr: `m.exists_one(k, k > f1())`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + vars: []*decls.VariableDecl{ + decls.NewVariable("m", types.NewMapType(types.IntType, types.StringType)), + }, + in: activation(&mapActivation{ + bindings: map[string]any{ + "m": map[int]string{ + 39: "hello", + 45: "hey", + 52: "hi", + }, + }}, + f1_int, + ), + out: types.True, + }, + { + name: "T05.04__OK_Map_Exists", + // without overrides: + // - f1() -> 37 + // result: [39 > 37, 45 > 37, 52 > 37] = true + // with overrides: [39 < 51, 45 < 51, 10 < 51] = false + expr: `m.exists(k, k > f1())`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + vars: []*decls.VariableDecl{ + decls.NewVariable("m", types.NewMapType(types.IntType, types.StringType)), + }, + in: activation(&mapActivation{ + bindings: map[string]any{ + "m": map[int]string{ + 39: "hello", + 45: "hey", + 10: "hi", + }, + }}, + f1_int, + ), + out: types.False, + }, + { + name: "T05.05__OK_Map_Map", + // without overrides: + // - f1(5,3) -> 8 + // - f1(true,true,false) -> 2 + // result: [1 < 8, 2 < 8, 7 < 8] = [1 * 2, 2 * 2, 7 * 2] = [2, 4, 14] + // with overrides: + // - f1(5,3) -> 2 + // - f1(true,true,false) -> 1 + // result: [1] + // + // NOTE: ensure the expected result is only one, so we don't get flakiness + // due to variable key ordering of the map iterator. + expr: `m.map(k, k < f1(5,3), k * f1(true, true, false))`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + vars: []*decls.VariableDecl{ + decls.NewVariable("m", types.NewMapType(types.IntType, types.StringType)), + }, + in: activation(&mapActivation{ + bindings: map[string]any{ + "m": map[int]string{ + 1: "hello", + 2: "hey", + 7: "hi", + }, + }}, + f1_int_int_int, + f1_bool_bool_bool_int, + ), + out: types.NewDynamicList(types.DefaultTypeAdapter, []int{1}), + }, + { + name: "T05.06__OK_List_Filter", + // without overrides: + // - f1(6,3) -> 9 + // result: [12, 20] + // with overrides: [12, 4, 20] + expr: `l.filter(e, e > f1(6,3))`, + funcs: []*decls.FunctionDecl{f1(t)}, + vars: []*decls.VariableDecl{ + decls.NewVariable("l", types.NewListType(types.IntType)), + }, + unchecked: false, + in: activation(&mapActivation{ + bindings: map[string]any{ + "l": []int{2, 12, 4, 20}, + }}, + f1_int_int_int, + ), + out: types.NewDynamicList(types.DefaultTypeAdapter, []int{12, 4, 20}), + }, + { + name: "T05.07__OK_List_All", + // without overrides: + // - f1(6,3) -> 9 + // result: false + // with overrides: true + expr: `l.all(e, e > f1(6,3))`, + funcs: []*decls.FunctionDecl{f1(t)}, + vars: []*decls.VariableDecl{ + decls.NewVariable("l", types.NewListType(types.IntType)), + }, + unchecked: false, + in: activation(&mapActivation{ + bindings: map[string]any{ + "l": []int{5, 6, 4, 7}, + }}, + f1_int_int_int, + ), + out: types.True, + }, + { + name: "T05.08__OK_List_Exists", + // without overrides: + // - f1() -> 37 + // result: true + // with overrides: false + expr: `l.exists(e, e == f1())`, + funcs: []*decls.FunctionDecl{f1(t)}, + vars: []*decls.VariableDecl{ + decls.NewVariable("l", types.NewListType(types.IntType)), + }, + unchecked: false, + in: activation(&mapActivation{ + bindings: map[string]any{ + "l": []int{37, 6, 4, 7}, + }}, + f1_int, + ), + out: types.False, + }, + { + name: "T05.09__OK_List_Exists_One", + // without overrides: + // - f1("hello") -> 5 + // result: 4 < 5 (only one): true + // with overrides: 6,4,7 < 10: false + expr: `l.exists_one(e, e < f1("hello"))`, + funcs: []*decls.FunctionDecl{f1(t)}, + vars: []*decls.VariableDecl{ + decls.NewVariable("l", types.NewListType(types.IntType)), + }, + unchecked: false, + in: activation(&mapActivation{ + bindings: map[string]any{ + "l": []int{37, 6, 4, 7}, + }}, + f1_string_int, + ), + out: types.False, + }, + { + name: "T05.10__OK_List_Map", + // without overrides: + // - f1("hi") -> 2 + // result: [5 * len("hello"), 5 * len("hey"), 5 * len("howdy")] = [25, 15, 25] + // with overrides: [5 * 2 * len("hello"), 5 * 2 * len("howdy")] = [50, 50] + expr: `l.map(e, size(e) > f1("hi"), 5 * f1(e))`, + funcs: []*decls.FunctionDecl{f1(t)}, + vars: []*decls.VariableDecl{ + decls.NewVariable("l", types.NewListType(types.StringType)), + }, + unchecked: false, + in: activation(&mapActivation{ + bindings: map[string]any{ + "l": []string{"hello", "hi", "hey", "howdy"}, + }}, + f1_string_int, + ), + out: types.NewDynamicList(types.DefaultTypeAdapter, []int{50, 50}), + }, + // Test Group 06 - Complex Nesting + // ------------------------------- + // These are just sanity checks to ensure that when we + // have complex expressions we don't have forgotten + // something. + { + name: "T06.01__OK_Nested_Function_Calls", + // without overrides: + // - f1() -> 37 + // - f1(bool, bool, bool) -> nr of trues + // - f1("hi") -> 2, f1("hello") -> 5 + // - f1(n,m) -> n + m + // result: 37 + f1(true, true, false) = 39 + // with overrides: + // - f1() -> 51 + // - f1(bool, bool, bool) -> nr of falses + // - f1("hi") -> 4, f1("hello") -> 10 + // - f1(n,m) -> n - m + // result: 51 + f1(false, false, true) = 53 + expr: `f1() + f1( + l.all( + e, + e < f1(f1("hello"), f1("hi")) + ), + m.exists( + k, + k == f1() + ), + c > f1(a,b) + )`, + funcs: []*decls.FunctionDecl{f1(t)}, + vars: []*decls.VariableDecl{ + decls.NewVariable("l", types.NewListType(types.IntType)), + decls.NewVariable("m", types.NewMapType(types.IntType, types.StringType)), + decls.NewVariable("c", types.IntType), + decls.NewVariable("a", types.IntType), + decls.NewVariable("b", types.IntType), + }, + unchecked: false, + in: activation(&mapActivation{ + bindings: map[string]any{ + // true for l.all with no overrides, false with overrides + "l": []int{3, 4, 5, 6}, + // true for m.exists with no overrides, false with overrides + "m": map[int]string{ + 37: "x", + 45: "y", + 81: "z", + }, + // false for c > f1(a,b) with no overrides + "c": 3, + "a": 5, + "b": 4, + }}, + f1_int, + f1_string_int, + f1_int_int_int, + f1_bool_bool_bool_int, + ), + out: types.Int(53), + }, + { + name: "T06.02__OK_Nested_Runtime_Overrides", + // without overrides: + // - f1() -> 37 + // - f1(a) -> 5 + // - f1(b,c) -> 7 + // - f1(d,e,f) -> 2 + // result: 37 + 5 + 7 + 2 = 51 + // with overrides: + // - f1() -> 51 + // - f1(a) -> 10 + // - f1(b,c) -> -1 + // - f1(d,e,f) -> 1 + // result: 51 + 10 - 1 + 1 = 61 + expr: `f1() + f1(a) + f1(b,c) + f1(d,e,f)`, + funcs: []*decls.FunctionDecl{f1(t)}, + vars: []*decls.VariableDecl{ + decls.NewVariable("a", types.StringType), + decls.NewVariable("b", types.IntType), + decls.NewVariable("c", types.IntType), + decls.NewVariable("d", types.BoolType), + decls.NewVariable("e", types.BoolType), + decls.NewVariable("f", types.BoolType), + }, + unchecked: false, + in: &hierarchicalActivation{ + parent: activation( + &mapActivation{ + bindings: map[string]any{ + "a": "howdy", + "b": 3, + }, + }, + f1_int, + ), + child: &hierarchicalActivation{ + parent: activation( + &mapActivation{ + bindings: map[string]any{ + "c": 4, + "d": true, + }, + }, + f1_string_int, + ), + child: &lateBindActivation{ + dispatcher: &defaultDispatcher{ + parent: &defaultDispatcher{ + overloads: overloadMap{ + "f1_int_int_int": f1_int_int_int, + }, + }, + overloads: overloadMap{ + "f1_bool_bool_bool_int": f1_bool_bool_bool_int, + }, + }, + vars: &mapActivation{ + bindings: map[string]any{ + "e": false, + "f": true, + }, + }, + }, + }, + }, + out: types.Int(61), + }, + + // Test Group 07 - With Eval Observer and Others + // --------------------------------------------------- + // These test cases are important to ensure that when + // interpretables are wrapped by evalWatchXXX the late + // bind decorator, if added later, can still travel + // through the wrapped Inteprepretable. + { + name: "T07.01__OK_With_EvalObserver", + // without overrides: + // - f1() -> 37 + // - f2(2,b) -> 2 + b + // result: a + 37 - 2 + b = 28 + // with overrides: + // - f1() -> 51 + // - f2(2,b) -> 2 - b + // result: a + 51 - 2 - b = 66 + expr: `a + f1() - f1(2,b)`, + funcs: []*decls.FunctionDecl{f1(t)}, + vars: []*decls.VariableDecl{ + decls.NewVariable("a", types.IntType), + decls.NewVariable("b", types.IntType), + }, + extraOpts: []PlannerOption{ + EvalStateObserver(), + LateBindCalls(), + }, + in: activation( + &mapActivation{ + bindings: map[string]any{ + "a": 5, + "b": 12, + }, + }, + f1_int, + f1_int_int_int, + ), + out: types.Int(66), + }, + { + name: "T07.02__OK_With_Optimize", + expr: `f1() in [ 23, 34, 51 ]`, + funcs: []*decls.FunctionDecl{f1(t)}, + extraOpts: []PlannerOption{ + Optimize(), + LateBindCalls(), + }, + unchecked: false, + in: activation(&emptyActivation{}, f1_int), + out: types.True, + }, + + // Test Group 08 - LateBindCalls Variations + // ------------------------------------------ + // These test cases are aimed at checking that when + // we play around with LateBindCalls the outcome is + // still predictable and expected: + // - if we add two LateBindCalls options, only the + // first one will have effect. + // - if we add a custom injector, this will be honored. + { + name: "T08.01__OK_With_Custom_Injector_And_Custom_Eval", + // without overrides: + // - f1() -> 37 + // with overrides: + // - f1() -> 51 + expr: `f1()`, + funcs: []*decls.FunctionDecl{f1(t)}, + extraOpts: []PlannerOption{ + + // injects custom type. + dummyDecorator(), + // lateBind will not process this node + // and throw an error. + LateBindCalls(Injector( + &dummyEval{}, + func(i InterpretableCall, ovl *functions.Overload, _ LateBindFlags) (InterpretableCall, error) { + + de := i.(*dummyEval) + + return &dummyEval{ + id: de.id, + function: de.function, + overload: de.overload, + // we should check that the function is not nil + // but we only do this for the purpose of test. + impl: ovl.Function, + }, nil + }), + ), + }, + unchecked: false, + in: activation(&emptyActivation{}, f1_int), + out: types.Int(51), + }, + { + name: "T08.02__ERROR_With_Custom_Eval", + // without overrides: + // - f1() -> 37 + // with overrides: + // - f1() -> 51 (error) + // NOTE: since f1() is mutated from evalZeroArity to evalDummy + // we will receive an error, because there is no custom injector + // able to handle this type. + expr: `f1()`, + funcs: []*decls.FunctionDecl{f1(t)}, + extraOpts: []PlannerOption{ + + // injects custom type. + dummyDecorator(), + // lateBind will not process this node + // and throw an error. + LateBindCalls(), + }, + in: activation(&emptyActivation{}, f1_int), + progErr: fmt.Sprintf(errorUnknownCallNode, 1, &dummyEval{}), + }, + { + name: "T08.03__OK_With_Multiple_LateBind_Calls", + // without overrides: + // - f1() -> 37 + // - f1("hi") -> 2 + // - f1(4,2) -> 6 + // - f1(true,true,true) -> 3 + // result: 37 + 2 + 6 + 3 = 48 + // with overrides: + // - f1() -> 51 + // - f1("hi") -> 4 + // - f1(4,2) -> 2 + // - f1(true,true,true) -> 0 + // result: 51 + 4 + 2 + 0 = 57 + expr: `f1() + f1("hi") + f1(4,2) + f1(true,true,true)`, + funcs: []*decls.FunctionDecl{f1(t)}, + unchecked: false, + extraOpts: []PlannerOption{ + LateBindCalls(), + LateBindCalls(), // this will not take effect. + }, + in: activation( + &emptyActivation{}, + f1_int, + f1_string_int, + f1_int_int_int, + f1_bool_bool_bool_int, + ), + out: types.Int(57), + }, + } + + for _, tc := range testCases { + + t.Run(tc.name, func(t *testing.T) { + + // this is to control the use case scenarios where we + // need to explicitly add and configure LateBindCalls + if len(tc.extraOpts) == 0 { + + // if it is empty, we add the default behaviour + // otherwise we expect the test case to explicitly + // configure the decorator. + tc.extraOpts = append(tc.extraOpts, LateBindCalls()) + } + + interpretable, activation, err := program(t, &tc, tc.extraOpts...) + + if err != nil { + + if len(tc.progErr) > 0 { + + if !strings.Contains(err.Error(), tc.progErr) { + t.Fatalf("got %v, (%T), wanted program error with: %s", err.Error(), err, tc.progErr) + } + // if we have a program error, we cannot continue. + return + + } else { + + t.Fatalf("pre-condition failed: could not create program (cause: %v)", err) + } + } + got := interpretable.Eval(activation) + if len(tc.err) > 0 { + // we expect error + if !types.IsError(got) || !strings.Contains(got.(*types.Err).String(), tc.err) { + t.Fatalf("got %v (%T), wanted error: %s", got, got, tc.err) + } + } else { + want := tc.out.(ref.Val) + if got.Equal(want) != types.True { + t.Fatalf("got %v, wanted: %v", got, want) + } + } + }) + } +} + +// dummyEval is a test struct used to demonstrate +// the behaviour of the OverloadInjector or its +// absence during late binding. +type dummyEval struct { + id int64 + function string + overload string + impl func(...ref.Val) ref.Val +} + +func (de *dummyEval) ID() int64 { return de.id } +func (de *dummyEval) Eval(ctx Activation) ref.Val { return types.LabelErrNode(de.id, de.impl()) } +func (de *dummyEval) Function() string { return de.function } +func (de *dummyEval) OverloadID() string { return de.overload } +func (de *dummyEval) Args() []Interpretable { return []Interpretable{} } + func testContainer(name string) *containers.Container { cont, _ := containers.NewContainer(containers.Name(name)) return cont diff --git a/interpreter/late_binding.go b/interpreter/late_binding.go new file mode 100644 index 000000000..225918504 --- /dev/null +++ b/interpreter/late_binding.go @@ -0,0 +1,427 @@ +package interpreter + +import ( + "errors" + "fmt" + "reflect" + + "github.com/google/cel-go/common/functions" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +const ( + errorUncheckedAst = "cannot decorate an un-checked AST for late binding, unchecked ASTs are the result of env.Parse(...), while late binding requires ASTs produced by env.Compile(...) or env.Check(...)" + errorOverloadMismatch = "function overload (id: %s) has different attributes (name: %s, got: %v, want: %v)" + errorOverloadSignature = "function overload (name: %s, id: %s) is not matched (got: %s, want: %s)" + errorOverloadNotFound = "unexpected: overload (id: %s) not found" + errorUnknownCallNode = "cannot apply late binding decoration to node (id: %d, type: %T): unsupported type" + errorOverloadInjection = "runtime dispatch error (cause: %v)" + + unarySignature = "unary{ func(ref.Val) ref.Val }" + binarySignature = "binary{ func(ref.Val, ref.Val) ref.Val }" + functionSignature = "varargs{ func(...ref.Val) ref.Val }" +) + +// UncheckedAstError returns an error implementation that notifies the +// caller that the AST is unchecked and therefore it is not possible to +// apply the late binding decorator to the resulting Inteprepretable. +func UncheckedAstError() error { + return errors.New(errorUncheckedAst) +} + +// UnknownCallNodeError returns an error implementation that notifies the +// caller that the late binding decorator has encountered an IntepreptableCall +// implementation that is not known and therefore it cannot applu late binding. +func UnknownCallNodeError(id int64, callNode InterpretableCall) error { + + return fmt.Errorf(errorUnknownCallNode, id, callNode) +} + +// OverloadMismatchError returns an error implementation that contains information +// about a mismatch between the runtime supplied overload and the statically linked +// overload for the given overload identifier. +func OverloadMismatchError(overloadId string, attribute string, got any, want any) error { + + return fmt.Errorf(errorOverloadMismatch, overloadId, attribute, got, want) +} + +// OverloadSignatureError returns an error implementation that contains information about +// a signature mismatch between the function overload statically configured in the environment +// and the corresponding function overload supplied at runtime during evaluation. +func OverloadSignatureError(function string, overloadId string, got string, want string) error { + return fmt.Errorf(errorOverloadSignature, function, overloadId, got, want) +} + +// ValidateOverloads ensures that if the activation contains an overload function +// its signature matches the one associated to the same overload identifier in +// the dispatcher otherwise throws an error. If the activation defines more function +// overloads, those won't be considered in the validation. +func ValidateOverloads(original Dispatcher, activation Activation) error { + + // we create a + aggregate := NewDispatcher() + resolveAllOverloads(aggregate, activation) + + overloads := original.OverloadIds() + for _, overloadId := range overloads { + + refOvl, found := original.FindOverload(overloadId) + if !found { + return fmt.Errorf(errorOverloadNotFound, overloadId) + } + + ovl, found := aggregate.FindOverload(overloadId) + if found { + // we need to make sure that the overloads are + // matching. + + result := matchSignature(overloadId, refOvl, ovl) + if result != nil { + return result + } + } + } + + return nil +} + +// resolveAllOverloads aggregates all function overloads defined in the +// activation into a single dispatcher so that they can be easily checked +// at once when we validate the overloads. +func resolveAllOverloads(aggregate Dispatcher, activation Activation) { + + if activation == nil { + return + } + switch act := activation.(type) { + case *mapActivation: + return + case *emptyActivation: + return + case *partActivation: + resolveAllOverloads(aggregate, act.Activation) + case *hierarchicalActivation: + resolveAllOverloads(aggregate, act.child) + resolveAllOverloads(aggregate, act.parent) + case LateBindActivation: + + // the implementation of Overloads() is expected to be + // recursive, therefore we don't need to look any further. + dispatcher := act.ResolveOverloads() + for _, ovlId := range dispatcher.OverloadIds() { + ovl, found := dispatcher.FindOverload(ovlId) + if found { + // note we don't need to check an error because if there + // is an error the overload is already defined. This may + // happen because we nest multiple activation with late + // binding capabilities and one may shadow another as it + // happens for variable names. Since the activations are + // visited in the correct order this is expected behaviour. + aggregate.Add(ovl) + } + } + default: + // this is to cater for all other implementations + // that we don't known about but that rightfully + // implement the Activation interface. + resolveAllOverloads(aggregate, act.Parent()) + } + +} + +// matchSignature compares the two overload definitions and returns an error +// if the overload function does not have a matching signature with the +// reference overload. The only check we can implement is over the number of +// parameters and the attributes of the overload. +// +// The impmlementation verifies the following: +// +// - if refOvl.Unary is not nil, the expectation is that ovl.Unary is not nil. +// - if refOvl.Binry is not nil, the expectation is that ovl.Binary is not nil. +// - if refOvl.Function not nil, the expectation is that ovl.Fnuction is not nil. +// - refOvl.NotStrict and ovl.NonStrict must be the same. +// - refOvl.OperandTrait and ovl.OperandTrait must be the same. +// - refOvl.Operator and ovl.Operator must be the same. +func matchSignature(overloadId string, refOvl *functions.Overload, ovl *functions.Overload) error { + + got := "" + function := "" + + if refOvl.Unary != nil { + + if ovl.Unary == nil { + + if ovl.Binary != nil { + got = binarySignature + + } else if ovl.Function != nil { + got = functionSignature + } + return OverloadSignatureError(function, overloadId, got, unarySignature) + } + } else if refOvl.Binary != nil { + + if ovl.Binary == nil { + + if ovl.Unary != nil { + got = unarySignature + } else if ovl.Function != nil { + got = functionSignature + } + + return OverloadSignatureError(function, overloadId, got, binarySignature) + } + + } else if refOvl.Function != nil { + + if ovl.Function == nil { + + if ovl.Unary != nil { + got = unarySignature + } else if ovl.Binary != nil { + got = binarySignature + } + + return OverloadSignatureError(function, overloadId, got, functionSignature) + + } + } + + if refOvl.NonStrict != ovl.NonStrict { + + return OverloadMismatchError(overloadId, "NonStrict", ovl.NonStrict, refOvl.NonStrict) + } + if refOvl.OperandTrait != ovl.OperandTrait { + + return OverloadMismatchError(overloadId, "OperandTrait", ovl.OperandTrait, refOvl.OperandTrait) + } + if refOvl.Operator != ovl.Operator { + + // unless we test directly matchSignature, this branch can only be reached + // with misconfiguration that are unlikely to occur (see test cases for + // ValidateOverloads -> ERROR_Misconfigured_Dispatcher). + return OverloadMismatchError(overloadId, "Operator", ovl.Operator, refOvl.Operator) + } + + return nil +} + +// LateBindFlags is a bitmask that is reserved for future uses to pass parameters +// to the late binding algorithm both during the program planning phase and the +// runtime dispatch behaviour. +type LateBindFlags int + +const ( + LateBindFlagsNone LateBindFlags = iota +) + +// OverloadInjector defines the signature of the function that is used to create a replica +// of the given InterpretableCall configured with the new function Overload passed as the +// second argument. The contract defined by this signature requires implementations to +// create a new instance of the InterpretableCall, which is identical to the one passed as +// argument, except for the overload used for the function. If the injection is not possible +// the function will return an error. +type OverloadInjector func(InterpretableCall, *functions.Overload, LateBindFlags) (InterpretableCall, error) + +// evalLateBind implements the decorator pattern for function call nodes that implement +// InterpretableCall. This type is shallow wrapper around an InterpretableCall implementation +// that during evaluation looks up the overload identifier exposed by the interpretable and +// resolves it from the activation if present. It then used the configured OverloadInjector +type evalLateBind struct { + target InterpretableCall + injectOverload OverloadInjector + flags LateBindFlags +} + +// ID implements Interpretable.ID() and returns the node identifier +// of the wrapped InterpretableCall. +func (elb *evalLateBind) ID() int64 { + return elb.target.ID() +} + +// Eval implements Interpretable.Eval(Activation) and executes the late binding +// behaviour, by looking up the overload identifier associated to the wrapped +// IntepretableCall implementation in the given Activation. If a non-nil overload +// is found, it then uses the configured OverloadInjector to create a fresh copy +// of the original InterpretableCall and reconfigures it with the new overload. +// If there is no override in the activation for the overload associated to the +// InterpretableCall implementation, the original function statically linked during +// the planner execution will be executed. +func (elb *evalLateBind) Eval(ctx Activation) ref.Val { + + ovlId := elb.target.OverloadID() + ovl := resolveOverload(ovlId, ctx) + + var err error + var subject Interpretable = elb.target + if ovl != nil { + + // this creates a new instance of the original + // node, to ensure that the original remains + // unchanged. + subject, err = elb.injectOverload(elb.target, ovl, elb.flags) + if err != nil { + return types.NewErrWithNodeID(elb.target.ID(), errorOverloadInjection, err) + } + } + + return subject.Eval(ctx) + +} + +// Function implements InterpretableCall.Function() and returns the +// resolved function name configured with the wrapped InterpretableCall. +func (elb *evalLateBind) Function() string { + return elb.target.Function() +} + +// OverloadID implements InterpretableCall.OverloadID() and returns the +// resolved overload identifier configured with the wrapped InterpretableCall. +func (elb *evalLateBind) OverloadID() string { + return elb.target.OverloadID() +} + +// Args implements InterpretableCall.Args() and returns the resolved +// array of arguments configured with the wrapped InterpretableCall. +func (elb *evalLateBind) Args() []Interpretable { + return elb.target.Args() +} + +// LateBindCallOption defines the signature of an option function +// that can be used to configure the behaviour of the late binding +// algorithm. +type LateBindCallOption func(c *lateBindConfig) *lateBindConfig + +// Injector returns an option function that can be used to extend the +// behaviour of the late binding algorithm to include handing for a +// specific type of InterpretableCall, which is not part of the core +// code base and therefore unknown to the algorithm. +func Injector(t InterpretableCall, injector OverloadInjector) LateBindCallOption { + + return func(c *lateBindConfig) *lateBindConfig { + + theType := reflect.TypeOf(t) + c.injectors[theType] = injector + + return c + } +} + +// lateBindConfig defines the configuration settings for the late +// binding behaviour as well as some runtime state that is used +// by the algorithm (i.e. cache of nodes processed) +type lateBindConfig struct { + cache map[int64]Interpretable + injectors map[reflect.Type]OverloadInjector + flags LateBindFlags +} + +// defaultInjectors is implements a LateBindCallOption that is +// used to populate the injectors map with handlers for all +// known types of function call nodes. +func defaultInjectors(c *lateBindConfig) *lateBindConfig { + + c.injectors[reflect.TypeOf(&evalZeroArity{})] = injectZeroArity + c.injectors[reflect.TypeOf(&evalUnary{})] = injectUnary + c.injectors[reflect.TypeOf(&evalBinary{})] = injectBinary + c.injectors[reflect.TypeOf(&evalVarArgs{})] = injectVarArgs + + return c +} + +// injectZeroArity implements an OverloadInjector for the evalZeroArity implementation +// of InterpretableCall. This implementation expects a varargs function to be defined +// by the overload in order to be substituted to the function implementation that is +// statically linked to the node during the planning phase. +func injectZeroArity(target InterpretableCall, overload *functions.Overload, _ LateBindFlags) (InterpretableCall, error) { + + zeroArity := target.(*evalZeroArity) + + if overload.Function == nil { + + return nil, OverloadSignatureError(zeroArity.function, zeroArity.overload, "", functionSignature) + } + + return &evalZeroArity{ + id: zeroArity.id, + function: zeroArity.function, + overload: zeroArity.overload, + impl: overload.Function, + }, nil +} + +// injectUnary implements an OverloadInjector for the evalUnary implementation of +// InterpretableCall. This implementation expects a unary function to be defined +// by the overload in order to be substituted to the function implementation that +// is statically linked to the node during the planning phase. +func injectUnary(target InterpretableCall, overload *functions.Overload, _ LateBindFlags) (InterpretableCall, error) { + + unary := target.(*evalUnary) + + if overload.Unary == nil { + + return nil, OverloadSignatureError(unary.function, unary.overload, "", unarySignature) + } + + return &evalUnary{ + id: unary.id, + function: unary.function, + overload: unary.overload, + arg: unary.arg, + trait: unary.trait, + nonStrict: unary.nonStrict, + + impl: overload.Unary, + }, nil +} + +// injectBinary implements an OverloadInjector for the evalBinary implementation of +// InterpretableCall. This implementation expects a binary function to be defined by +// the overload in order to be substituted to the function implementation that is +// statically linked to the node during the planning phase. +func injectBinary(target InterpretableCall, overload *functions.Overload, _ LateBindFlags) (InterpretableCall, error) { + binary := target.(*evalBinary) + + if overload.Binary == nil { + + return nil, OverloadSignatureError(binary.function, binary.overload, "", binarySignature) + } + + return &evalBinary{ + id: binary.id, + function: binary.function, + overload: binary.overload, + lhs: binary.lhs, + rhs: binary.rhs, + trait: binary.trait, + nonStrict: binary.nonStrict, + + impl: overload.Binary, + }, nil +} + +// injectVarArgs implements an OverloadInjector for the evalVarArgs implementation of +// InterpretableCall. This implementation expects a varargs function to be defined by +// the overload in order to be substituted to the function implementation that is +// statically linked to the node during the planning phase. +func injectVarArgs(target InterpretableCall, overload *functions.Overload, _ LateBindFlags) (InterpretableCall, error) { + + varArgs := target.(*evalVarArgs) + + if overload.Function == nil { + + return nil, OverloadSignatureError(varArgs.function, varArgs.overload, "", functionSignature) + } + + return &evalVarArgs{ + id: varArgs.id, + function: varArgs.function, + overload: varArgs.overload, + args: varArgs.args, + trait: varArgs.trait, + nonStrict: varArgs.nonStrict, + + impl: overload.Function, + }, nil +} diff --git a/interpreter/late_binding_test.go b/interpreter/late_binding_test.go new file mode 100644 index 000000000..b5d829a98 --- /dev/null +++ b/interpreter/late_binding_test.go @@ -0,0 +1,1132 @@ +package interpreter + +import ( + "fmt" + "reflect" + "strings" + "testing" + + "github.com/google/cel-go/common/functions" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// testErrorFunction is a simple helper to verifies the behaviour of function that +// generate error implementations. +func testErrorFunction(t *testing.T, name string, candidate func() error, expected string) { + + actualError := candidate() + if actualError == nil { + t.Fatalf("%s returned nil, a non-nil error implementation is expected", name) + } + actual := actualError.Error() + if expected != actual { + t.Errorf("%s has unexpected message (got: %s, want: %s)", name, actual, expected) + } +} + +// TestUncheckedAstError verifies the implemented behaviour of UncheckedAstError. The expectation +// is for the function to return a non-nil error implementation that contains the message defined +// by the constant errrorUncheckedAst. +func TestUncheckedAstError(t *testing.T) { + + testErrorFunction(t, "UncheckAstError", UncheckedAstError, errorUncheckedAst) +} + +// TestUnknownCallNodeError verifies the implemented behaviour of UnknownCallNodeError. +// The expectation is for the function to return a non-nil error implementation whose +// message is configured according to the errorUnknownCallNode template. +func TestUnknownCallNodeError(t *testing.T) { + testErrorFunction( + t, + "UnknownCallNodeError", + func() error { + return UnknownCallNodeError(45, &evalBinary{}) + }, + fmt.Sprintf(errorUnknownCallNode, 45, &evalBinary{}), + ) +} + +// TestOverloadMismatchError verifies the implemented behaviour of OverloadMismatchError. +// The expectation is for the function to return an error implementation with a message +// formatted according to the errorOverloadMismatch template. +func TestOverloadMismatchError(t *testing.T) { + + testErrorFunction( + t, + "OverloadMismtachError", + func() error { return OverloadMismatchError("ovlId", "Operator", "op1", "op2") }, + fmt.Sprintf(errorOverloadMismatch, "ovlId", "Operator", "op1", "op2"), + ) +} + +// TestOverloadSignatureError verifies the implemented behaviour of the OverloadSignatureError. +// The expectation is for the function to return an error implementation with a message formatted +// according to the errorOverloadSignature template. +func TestOverloadSignatureError(t *testing.T) { + + testErrorFunction( + t, + "OverloadSignatureError", + func() error { + return OverloadSignatureError("f1", "ovlId", "func(ref.Val) ref.Val", "func(...ref.Val) ref.Val") + }, + fmt.Sprintf(errorOverloadSignature, "f1", "ovlId", "func(ref.Val) ref.Val", "func(...ref.Val) ref.Val"), + ) +} + +// TestValidateOverloads verifies the implemented behaviour of ValidateOverloads. +// the expectation is for the function to resolve all function overload definitions +// that are exposed by the given activation and compare them with the reference +// dispatcher maintaining the set of functions statically bound during expression +// parsing. For any function overload that is redefined in the activation the +// expectation is for the two function signatures to match and the associated +// overload attributes to be the same. +func TestValidateOverloads(t *testing.T) { + + f2_string_string := unary("f2_string_string", 0, true, func(arg ref.Val) ref.Val { + return types.String("f2_string_string") + }) + f2_string_string_string := binary("f2_string_string_string", 0, true, func(lhs ref.Val, rhs ref.Val) ref.Val { + return types.String("f2_string_string_string") + }) + f2_varargs_string := function("f2_varargs_string", 0, true, func(args ...ref.Val) ref.Val { + return types.String("f2_varargs_string") + }) + + f1_string_string := unary("f1_string_string", 0, true, func(arg ref.Val) ref.Val { + return types.String("f1_string_string") + }) + f1_string_string_string := binary("f1_string_string_string", 0, true, func(lhs ref.Val, rhs ref.Val) ref.Val { + return types.String("f1_string") + }) + f1_varargs_string := function("f1_varargs_string", 0, true, func(args ...ref.Val) ref.Val { + return types.String("f1_varargs_string") + }) + + // matching overloads + + f1_string_string_overload := unary("f1_string_string", 0, true, func(arg ref.Val) ref.Val { + return types.String("f1_string_string_overload") + }) + f1_string_string_string_overload := binary("f1_string_string_string", 0, true, func(lhs ref.Val, rhs ref.Val) ref.Val { + return types.String("f1_string_overload") + }) + f1_varargs_string_overload := function("f1_varargs_string", 0, true, func(args ...ref.Val) ref.Val { + return types.String("f1_varargs_string_overload") + }) + + f3_string_string := unary("f3_string_string", 0, true, func(arg ref.Val) ref.Val { + return types.String("f3_string_string") + }) + + // mismatched overloads + + f1_string_string_binary := binary("f1_string_string", 0, true, func(lhs ref.Val, rhs ref.Val) ref.Val { + return types.String("f1_string_string_binary") + }) + + f1_string_string_varargs := function("f1_string_string", 0, true, func(args ...ref.Val) ref.Val { + return types.String("f1_string_string_binary") + }) + + f1_string_string_string_unary := unary("f1_string_string_string", 0, true, func(arg ref.Val) ref.Val { + return types.String("f1_string_string_string_unary") + }) + f1_string_string_string_varargs := function("f1_string_string_string", 0, true, func(args ...ref.Val) ref.Val { + return types.String("f1_string_string_string_varargs") + }) + + f1_varargs_string_unary := unary("f1_varargs_string", 0, true, func(arg ref.Val) ref.Val { + return types.String("f1_varargs_string_unary") + }) + f1_varargs_string_binary := binary("f1_varargs_string", 0, true, func(lhs ref.Val, rhs ref.Val) ref.Val { + return types.String("f1_varargs_string_unary") + }) + + f1_string_operand_trait := unary("f1_string_string", 2, true, func(arg ref.Val) ref.Val { + return types.String("f1_string_operand_trait") + }) + + f1_string_non_strict := unary("f1_string_string", 0, false, func(arg ref.Val) ref.Val { + return types.String("f1_string_non_strict") + }) + + // referenceDispatcher generates a Dispatcher implementation that contains + // a baseline configuration for all the function overloads that are defined. + referenceDispatcher := func() Dispatcher { + + return &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f1_string_string": f1_string_string, + "f1_string_string_string": f1_string_string_string, + "f1_varargs_string": f1_varargs_string, + "f2_string_string": f2_string_string, + "f2_string_string_string": f2_string_string_string, + "f2_varargs_string": f2_varargs_string, + }, + } + } + + // candidateActivation generates a function that produces a repeatable configuration + // for the Activation implementation based on lateBindActivation. When supplied with + // a non-empty overload identifier, it also maps the given overload to such identifier. + // + // NOTE: these parameters are used to produce alterations over a baseline and produce + // different test cases, to ensure that all branches are explored. + candidateActivation := func(overloadId string, overload *functions.Overload) func() Activation { + + return func() Activation { + + dispatcher := &defaultDispatcher{ + parent: nil, + overloads: overloadMap{ + "f1_string_string": f1_string_string_overload, + "f1_string_string_string": f1_string_string_string_overload, + "f1_varargs_string": f1_varargs_string_overload, + "f3_string_string": f3_string_string, + }, + } + if len(overloadId) > 0 { + dispatcher.overloads[overloadId] = overload + } + return &lateBindActivation{ + vars: &mapActivation{}, + dispatcher: dispatcher, + } + } + } + + testCases := []struct { + name string + reference Dispatcher + candidate func() Activation + err error + }{ + { + name: "OK_Nil_Activation", + reference: referenceDispatcher(), + candidate: func() Activation { return nil }, + err: nil, + }, + { + name: "OK_Unknown_Activation_Type", + reference: referenceDispatcher(), + candidate: func() Activation { return &dummyActivation{} }, + err: nil, + }, + { + name: "OK_Matching_Overloads", + reference: referenceDispatcher(), + candidate: candidateActivation("", nil), + err: nil, + }, + + // binary function is unmatched in signature + { + name: "ERROR_Unary_Mismatch_Binary", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_string_string", f1_string_string_binary), + err: fmt.Errorf(errorOverloadSignature, "", "f1_string_string", binarySignature, unarySignature), + }, { + name: "ERROR_Unary_Mismatch_VarArgs", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_string_string", f1_string_string_varargs), + err: fmt.Errorf(errorOverloadSignature, "", "f1_string_string", functionSignature, unarySignature), + }, { + name: "ERROR_Unary_Mismatch_Nil", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_string_string", &functions.Overload{ + Operator: "f1_string_string", + OperandTrait: 0, + NonStrict: true, + }), + err: fmt.Errorf(errorOverloadSignature, "", "f1_string_string", "", unarySignature), + }, + + // binary function is unmatched in signature + { + name: "ERROR_Binary_Mismatch_Unary", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_string_string_string", f1_string_string_string_unary), + err: fmt.Errorf(errorOverloadSignature, "", "f1_string_string_string", unarySignature, binarySignature), + }, { + name: "ERROR_Binary_Mismatch_VarArgs", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_string_string_string", f1_string_string_string_varargs), + err: fmt.Errorf(errorOverloadSignature, "", "f1_string_string_string", functionSignature, binarySignature), + }, { + name: "ERROR_Binary_Mismatch_Nil", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_string_string_string", &functions.Overload{ + Operator: "f1_string_string_string", + OperandTrait: 0, + NonStrict: true, + }), + err: fmt.Errorf(errorOverloadSignature, "", "f1_string_string_string", "", binarySignature), + }, + + // varargs function is unmatched in signature + { + name: "ERROR_VarArgs_Mismatch_Unary", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_varargs_string", f1_varargs_string_unary), + err: fmt.Errorf(errorOverloadSignature, "", "f1_varargs_string", unarySignature, functionSignature), + }, { + name: "ERROR_VarArgs_Mismatch_Binary", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_varargs_string", f1_varargs_string_binary), + err: fmt.Errorf(errorOverloadSignature, "", "f1_varargs_string", binarySignature, functionSignature), + }, { + name: "ERROR_VarArgs_Mismatch_Nil", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_varargs_string", &functions.Overload{ + Operator: "f1_varargs_string", + OperandTrait: 0, + NonStrict: true, + }), + err: fmt.Errorf(errorOverloadSignature, "", "f1_varargs_string", "", functionSignature), + }, + + // unmatched attributes + { + name: "ERROR_Mismatch_OperandTrait", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_string_string", f1_string_operand_trait), + err: fmt.Errorf(errorOverloadMismatch, "f1_string_string", "OperandTrait", 2, 0), + }, { + name: "ERROR_Mismatch_NonStrict", + reference: referenceDispatcher(), + candidate: candidateActivation("f1_string_string", f1_string_non_strict), + err: fmt.Errorf(errorOverloadMismatch, "f1_string_string", "NonStrict", false, true), + }, + // NOTE: in this scenario a misconfiguration of the key of the original + // dispacher that is mapped to an overload with a different operator + // causes a signature error, when the same key has a match in the + // Activation implementation passed to ValidateOverloads + { + name: "ERROR_Misconfigured_Dispatcher", + reference: &defaultDispatcher{ + overloads: overloadMap{ + "f7_string_string": f1_string_string, + }, + }, + candidate: candidateActivation("f7_string_string", &functions.Overload{ + Operator: "f7_string_string", + OperandTrait: 0, + NonStrict: true, + Unary: func(_ ref.Val) ref.Val { + return types.String("") + }, + }), + err: fmt.Errorf(errorOverloadMismatch, "f7_string_string", "Operator", "f7_string_string", "f1_string_string"), + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + + activation := testCase.candidate() + err := ValidateOverloads(testCase.reference, activation) + + // if we expect an error + if testCase.err != nil { + if err == nil { + // it should not be nil + t.Errorf("outcome mismatch (got: , want: %s)", testCase.err.Error()) + } else if testCase.err.Error() != err.Error() { + + // it should match + t.Errorf("outcome mismatch (got: %s want: %s)", err.Error(), testCase.err.Error()) + } + } else if err != nil { + + t.Errorf("outcome mismatch (got: %s, want: )", err.Error()) + } + }) + } +} + +type dummyActivation struct { + Activation +} + +func (da *dummyActivation) ResolveName(name string) (any, bool) { + if da.Activation != nil { + return da.Activation.ResolveName(name) + } + return nil, false +} +func (da *dummyActivation) Parent() Activation { return da.Activation } + +// evalLateBindTestCase is a convenience structure used to test +// the different methods of evalLateBind. +type evalLateBindTestCase struct { + name string + target InterpretableCall + + // the attributes that follow are only required for the + // execution of tests for the Eval method. + // + injector OverloadInjector + activation Activation + // expect is used to implement post execution assertions + // for the Eval method. The signature could be more precise + // but in this way we can re-use pre-built expectation + // functions. + expect func(*testing.T, Interpretable, ref.Val) +} + +// TestEvalLateBind_ID verifies the implemented behaviour of evalLatebind.ID(). +// The expectation is for the function to return the unique identifier of the +// IntepretableCall wrapped by the structure. +func TestEvalLateBind_ID(t *testing.T) { + + testCases := testAllEvalTypes() + + for _, testCase := range testCases { + + t.Run(testCase.name, func(t *testing.T) { + + candidate := &evalLateBind{ + target: testCase.target, + } + actual := candidate.ID() + expected := candidate.target.ID() + + if actual != expected { + t.Errorf("ID() returned an unexpected value (got: %d, want: %d)", actual, expected) + } + }) + } +} + +// TestEvalLateBind_Function verifies the implemented behaviour of +// evalLatebind.Function(). The expectation is for the function to +// return the name of the function that has been configured with the +// IntepretableCall wrapped by the structure. +func TestEvalLateBind_Function(t *testing.T) { + testCases := testAllEvalTypes() + + for _, testCase := range testCases { + + t.Run(testCase.name, func(t *testing.T) { + + candidate := &evalLateBind{ + target: testCase.target, + } + actual := candidate.Function() + expected := candidate.target.Function() + + if actual != expected { + t.Errorf("Function() returned an unexpected value (got: %s, want: %ss)", actual, expected) + } + }) + } +} + +// TestEvalLateBind_OverloadID verifies the implemented behaviour of +// evalLatebind.OverloadID(). The expectation is for the function to +// return the unique identifier of the overload configured with the +// IntepretableCall wrapped by the structure. +func TestEvalLateBind_OverloadID(t *testing.T) { + testCases := testAllEvalTypes() + + for _, testCase := range testCases { + + t.Run(testCase.name, func(t *testing.T) { + + candidate := &evalLateBind{ + target: testCase.target, + } + actual := candidate.OverloadID() + expected := candidate.target.OverloadID() + + if actual != expected { + t.Errorf("OverloadID() returned an unexpected value (got: %s, want: %s)", actual, expected) + } + }) + } +} + +// TestEvalLateBind_Args verifies the implemented behaviour of +// evalLatebind.Args(). The expectation is for the function to +// return the slice of Interpretable implementation that are +// the arguments resolved for the function call configured with +// IntepretableCall wrapped by the structure. +func TestEvalLateBind_Args(t *testing.T) { + testCases := testAllEvalTypes() + + for _, testCase := range testCases { + + t.Run(testCase.name, func(t *testing.T) { + + candidate := &evalLateBind{ + target: testCase.target, + } + actual := candidate.Args() + expected := candidate.target.Args() + + if len(actual) != len(expected) { + t.Errorf("Args() returned an array of different size (got: %d, want: %d)", len(actual), len(expected)) + } + for index, e := range expected { + a := actual[index] + if a != e { + t.Errorf("Args() returned an unexpected value for index '%d' (got: %v, want: %v)", index, actual, expected) + } + } + }) + } +} + +// TestEvalLateBind_Eval verifies the implemented behaviour of evalLateBind.Eval(Activation). +// The expectation is for the method to execute runtime dispatching of the function overload +// configured with the wrapped InterpretableCall by looking up the Activation passed to the +// function. If a matching overload is found, the implementation should replicate the wrapped +// intepretable and reconfigure it with the new overload implementation prior to executing +// the evaluation. If there is any error in the injection of the overload the function will +// return such error. +func TestEvalLateBind_Eval(t *testing.T) { + + // evalZeroCall returns an evalZeroArity reference that + // is configured with a function that returns 0. + evalZeroCall := func() *evalZeroArity { + return &evalZeroArity{ + id: 50, + function: "f1", + overload: "f1_int", + impl: func(_ ...ref.Val) ref.Val { + return types.Int(0) + }, + } + } + + // evalUnaryCall returns an evalUnary reference that + // is configured with a function that computes to true + // when applied to the argument supplied. + evalUnaryCall := func() *evalUnary { + return &evalUnary{ + id: 51, + function: "f2", + overload: "f2_string_bool", + impl: func(arg ref.Val) ref.Val { + + text := arg.(types.String) + return text.Equal(types.String("hello")) + }, + arg: NewConstValue(52, types.String("hello")), + nonStrict: false, + trait: 0, + } + } + + // evalBinaryCall returns an evalBinary reference that + // is configured with a function that computes to 15 + // when applied to the arguments supplied (13, 2). + evalBinaryCall := func() *evalBinary { + return &evalBinary{ + id: 53, + function: "f3", + overload: "f3_int_int_int", + impl: func(lhs ref.Val, rhs ref.Val) ref.Val { + + l := lhs.(types.Int) + r := rhs.(types.Int) + return l + r + }, + lhs: NewConstValue(54, types.Int(13)), + rhs: NewConstValue(55, types.Int(2)), + nonStrict: false, + trait: 0, + } + } + + // evalVarArgsCall returns an evalVarArgs reference that + // is configured with a function that computes the string + // `this is fun` when applied to the arguments supplied + // (`this`, `is`, `fun`). + evalVarArgsCall := func() *evalVarArgs { + + return &evalVarArgs{ + + id: 55, + function: "f4", + overload: "f4_string_string_string_string", + impl: func(args ...ref.Val) ref.Val { + + parts := make([]string, len(args)) + for i, arg := range args { + parts[i] = arg.(types.String).Value().(string) + } + + return types.String(strings.Join(parts, " ")) + }, + args: []Interpretable{ + NewConstValue(56, types.String("this")), + NewConstValue(56, types.String("is")), + NewConstValue(56, types.String("fun")), + }, + nonStrict: false, + trait: 0, + } + } + + // activation creates a lateBindActivation and configures it with the + // given activation and overload functions. + activation := func(vars Activation, ovls ...*functions.Overload) Activation { + d := &defaultDispatcher{ + parent: nil, + overloads: overloadMap{}, + } + for _, ovl := range ovls { + d.overloads[ovl.Operator] = ovl + } + + return &lateBindActivation{vars, d} + } + + // expectOriginal generates an expectation function that checks that the wrapped + // target has remained unchanged. It does so by invoking the Eval method on an + // empty activation and compares the result with the supplied original value. + expectOriginal := func(original ref.Val) func(t *testing.T, target Interpretable, _ ref.Val) { + + return func(t *testing.T, target Interpretable, _ ref.Val) { + + actual := target.Eval(&emptyActivation{}) + if actual != original { + t.Errorf("target.Eval(Activation) returned unexpected value (got: %v, want: %v)", actual, original) + } + } + } + + testCases := []evalLateBindTestCase{ + + // Test Case Group 0 - No Overrides + // --------------------------------------------- + // The expectation is that the evaluation returns + // the same result as the result computed with the + // function statically configured with the eval + // struct. + + // Test Case 01 - evalZeroArity + // ------------------------------------------------ + // result: 0 + // execution: f1_int (pre-configured) + { + name: "OK_evalZeroArity_No_Overrides", + target: evalZeroCall(), + injector: injectZeroArity, + activation: &emptyActivation{}, + expect: expectValue(types.Int(0)), + }, + + // Test Case 02 - evalUnary + // ------------------------------------------------ + // result: true + // execution: f2_string_bool (pre-configured) + { + name: "OK_evalUnaryy_No_Overrides", + target: evalUnaryCall(), + injector: injectUnary, + activation: activation( + &emptyActivation{}, + function("f2_bool", 0, false, func(_ ...ref.Val) ref.Val { return types.False }), + ), + expect: expectValue(types.True), + }, + + // Test Case 03 - evalBinary + // ------------------------------------------------ + // result: 15 + // execution: f3_int_int_int (pre-configured) + { + name: "OK_evalBinary_No_Overrides", + target: evalBinaryCall(), + injector: injectBinary, + activation: activation( + &emptyActivation{}, + function("f7_int_int_int", 0, false, func(_ ...ref.Val) ref.Val { return types.Int(12) }), + ), + expect: expectValue(types.Int(15)), + }, + + // Test Case 04 - evalVarArgs + // ------------------------------------------------ + // result: this is fun + // execution: f4_string_string_string_string (pre-configured) + { + name: "OK_evalVarArgs_No_Overrides", + target: evalVarArgsCall(), + injector: injectVarArgs, + activation: &mapActivation{ + bindings: map[string]any{ + "a": 10, + "f": true, + }, + }, + expect: expectValue(types.String("this is fun")), + }, + + // Test Case Group 1 - Overrides (Hppy Path) + // ----------------------------------------- + // The expectation is that the result of the + // evaluation returns the same result of the + // function that is supplied at runtime with + // Activation implementation, but the original + // struct remains unchanged. + + // Test Case 11 - evalZeroArity + // ------------------------------------------------ + // result: 2 + // execution: f1_int (runtime override) + { + name: "OK_evalZeroArity_With_Overrides", + target: evalZeroCall(), + injector: injectZeroArity, + activation: activation( + &emptyActivation{}, + function("f1_int", 0, false, func(_ ...ref.Val) ref.Val { return types.Int(2) }), + ), + expect: chain( + // this is the result from the override + expectValue(types.Int(2)), + // this is the original result that we + // should stil retain in the node. + expectOriginal(types.Int(0)), + ), + }, + + // Test Case 12 - evalUnary + // ------------------------------------------------ + // result: false + // execution: f2_string_bool (runtime override) + { + name: "OK_evalUnary_With_Overrides", + target: evalUnaryCall(), + injector: injectUnary, + activation: activation( + &emptyActivation{}, + unary("f2_string_bool", 0, false, func(arg ref.Val) ref.Val { + text := arg.Value().(string) + return types.Bool(len(text) == 10) + }), + ), + expect: chain( + // this is the result from the override + expectValue(types.False), + // this is the original result that we + // should stil retain in the node. + expectOriginal(types.True), + ), + }, + + // Test Case 13 - evalBinary + // ------------------------------------------------ + // result: 2 + // execution: f3_int_int_int (runtime override) + { + name: "OK_evalBinary_With_Overrides", + target: evalBinaryCall(), + injector: injectBinary, + activation: activation( + &emptyActivation{}, + binary("f3_int_int_int", 0, false, func(lhs ref.Val, rhs ref.Val) ref.Val { return types.Int(2) }), + ), + expect: chain( + // this is the result from the override + expectValue(types.Int(2)), + // this is the original result that we + // should stil retain in the node. + expectOriginal(types.Int(15)), + ), + }, + + // Test Case 14 - evalVarArgs + // ------------------------------------------------ + // result: fun this is + // execution: f4_string_string_string_string (runtime override) + { + name: "OK_evalVarArgs_With_Overrides", + target: evalVarArgsCall(), + injector: injectVarArgs, + activation: activation( + &emptyActivation{}, + function("f4_string_string_string_string", 0, false, func(args ...ref.Val) ref.Val { + + max := len(args) + parts := make([]string, max) + for i := max; i > 0; i-- { + parts[max-i] = args[i-1].Value().(string) + } + return types.String(strings.Join(parts, " ")) + }), + ), + expect: chain( + // this is the result from the override + expectValue(types.String("fun is this")), + // this is the original result that we + // should stil retain in the node. + expectOriginal(types.String("this is fun")), + ), + }, + + // Test Case Group 2 - Overrides With Errors + // ----------------------------------------- + // The expectation is that an error is returned + // while trying to inject the overload override + // resolved at runtime because of signature + // mismatch. + + // Test Case 21 - evalZeroArity + // ------------------------------------------------ + // result: + // execution: f1_int (none) + { + name: "ERROR_evalZeroArity_With_Overrides", + target: evalZeroCall(), + injector: injectZeroArity, + activation: activation( + &emptyActivation{}, + unary("f1_int", 0, false, func(arg ref.Val) ref.Val { return types.Int(2) }), + ), + expect: chain( + // this is the result from the override + expectValue( + types.NewErrWithNodeID( + 51, + errorOverloadInjection, + OverloadSignatureError("f1", "f1_int", "", functionSignature), + ), + ), + // this is the original result that we + // should stil retain in the node. + expectOriginal(types.Int(0)), + ), + }, + + // Test Case 22 - evalUnary + // ------------------------------------------------ + // result: + // execution: f2_string_bool (none) + { + name: "ERROR_evalUnary_With_Overrides", + target: evalUnaryCall(), + injector: injectUnary, + activation: activation( + &emptyActivation{}, + function("f2_string_bool", 0, false, func(args ...ref.Val) ref.Val { + return types.False + }), + ), + expect: chain( + // this is the result from the override + expectValue( + types.NewErrWithNodeID( + 51, + errorOverloadInjection, + OverloadSignatureError("f2", "f2_string_bool", "", unarySignature), + ), + ), + // this is the original result that we + // should stil retain in the node. + expectOriginal(types.True), + ), + }, + + // Test Case 23 - evalBinary + // ------------------------------------------------ + // result: + // execution: f3_int_int_int (none) + { + name: "ERROR_evalBinary_With_Overrides", + target: evalBinaryCall(), + injector: injectBinary, + activation: activation( + &emptyActivation{}, + unary("f3_int_int_int", 0, false, func(arg ref.Val) ref.Val { return types.Int(2) }), + ), + expect: chain( + // this is the result from the override + expectValue( + types.NewErrWithNodeID( + 53, + errorOverloadInjection, + OverloadSignatureError("f3", "f3_int_int_int", "", binarySignature), + ), + ), + // this is the original result that we + // should stil retain in the node. + expectOriginal(types.Int(15)), + ), + }, + + // Test Case 24 - evalVarArgs + // ------------------------------------------------ + // result: + // execution: f4_string_string_string_string (none) + { + name: "ERROR_evalVarArgs_With_Overrides", + target: evalVarArgsCall(), + injector: injectVarArgs, + activation: activation( + &emptyActivation{}, + binary("f4_string_string_string_string", 0, false, func(lhs ref.Val, rhs ref.Val) ref.Val { return types.String("") }), + ), + expect: chain( + // this is the result from the override + expectValue( + types.NewErrWithNodeID( + 55, + errorOverloadInjection, + OverloadSignatureError("f4", "f4_string_string_string_string", "", functionSignature), + ), + ), + // this is the original result that we + // should stil retain in the node. + expectOriginal(types.String("this is fun")), + ), + }, + } + + for _, testCase := range testCases { + + t.Run(testCase.name, func(t *testing.T) { + + candidate := &evalLateBind{ + target: testCase.target, + injectOverload: testCase.injector, + } + + actual := candidate.Eval(testCase.activation) + testCase.expect(t, candidate.target, actual) + }) + } +} + +// TestInjector verifies the implemented behaviour of Injector. The expectation is for the +// function to generate a LateBindCallOption that maps the supplied OverloadInjector to the +// key resolved by the type associated to the given IntepretableCall implementation. +func TestInjector(t *testing.T) { + + // control flag to ensure that the supplied method + // is actually the one configured. + isMatchingInjector := false + + // expectation is a function that produces an OverloadInjector that sets the above flag if + // invoked. This is used as a control mechanism to ensure that we are actually invoking this + // function. We are not really interested in the implementation of the injector, rather that + // if we supply one that is the one used. + expectation := func(match *bool) func(target InterpretableCall, overload *functions.Overload, _ LateBindFlags) (InterpretableCall, error) { + + return func(target InterpretableCall, overload *functions.Overload, _ LateBindFlags) (InterpretableCall, error) { + + *match = true + + return nil, nil + } + } + + expected := expectation(&isMatchingInjector) + candidate := Injector(&evalUnary{}, expected) + + // check 1 - the LateBindCallOption should not be nil + if candidate == nil { + t.Fatal("Injector should return a non-nil LateBindCallOption") + } + + // check 2 - the Injector when invoked should modify the + // injector map with a new entry (if not there + // already). + config := &lateBindConfig{ + injectors: map[reflect.Type]OverloadInjector{}, + } + modified := candidate(config) + if len(modified.injectors) != 1 { + t.Fatalf("Injector did not add the supplied injector, injector map is empty (got: %d, want: %d)", len(modified.injectors), 1) + } + + // check 3 - the Injector should map the supplied injector + // to the key resolved by the supplied type. + key := reflect.TypeOf(&evalUnary{}) + actual, found := modified.injectors[key] + if !found { + t.Fatalf("Injector did not add the supplied injector for key '%s'", key) + } + + // check 4 - the Injector should configured he supplied injector + // and not a random method. + target := &evalUnary{ + id: 30, + arg: NewConstValue(31, types.String("hello")), + function: "f1", + overload: "f1_string_int", + impl: func(arg ref.Val) ref.Val { + return arg.(types.String).Size() + }, + trait: 0, + nonStrict: false, + } + + overload := &functions.Overload{ + Operator: "f1_string_int", + OperandTrait: 0, + NonStrict: false, + Unary: func(arg ref.Val) ref.Val { + return arg.(types.String).Size().(types.Int).Multiply(types.Int(2)) + }, + } + + actual(target, overload, LateBindFlagsNone) + if !isMatchingInjector { + t.Errorf("Injector did not configured the supplied OverloadInjector for key '%s'", key) + } + +} + +// testAllEvalTypes produces an array of test cases for the +// purpose of testing evalLateBind methods across all the +// known wrapped types. +func testAllEvalTypes() []evalLateBindTestCase { + + return []evalLateBindTestCase{ + { + name: "evalZeroArity", + target: &evalZeroArity{ + id: 45, + impl: func(_ ...ref.Val) ref.Val { + return types.Int(0) + }, + function: "f1", + overload: "f1_int", + }, + }, { + name: "evalUnary", + target: &evalUnary{ + id: 46, + arg: NewConstValue(47, types.Int(2)), + impl: func(_ ref.Val) ref.Val { + return types.Int(0) + }, + function: "f1", + overload: "f1_int_int", + nonStrict: true, + trait: 0, + }, + }, { + name: "evalBinary", + target: &evalBinary{ + id: 48, + lhs: NewConstValue(49, types.Int(2)), + rhs: NewConstValue(50, types.Int(5)), + impl: func(_ ref.Val, _ ref.Val) ref.Val { + return types.Int(3) + }, + function: "f1", + overload: "f1_int_int_int", + nonStrict: true, + trait: 0, + }, + }, { + name: "evalVarArgs", + target: &evalVarArgs{ + id: 51, + args: []Interpretable{ + NewConstValue(52, types.Int(2)), + NewConstValue(53, types.Int(5)), + NewConstValue(54, types.Int(5)), + }, + impl: func(_ ...ref.Val) ref.Val { + return types.Int(3) + }, + function: "f1", + overload: "f1_int_int_int_int", + nonStrict: true, + trait: 0, + }, + }, + } +} + +// unary is a convenience function to produce a reference to functions.Overload configured with +// the given parameters. This function sets the Overload.Unary to the given function. +func unary(operator string, operandTrait int, nonStrict bool, function functions.UnaryOp) *functions.Overload { + + return &functions.Overload{ + Operator: operator, + OperandTrait: operandTrait, + NonStrict: nonStrict, + Unary: function, + } +} + +// binary is a convenience function to produce a reference to functions.Overload configured with +// the given parameters. This function sets the Overload.Binary to the given function. +func binary(operator string, operandTrait int, nonStrict bool, function functions.BinaryOp) *functions.Overload { + + return &functions.Overload{ + Operator: operator, + OperandTrait: operandTrait, + NonStrict: nonStrict, + Binary: function, + } +} + +// function is a convenience function to produce a reference to functions.Overload configured with +// the given parameters. This function sets the Overload.Function to the given function. +func function(operator string, operandTrait int, nonStrict bool, function functions.FunctionOp) *functions.Overload { + + return &functions.Overload{ + Operator: operator, + OperandTrait: operandTrait, + NonStrict: nonStrict, + Function: function, + } +} + +// expectValue is a convenience function that produces an expectation function that checks +// that the value returned by the execution of Interpretable.Eval(Activation) matches the +// given expected value otherwise it fails the test. +func expectValue(expected ref.Val) func(t *testing.T, target Interpretable, actual ref.Val) { + + return func(t *testing.T, _ Interpretable, actual ref.Val) { + + t.Helper() + + if expected.Equal(actual) == types.False { + t.Errorf("unexpected value (got: %v, want: %v)", actual, expected) + } + } +} + +// expectType generates an expectation function that is used to validate that the name +// of Interpretable expression is the same of the name of the given type. +func expectType(expected Interpretable) func(t *testing.T, target Interpretable, actual ref.Val) { + + return func(t *testing.T, target Interpretable, _ ref.Val) { + + t.Helper() + + expectedType := reflect.TypeOf(expected).Name() + actualType := reflect.TypeOf(target).Name() + + if expectedType != actualType { + t.Errorf("unexpected type: (got: %s, want: %s)", actualType, expectedType) + } + } +} + +// chain is a convenience function that can be used to run in sequence multiple expectation +// functions that are passed as argument. The function returns an expectation function that +// executes all the functions in the checks array, with the actual arguments passed to the +// function. +func chain(checks ...func(t *testing.T, target Interpretable, actual ref.Val)) func(t *testing.T, target Interpretable, actual ref.Val) { + + return func(t *testing.T, target Interpretable, actual ref.Val) { + + for _, check := range checks { + check(t, target, actual) + } + } +}