Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,25 @@ func (e *Env) ToConfig(name string) (*env.Config, error) {
}
}

// Serialize validators
for _, val := range e.Validators() {
// Only add configurable validators to the env.Config as all others are
// expected to be implicitly enabled via extension libraries.
if confVal, ok := val.(ConfigurableASTValidator); ok {
conf.AddValidators(confVal.ToConfig())
}
}

// Serialize features
for featID, enabled := range e.features {
featName, found := featureNameByID(featID)
if !found {
// If the feature isn't named, it isn't intended to be publicly exposed
continue
}
conf.AddFeatures(env.NewFeature(featName, enabled))
}

return conf, nil
}

Expand Down Expand Up @@ -541,7 +560,7 @@ func (e *Env) Functions() map[string]*decls.FunctionDecl {

// Variables returns the set of variables associated with the environment.
func (e *Env) Variables() []*decls.VariableDecl {
return e.variables
return e.variables[:]
}

// HasValidator returns whether a specific ASTValidator has been configured in the environment.
Expand All @@ -554,6 +573,11 @@ func (e *Env) HasValidator(name string) bool {
return false
}

// Validators returns the set of ASTValidators configured on the environment.
func (e *Env) Validators() []ASTValidator {
return e.validators[:]
}

// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
//
// This form of Parse creates a Source value for the input `txt` and forwards to the
Expand Down
215 changes: 202 additions & 13 deletions cel/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"testing"

"github.com/google/cel-go/common"
"github.com/google/cel-go/common/ast"
"github.com/google/cel-go/common/decls"
"github.com/google/cel-go/common/env"
"github.com/google/cel-go/common/operators"
"github.com/google/cel-go/common/types"
Expand Down Expand Up @@ -401,6 +403,30 @@ func TestEnvToConfig(t *testing.T) {
},
want: env.NewConfig("context proto").SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")),
},
{
name: "feature flags",
opts: []EnvOption{
DefaultUTCTimeZone(false),
EnableMacroCallTracking(),
},
want: env.NewConfig("feature flags").AddFeatures(
env.NewFeature("cel.feature.macro_call_tracking", true),
),
},
{
name: "validators",
opts: []EnvOption{
ExtendedValidations(),
ASTValidators(ValidateComprehensionNestingLimit(1)),
},
want: env.NewConfig("validators").AddValidators(
env.NewValidator("cel.validator.duration"),
env.NewValidator("cel.validator.timestamp"),
env.NewValidator("cel.validator.matches"),
env.NewValidator("cel.validator.homogeneous_literals"),
env.NewValidator("cel.validator.comprehension_nesting_limit").SetConfig(map[string]any{"limit": 1}),
),
},
}

