Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions cel/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,38 @@ func Function(name string, opts ...FunctionOpt) EnvOption {
}
}

// OverloadSelector selects an overload associated with a given function when it returns true.
//
// Used in combination with the FunctionDecl.Subset method.
type OverloadSelector = decls.OverloadSelector

// IncludeOverloads defines an OverloadSelector which allow-lists a set of overloads by their ids.
func IncludeOverloads(overloadIDs ...string) OverloadSelector {
return decls.IncludeOverloads(overloadIDs...)
}

// ExcludeOverloads defines an OverloadSelector which deny-lists a set of overloads by their ids.
func ExcludeOverloads(overloadIDs ...string) OverloadSelector {
return decls.ExcludeOverloads(overloadIDs...)
}

// FunctionDecls provides one or more fully formed function declaration to be added to the environment.
func FunctionDecls(funcs ...*decls.FunctionDecl) EnvOption {
return func(e *Env) (*Env, error) {
var err error
for _, fn := range funcs {
if existing, found := e.functions[fn.Name()]; found {
fn, err = existing.Merge(fn)
if err != nil {
return nil, err
}
}
e.functions[fn.Name()] = fn
}
return e, nil
}
}

// FunctionOpt defines a functional option for configuring a function declaration.
type FunctionOpt = decls.FunctionOpt

Expand Down
182 changes: 182 additions & 0 deletions cel/decls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/google/cel-go/common/functions"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/overloads"
"github.com/google/cel-go/common/stdlib"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/google/cel-go/common/types/traits"
Expand Down Expand Up @@ -780,6 +781,187 @@ func TestExprDeclToDeclarationInvalid(t *testing.T) {
}
}

func TestFunctionDeclExcludeOverloads(t *testing.T) {
funcs := []*decls.FunctionDecl{}
for _, fn := range stdlib.Functions() {
if fn.Name() == operators.Add {
fn = fn.Subset(ExcludeOverloads(overloads.AddList, overloads.AddBytes, overloads.AddString))
}
funcs = append(funcs, fn)
}
env, err := NewCustomEnv(FunctionDecls(funcs...))
if err != nil {
t.Fatalf("NewCustomEnv() failed: %v", err)
}

successTests := []struct {
name string
expr string
want ref.Val
}{
{
name: "ints",
expr: "1 + 1",
want: types.Int(2),
},
{
name: "doubles",
expr: "1.5 + 1.5",
want: types.Double(3.0),
},
{
name: "uints",
expr: "1u + 2u",
want: types.Uint(3),
},
{
name: "timestamp plus duration",
expr: "timestamp('2001-01-01T00:00:00Z') + duration('1h') == timestamp('2001-01-01T01:00:00Z')",
want: types.True,
},
{
name: "durations",
expr: "duration('1h') + duration('1m') == duration('1h1m')",
want: types.True,
},
}
for _, tst := range successTests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatalf("prg.Eval() errored: %v", err)
}
if out.Equal(tc.want) != types.True {
t.Errorf("Eval() got %v, wanted %v", out, tc.want)
}
})
}
failureTests := []struct {
name string
expr string
}{
{
name: "strings",
expr: "'a' + 'b'",
},
{
name: "bytes",
expr: "b'123' + b'456'",
},
{
name: "lists",
expr: "[1] + [2, 3]",
},
}
for _, tst := range failureTests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
_, iss := env.Compile(tc.expr)
if iss.Err() == nil {
t.Error("env.Compile() got ast, wanted error")
}
})
}
}

func TestFunctionDeclIncludeOverloads(t *testing.T) {
funcs := []*decls.FunctionDecl{}
for _, fn := range stdlib.Functions() {
if fn.Name() == operators.Add {
fn = fn.Subset(IncludeOverloads(overloads.AddInt64, overloads.AddDouble))
}
funcs = append(funcs, fn)
}
env, err := NewCustomEnv(FunctionDecls(funcs...))
if err != nil {
t.Fatalf("NewCustomEnv() failed: %v", err)
}

successTests := []struct {
name string
expr string
want ref.Val
}{
{
name: "ints",
expr: "1 + 1",
want: types.Int(2),
},
{
name: "doubles",
expr: "1.5 + 1.5",
want: types.Double(3.0),
},
}
for _, tst := range successTests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(NoVars())
if err != nil {
t.Fatalf("prg.Eval() errored: %v", err)
}
if out.Equal(tc.want) != types.True {
t.Errorf("Eval() got %v, wanted %v", out, tc.want)
}
})
}
failureTests := []struct {
name string
expr string
}{
{
name: "strings",
expr: "'a' + 'b'",
},
{
name: "bytes",
expr: "b'123' + b'456'",
},
{
name: "lists",
expr: "[1] + [2, 3]",
},
{
name: "uints",
expr: "1u + 2u",
},
{
name: "timestamp plus duration",
expr: "timestamp('2001-01-01T00:00:00Z') + duration('1h') == timestamp('2001-01-01T01:00:00Z')",
},
{
name: "durations",
expr: "duration('1h') + duration('1m') == duration('1h1m')",
},
}
for _, tst := range failureTests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
_, iss := env.Compile(tc.expr)
if iss.Err() == nil {
t.Error("env.Compile() got ast, wanted error")
}
})
}
}

func testParse(t testing.TB, env *Env, expr string, want any) {
t.Helper()
ast, iss := env.Parse(expr)
Expand Down
54 changes: 54 additions & 0 deletions common/decls/decls.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,60 @@ func (f *FunctionDecl) Merge(other *FunctionDecl) (*FunctionDecl, error) {
return merged, nil
}

// OverloadSelector selects an overload associated with a given function when it returns true.
//
// Used in combination with the Subset method.
type OverloadSelector func(overload *OverloadDecl) bool

// IncludeOverloads defines an OverloadSelector which allow-lists a set of overloads by their ids.
func IncludeOverloads(overloadIDs ...string) OverloadSelector {
return func(overload *OverloadDecl) bool {
for _, oID := range overloadIDs {
if overload.id == oID {
return true
}
}
return false
}
}

// ExcludeOverloads defines an OverloadSelector which deny-lists a set of overloads by their ids.
func ExcludeOverloads(overloadIDs ...string) OverloadSelector {
return func(overload *OverloadDecl) bool {
for _, oID := range overloadIDs {
if overload.id == oID {
return false
}
}
return true
}
}

// Subset returns a new function declaration which contains only the overloads with the specified IDs.
func (f *FunctionDecl) Subset(selector OverloadSelector) *FunctionDecl {
if f == nil {
return nil
}
overloads := make(map[string]*OverloadDecl)
overloadOrdinals := make([]string, 0, len(f.overloadOrdinals))
for _, oID := range f.overloadOrdinals {
overload := f.overloads[oID]
if selector(overload) {
overloads[oID] = overload
overloadOrdinals = append(overloadOrdinals, oID)
}
}
subset := &FunctionDecl{
name: f.Name(),
overloads: overloads,
singleton: f.singleton,
disableTypeGuards: f.disableTypeGuards,
state: f.state,
overloadOrdinals: overloadOrdinals,
}
return subset
}

// AddOverload ensures that the new overload does not collide with an existing overload signature;
// however, if the function signatures are identical, the implementation may be rewritten as its
// difficult to compare functions by object identity.
Expand Down