for _, tst := range tests {
Expand Down Expand Up @@ -430,11 +456,12 @@ func TestEnvFromConfig(t *testing.T) {
out ref.Val
}
tests := []struct {
name string
beforeOpts []EnvOption
afterOpts []EnvOption
conf *env.Config
exprs []exprCase
name string
beforeOpts []EnvOption
afterOpts []EnvOption
conf *env.Config
confHandlers []ConfigOptionFactory
exprs []exprCase
}{
{
name: "std env",
Expand Down Expand Up @@ -617,18 +644,138 @@ func TestEnvFromConfig(t *testing.T) {
},
},
},
{
name: "extensions - config factory",
conf: env.NewConfig("extensions").
AddExtensions(env.NewExtension("plus", math.MaxUint32)),
confHandlers: []ConfigOptionFactory{
func(a any) (EnvOption, bool) {
ext, ok := a.(*env.Extension)
if !ok || ext.Name != "plus" {
return nil, false
}
return Function("plus", Overload("plus_int_int", []*Type{IntType, IntType}, IntType,
decls.BinaryBinding(func(lhs, rhs ref.Val) ref.Val {
l := lhs.(types.Int)
r := rhs.(types.Int)
return l + r
}))), true
},
},
exprs: []exprCase{
{
name: "plus",
expr: "plus(1, 2)",
out: types.Int(3),
},
},
},
{
name: "features",
conf: env.NewConfig("features").
AddVariables(
env.NewVariable("m",
env.NewTypeDesc("map", env.NewTypeDesc("string"), env.NewTypeDesc("string")))).
AddFeatures(
env.NewFeature("cel.feature.backtick_escape_syntax", true),
env.NewFeature("cel.feature.unknown_feature_name", true)),
exprs: []exprCase{
{
name: "optional key",
expr: "m.`key-name` == 'value'",
in: map[string]any{"m": map[string]string{"key-name": "value"}},
out: types.True,
},
},
},
{
name: "validators",
conf: env.NewConfig("validators").
AddVariables(
env.NewVariable("m",
env.NewTypeDesc("map", env.NewTypeDesc("string"), env.NewTypeDesc("string"))),
).
AddValidators(
env.NewValidator(durationValidatorName),
env.NewValidator(timestampValidatorName),
env.NewValidator(regexValidatorName),
env.NewValidator(homogeneousValidatorName),
env.NewValidator(nestingLimitValidatorName).SetConfig(map[string]any{"limit": 0}),
),
exprs: []exprCase{
{
name: "bad duration",
expr: "duration('1')",
iss: errors.New("invalid duration"),
},
{
name: "bad timestamp",
expr: "timestamp('1')",
iss: errors.New("invalid timestamp"),
},
{
name: "bad regex",
expr: "'hello'.matches('?^()')",
iss: errors.New("invalid matches"),
},
{
name: "mixed type list",
expr: "[1, 2.0]",
iss: errors.New("expected type 'int'"),
},
{
name: "disabled comprehension",
expr: "[1, 2].exists(x, x % 2 == 0)",
iss: errors.New("comprehension exceeds nesting limit"),
},
},
},
{
name: "validators - config factory",
conf: env.NewConfig("validators").
AddValidators(
env.NewValidator("cel.validators.return_type").SetConfig(map[string]any{"type_name": "string"}),
),
confHandlers: []ConfigOptionFactory{
func(a any) (EnvOption, bool) {
val, ok := a.(*env.Validator)
if !ok || val.Name != "cel.validators.return_type" {
return nil, false
}
typeName, found := val.ConfigValue("type_name")
if !found {
return func(*Env) (*Env, error) {
return nil, fmt.Errorf("invalid validator: %s missing config parameter 'type_name'", val.Name)
}, true
}
return func(e *Env) (*Env, error) {
t, err := env.NewTypeDesc(typeName.(string)).AsCELType(e.CELTypeProvider())
if err != nil {
return nil, err
}
return ASTValidators(returnTypeValidator{returnType: t})(e)
}, true
},
},
exprs: []exprCase{
{
name: "string - ok",
expr: "'hello'",
out: types.String("hello"),
},
{
name: "int - error",
expr: "1",
iss: errors.New("unsupported return type: int, want string"),
},
},
},
}
for _, tst := range tests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
opts := tc.beforeOpts
opts = append(opts, FromConfig(tc.conf, func(elem any) (EnvOption, bool) {
if ext, ok := elem.(*env.Extension); ok && ext.Name == "optional" {
ver, _ := ext.GetVersion()
return OptionalTypes(OptionalTypesVersion(ver)), true
}
return nil, false
}))
opts = append(opts, FromConfig(tc.conf, tc.confHandlers...))
opts = append(opts, tc.afterOpts...)
var e *Env
var err error
Expand Down Expand Up @@ -679,6 +826,16 @@ func TestEnvFromConfigErrors(t *testing.T) {
conf *env.Config
want error
}{
{
name: "bad container",
conf: env.NewConfig("bad container").SetContainer(".hello.world"),
want: errors.New("container name must not contain"),
},
{
name: "colliding imports",
conf: env.NewConfig("colliding imports").AddImports(env.NewImport("pkg.ImportName"), env.NewImport("pkg2.ImportName")),
want: errors.New("abbreviation collides"),
},
{
name: "invalid subset",
conf: env.NewConfig("invalid subset").SetStdLib(env.NewLibrarySubset().SetDisableMacros(true)),
Expand Down Expand Up @@ -707,9 +864,21 @@ func TestEnvFromConfigErrors(t *testing.T) {
{
name: "unrecognized extension",
conf: env.NewConfig("unrecognized extension").
AddExtensions(env.NewExtension("optional", math.MaxUint32)),
AddExtensions(env.NewExtension("unrecognized", math.MaxUint32)),
want: errors.New("unrecognized extension"),
},
{
name: "invalid validator config",
conf: env.NewConfig("invalid validator config").
AddValidators(env.NewValidator("cel.validator.comprehension_nesting_limit")),
want: errors.New("invalid validator"),
},
{
name: "invalid validator config type",
conf: env.NewConfig("invalid validator config").
AddValidators(env.NewValidator("cel.validator.comprehension_nesting_limit").SetConfig(map[string]any{"limit": 2.0})),
want: errors.New("invalid validator"),
},
}
for _, tst := range tests {
tc := tst
Expand Down Expand Up @@ -829,6 +998,26 @@ func mustContextProto(t *testing.T, pb proto.Message) Activation {
return ctx
}

type returnTypeValidator struct {
returnType *Type
}

func (returnTypeValidator) Name() string {
return "cel.validators.return_type"
}

func (v returnTypeValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) {
if a.GetType(a.Expr().ID()) != v.returnType {
iss.ReportErrorAtID(a.Expr().ID(),
"unsupported return type: %s, want %s",
a.GetType(a.Expr().ID()), v.returnType.TypeName())
}
}

func (v returnTypeValidator) ToConfig() *env.Validator {
return env.NewValidator(v.Name()).SetConfig(map[string]any{"type_name": v.returnType.TypeName()})
}

type customLegacyProvider struct {
provider ref.TypeProvider
}
Expand Down
Loading