From 766076f5f122e7bbb111c786e1559cbf0b9b972f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 20 Jul 2023 10:44:11 -0700 Subject: [PATCH 01/31] Bump word-wrap from 1.2.3 to 1.2.4 in /repl/appengine/web (#783) Bumps [word-wrap](https://github.com/jonschlinkert/word-wrap) from 1.2.3 to 1.2.4. - [Release notes](https://github.com/jonschlinkert/word-wrap/releases) - [Commits](https://github.com/jonschlinkert/word-wrap/compare/1.2.3...1.2.4) --- updated-dependencies: - dependency-name: word-wrap dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- repl/appengine/web/package-lock.json | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/repl/appengine/web/package-lock.json b/repl/appengine/web/package-lock.json index 29d7ce57a..1bced413f 100644 --- a/repl/appengine/web/package-lock.json +++ b/repl/appengine/web/package-lock.json @@ -13274,9 +13274,9 @@ "dev": true }, "node_modules/word-wrap": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.3.tgz", - "integrity": "sha512-Hz/mrNwitNRh/HUAtM/VT/5VH+ygD6DV7mYKZAtHOrbs8U7lvPS6xf7EJKMF0uW1KJCl0H701g3ZGus+muE5vQ==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/word-wrap/-/word-wrap-1.2.4.tgz", + "integrity": "sha512-2V81OA4ugVo5pRo46hAoD2ivUJx8jXmWXfUkY4KFNw0hEptvN0QfH3K4nHiwzGeKl5rFKedV48QVoqYavy4YpA==", "dev": true, "engines": { "node": ">=0.10.0" From 965e9c818f37cb335b6b04e181a87d124038743d Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 31 Jul 2023 13:19:04 -0400 Subject: [PATCH 02/31] String format validator (#775) * String format validator * Remove unused method from prior string validation --- cel/decls.go | 40 -- cel/options.go | 2 + cel/program.go | 60 +-- cel/validator.go | 11 - common/ast/BUILD.bazel | 1 + common/ast/expr.go | 9 +- ext/BUILD.bazel | 2 + ext/formatting.go | 878 ++++++++++++++++++++++++++++++++++++++ ext/strings.go | 419 ++---------------- ext/strings_test.go | 35 +- interpreter/BUILD.bazel | 1 - interpreter/formatting.go | 383 ----------------- 12 files changed, 949 insertions(+), 892 deletions(-) create mode 100644 ext/formatting.go delete mode 100644 interpreter/formatting.go diff --git a/cel/decls.go b/cel/decls.go index 0f9501341..b59e3708d 100644 --- a/cel/decls.go +++ b/cel/decls.go @@ -353,43 +353,3 @@ func ExprDeclToDeclaration(d *exprpb.Decl) (EnvOption, error) { return nil, fmt.Errorf("unsupported decl: %v", d) } } - -func typeValueToKind(tv ref.Type) (Kind, error) { - switch tv { - case types.BoolType: - return BoolKind, nil - case types.DoubleType: - return DoubleKind, nil - case types.IntType: - return IntKind, nil - case types.UintType: - return UintKind, nil - case types.ListType: - return ListKind, nil - case types.MapType: - return MapKind, nil - case types.StringType: - return StringKind, nil - case types.BytesType: - return BytesKind, nil - case types.DurationType: - return DurationKind, nil - case types.TimestampType: - return TimestampKind, nil - case types.NullType: - return NullTypeKind, nil - case types.TypeType: - return TypeKind, nil - default: - switch tv.TypeName() { - case "dyn": - return DynKind, nil - case "google.protobuf.Any": - return AnyKind, nil - case "optional": - return OpaqueKind, nil - default: - return 0, fmt.Errorf("no known conversion for type of %s", tv.TypeName()) - } - } -} diff --git a/cel/options.go b/cel/options.go index d47f55d8e..a6bff0dc4 100644 --- a/cel/options.go +++ b/cel/options.go @@ -447,6 +447,8 @@ const ( OptTrackCost EvalOption = 1 << iota // OptCheckStringFormat enables compile-time checking of string.format calls for syntax/cardinality. + // + // Deprecated: use ext.ValidateFormatString() as this option is now a no-op. OptCheckStringFormat EvalOption = 1 << iota ) diff --git a/cel/program.go b/cel/program.go index 11c5c447e..60fbfca46 100644 --- a/cel/program.go +++ b/cel/program.go @@ -19,7 +19,7 @@ import ( "fmt" "sync" - celast "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/interpreter" @@ -148,7 +148,7 @@ func (p *prog) clone() *prog { // ProgramOption values. // // If the program cannot be configured the prog will be nil, with a non-nil error response. -func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { +func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) { // Build the dispatcher, interpreter, and default program value. disp := interpreter.NewDispatcher() @@ -208,34 +208,6 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { if len(p.regexOptimizations) > 0 { decorators = append(decorators, interpreter.CompileRegexConstants(p.regexOptimizations...)) } - // Enable compile-time checking of syntax/cardinality for string.format calls. - if p.evalOpts&OptCheckStringFormat == OptCheckStringFormat { - var isValidType func(id int64, validTypes ...ref.Type) (bool, error) - if ast.IsChecked() { - isValidType = func(id int64, validTypes ...ref.Type) (bool, error) { - t := ast.typeMap[id] - if t.Kind() == DynKind { - return true, nil - } - for _, vt := range validTypes { - k, err := typeValueToKind(vt) - if err != nil { - return false, err - } - if t.Kind() == k { - return true, nil - } - } - return false, nil - } - } else { - // if the AST isn't type-checked, short-circuit validation - isValidType = func(id int64, validTypes ...ref.Type) (bool, error) { - return true, nil - } - } - decorators = append(decorators, interpreter.InterpolateFormattedString(isValidType)) - } // Enable exhaustive eval, state tracking and cost tracking last since they require a factory. if p.evalOpts&(OptExhaustiveEval|OptTrackState|OptTrackCost) != 0 { @@ -263,18 +235,18 @@ func newProgram(e *Env, ast *Ast, opts []ProgramOption) (Program, error) { decs = append(decs, interpreter.Observe(observers...)) } - return p.clone().initInterpretable(ast, decs) + return p.clone().initInterpretable(a, decs) } return newProgGen(factory) } - return p.initInterpretable(ast, decorators) + return p.initInterpretable(a, decorators) } -func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) { +func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) { // Unchecked programs do not contain type and reference information and may be slower to execute. - if !ast.IsChecked() { + if !a.IsChecked() { interpretable, err := - p.interpreter.NewUncheckedInterpretable(ast.Expr(), decs...) + p.interpreter.NewUncheckedInterpretable(a.Expr(), decs...) if err != nil { return nil, err } @@ -283,12 +255,7 @@ func (p *prog) initInterpretable(ast *Ast, decs []interpreter.InterpretableDecor } // When the AST has been checked it contains metadata that can be used to speed up program execution. - checked := &celast.CheckedAST{ - Expr: ast.Expr(), - SourceInfo: ast.SourceInfo(), - TypeMap: ast.typeMap, - ReferenceMap: ast.refMap, - } + checked := astToCheckedAST(a) interpretable, err := p.interpreter.NewInterpretable(checked, decs...) if err != nil { return nil, err @@ -558,9 +525,16 @@ func (p *evalActivationPool) Put(value any) { p.Pool.Put(a) } -var ( - emptyEvalState = interpreter.NewEvalState() +func astToCheckedAST(a *Ast) *ast.CheckedAST { + return &ast.CheckedAST{ + Expr: a.Expr(), + SourceInfo: a.SourceInfo(), + TypeMap: a.typeMap, + ReferenceMap: a.refMap, + } +} +var ( // activationPool is an internally managed pool of Activation values that wrap map[string]any inputs activationPool = newEvalActivationPool() diff --git a/cel/validator.go b/cel/validator.go index 78b311381..b5877d76f 100644 --- a/cel/validator.go +++ b/cel/validator.go @@ -244,17 +244,6 @@ func (homogeneousAggregateLiteralValidator) Name() string { return homogeneousValidatorName } -// Configure implements the ASTValidatorConfigurer interface and currently sets the list of standard -// and exempt functions from homogeneous aggregate literal checks. -// -// TODO: Move this call into the string.format() ASTValidator once ported. -func (homogeneousAggregateLiteralValidator) Configure(c MutableValidatorConfig) error { - emptyList := []string{} - exemptFunctions := c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, emptyList).([]string) - exemptFunctions = append(exemptFunctions, "format") - return c.Set(HomogeneousAggregateLiteralExemptFunctions, exemptFunctions) -} - // Validate validates that all lists and map literals have homogeneous types, i.e. don't contain dyn types. // // This validator makes an exception for list and map literals which occur at any level of nesting within diff --git a/common/ast/BUILD.bazel b/common/ast/BUILD.bazel index 7269cdff5..002327b9b 100644 --- a/common/ast/BUILD.bazel +++ b/common/ast/BUILD.bazel @@ -5,6 +5,7 @@ package( "//cel:__subpackages__", "//checker:__subpackages__", "//common:__subpackages__", + "//ext:__subpackages__", "//interpreter:__subpackages__", ], licenses = ["notice"], # Apache 2.0 diff --git a/common/ast/expr.go b/common/ast/expr.go index b63884a60..489c177b3 100644 --- a/common/ast/expr.go +++ b/common/ast/expr.go @@ -78,7 +78,7 @@ func KindMatcher(kind ExprKind) ExprMatcher { } // FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose -// function name is equal to `funcName`. +// function name is equal to `funcName` regardless of whether it is a member or a global function. func FunctionMatcher(funcName string) ExprMatcher { return func(e NavigableExpr) bool { if e.Kind() != CallKind { @@ -217,6 +217,9 @@ type NavigableCallExpr interface { // FunctionName returns the name of the function. FunctionName() string + // IsMemberFunction returns whether the call has a non-nil target indicating it is a member function + IsMemberFunction() bool + // Target returns the target of the expression if one is present. Target() NavigableExpr @@ -444,6 +447,10 @@ func (call navigableCallImpl) FunctionName() string { return call.ToExpr().GetCallExpr().GetFunction() } +func (call navigableCallImpl) IsMemberFunction() bool { + return call.ToExpr().GetCallExpr().GetTarget() != nil +} + func (call navigableCallImpl) Target() NavigableExpr { t := call.ToExpr().GetCallExpr().GetTarget() if t != nil { diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index ebdf7d01e..c6e57392c 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -8,6 +8,7 @@ go_library( name = "go_default_library", srcs = [ "encoders.go", + "formatting.go", "guards.go", "lists.go", "math.go", @@ -21,6 +22,7 @@ go_library( deps = [ "//cel:go_default_library", "//checker/decls:go_default_library", + "//common/ast:go_default_library", "//common/overloads:go_default_library", "//common/types:go_default_library", "//common/types/pb:go_default_library", diff --git a/ext/formatting.go b/ext/formatting.go new file mode 100644 index 000000000..178316959 --- /dev/null +++ b/ext/formatting.go @@ -0,0 +1,878 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ext + +import ( + "errors" + "fmt" + "math" + "sort" + "strconv" + "strings" + "unicode" + + "golang.org/x/text/language" + "golang.org/x/text/message" + + "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +type clauseImpl func(ref.Val, string) (string, error) + +func clauseForType(argType ref.Type) (clauseImpl, error) { + switch argType { + case types.IntType, types.UintType: + return formatDecimal, nil + case types.StringType, types.BytesType, types.BoolType, types.NullType, types.TypeType: + return FormatString, nil + case types.TimestampType, types.DurationType: + // special case to ensure timestamps/durations get printed as CEL literals + return func(arg ref.Val, locale string) (string, error) { + argStrVal := arg.ConvertToType(types.StringType) + argStr := argStrVal.Value().(string) + if arg.Type() == types.TimestampType { + return fmt.Sprintf("timestamp(%q)", argStr), nil + } + if arg.Type() == types.DurationType { + return fmt.Sprintf("duration(%q)", argStr), nil + } + return "", fmt.Errorf("cannot convert argument of type %s to timestamp/duration", arg.Type().TypeName()) + }, nil + case types.ListType: + return formatList, nil + case types.MapType: + return formatMap, nil + case types.DoubleType: + // avoid formatFixed so we can output a period as the decimal separator in order + // to always be a valid CEL literal + return func(arg ref.Val, locale string) (string, error) { + argDouble, ok := arg.Value().(float64) + if !ok { + return "", fmt.Errorf("couldn't convert %s to float64", arg.Type().TypeName()) + } + fmtStr := fmt.Sprintf("%%.%df", defaultPrecision) + return fmt.Sprintf(fmtStr, argDouble), nil + }, nil + case types.TypeType: + return func(arg ref.Val, locale string) (string, error) { + return fmt.Sprintf("type(%s)", arg.Value().(string)), nil + }, nil + default: + return nil, fmt.Errorf("no formatting function for %s", argType.TypeName()) + } +} + +func formatList(arg ref.Val, locale string) (string, error) { + argList := arg.(traits.Lister) + argIterator := argList.Iterator() + var listStrBuilder strings.Builder + _, err := listStrBuilder.WriteRune('[') + if err != nil { + return "", fmt.Errorf("error writing to list string: %w", err) + } + for argIterator.HasNext() == types.True { + member := argIterator.Next() + memberFormat, err := clauseForType(member.Type()) + if err != nil { + return "", err + } + unquotedStr, err := memberFormat(member, locale) + if err != nil { + return "", err + } + str := quoteForCEL(member, unquotedStr) + _, err = listStrBuilder.WriteString(str) + if err != nil { + return "", fmt.Errorf("error writing to list string: %w", err) + } + if argIterator.HasNext() == types.True { + _, err = listStrBuilder.WriteString(", ") + if err != nil { + return "", fmt.Errorf("error writing to list string: %w", err) + } + } + } + _, err = listStrBuilder.WriteRune(']') + if err != nil { + return "", fmt.Errorf("error writing to list string: %w", err) + } + return listStrBuilder.String(), nil +} + +func formatMap(arg ref.Val, locale string) (string, error) { + argMap := arg.(traits.Mapper) + argIterator := argMap.Iterator() + type mapPair struct { + key string + value string + } + argPairs := make([]mapPair, argMap.Size().Value().(int64)) + i := 0 + for argIterator.HasNext() == types.True { + key := argIterator.Next() + var keyFormat clauseImpl + switch key.Type() { + case types.StringType, types.BoolType: + keyFormat = FormatString + case types.IntType, types.UintType: + keyFormat = formatDecimal + default: + return "", fmt.Errorf("no formatting function for map key of type %s", key.Type().TypeName()) + } + unquotedKeyStr, err := keyFormat(key, locale) + if err != nil { + return "", err + } + keyStr := quoteForCEL(key, unquotedKeyStr) + value, found := argMap.Find(key) + if !found { + return "", fmt.Errorf("could not find key: %q", key) + } + valueFormat, err := clauseForType(value.Type()) + if err != nil { + return "", err + } + unquotedValueStr, err := valueFormat(value, locale) + if err != nil { + return "", err + } + valueStr := quoteForCEL(value, unquotedValueStr) + argPairs[i] = mapPair{keyStr, valueStr} + i++ + } + sort.SliceStable(argPairs, func(x, y int) bool { + return argPairs[x].key < argPairs[y].key + }) + var mapStrBuilder strings.Builder + _, err := mapStrBuilder.WriteRune('{') + if err != nil { + return "", fmt.Errorf("error writing to map string: %w", err) + } + for i, entry := range argPairs { + _, err = mapStrBuilder.WriteString(fmt.Sprintf("%s:%s", entry.key, entry.value)) + if err != nil { + return "", fmt.Errorf("error writing to map string: %w", err) + } + if i < len(argPairs)-1 { + _, err = mapStrBuilder.WriteString(", ") + if err != nil { + return "", fmt.Errorf("error writing to map string: %w", err) + } + } + } + _, err = mapStrBuilder.WriteRune('}') + if err != nil { + return "", fmt.Errorf("error writing to map string: %w", err) + } + return mapStrBuilder.String(), nil +} + +// quoteForCEL takes a formatted, unquoted value and quotes it in a manner suitable +// for embedding directly in CEL. +func quoteForCEL(refVal ref.Val, unquotedValue string) string { + switch refVal.Type() { + case types.StringType: + return fmt.Sprintf("%q", unquotedValue) + case types.BytesType: + return fmt.Sprintf("b%q", unquotedValue) + case types.DoubleType: + // special case to handle infinity/NaN + num := refVal.Value().(float64) + if math.IsInf(num, 1) || math.IsInf(num, -1) || math.IsNaN(num) { + return fmt.Sprintf("%q", unquotedValue) + } + return unquotedValue + default: + return unquotedValue + } +} + +// FormatString returns the string representation of a CEL value. +// +// It is used to implement the %s specifier in the (string).format() extension function. +func FormatString(arg ref.Val, locale string) (string, error) { + switch arg.Type() { + case types.ListType: + return formatList(arg, locale) + case types.MapType: + return formatMap(arg, locale) + case types.IntType, types.UintType, types.DoubleType, + types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType: + argStrVal := arg.ConvertToType(types.StringType) + argStr, ok := argStrVal.Value().(string) + if !ok { + return "", fmt.Errorf("could not convert argument %q to string", argStrVal) + } + return argStr, nil + case types.NullType: + return "null", nil + default: + return "", stringFormatError(runtimeID, arg.Type().TypeName()) + } +} + +func formatDecimal(arg ref.Val, locale string) (string, error) { + switch arg.Type() { + case types.IntType: + argInt, ok := arg.ConvertToType(types.IntType).Value().(int64) + if !ok { + return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value()) + } + return fmt.Sprintf("%d", argInt), nil + case types.UintType: + argInt, ok := arg.ConvertToType(types.UintType).Value().(uint64) + if !ok { + return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value()) + } + return fmt.Sprintf("%d", argInt), nil + default: + return "", decimalFormatError(runtimeID, arg.Type().TypeName()) + } +} + +func matchLanguage(locale string) (language.Tag, error) { + matcher, err := makeMatcher(locale) + if err != nil { + return language.Und, err + } + tag, _ := language.MatchStrings(matcher, locale) + return tag, nil +} + +func makeMatcher(locale string) (language.Matcher, error) { + tags := make([]language.Tag, 0) + tag, err := language.Parse(locale) + if err != nil { + return nil, err + } + tags = append(tags, tag) + return language.NewMatcher(tags), nil +} + +type stringFormatter struct{} + +func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) { + return FormatString(arg, locale) +} + +func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) { + return formatDecimal(arg, locale) +} + +func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) { + if precision == nil { + precision = new(int) + *precision = defaultPrecision + } + return func(arg ref.Val, locale string) (string, error) { + strException := false + if arg.Type() == types.StringType { + argStr := arg.Value().(string) + if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" { + strException = true + } + } + if arg.Type() != types.DoubleType && !strException { + return "", fixedPointFormatError(runtimeID, arg.Type().TypeName()) + } + argFloatVal := arg.ConvertToType(types.DoubleType) + argFloat, ok := argFloatVal.Value().(float64) + if !ok { + return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value()) + } + fmtStr := fmt.Sprintf("%%.%df", *precision) + + matchedLocale, err := matchLanguage(locale) + if err != nil { + return "", fmt.Errorf("error matching locale: %w", err) + } + return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil + } +} + +func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) { + if precision == nil { + precision = new(int) + *precision = defaultPrecision + } + return func(arg ref.Val, locale string) (string, error) { + strException := false + if arg.Type() == types.StringType { + argStr := arg.Value().(string) + if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" { + strException = true + } + } + if arg.Type() != types.DoubleType && !strException { + return "", scientificFormatError(runtimeID, arg.Type().TypeName()) + } + argFloatVal := arg.ConvertToType(types.DoubleType) + argFloat, ok := argFloatVal.Value().(float64) + if !ok { + return "", fmt.Errorf("could not convert \"%v\" to float64", argFloatVal.Value()) + } + matchedLocale, err := matchLanguage(locale) + if err != nil { + return "", fmt.Errorf("error matching locale: %w", err) + } + fmtStr := fmt.Sprintf("%%%de", *precision) + return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil + } +} + +func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) { + switch arg.Type() { + case types.IntType: + argInt := arg.Value().(int64) + // locale is intentionally unused as integers formatted as binary + // strings are locale-independent + return fmt.Sprintf("%b", argInt), nil + case types.UintType: + argInt := arg.Value().(uint64) + return fmt.Sprintf("%b", argInt), nil + case types.BoolType: + argBool := arg.Value().(bool) + if argBool { + return "1", nil + } + return "0", nil + default: + return "", binaryFormatError(runtimeID, arg.Type().TypeName()) + } +} + +func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) { + return func(arg ref.Val, locale string) (string, error) { + fmtStr := "%x" + if useUpper { + fmtStr = "%X" + } + switch arg.Type() { + case types.StringType, types.BytesType: + if arg.Type() == types.BytesType { + return fmt.Sprintf(fmtStr, arg.Value().([]byte)), nil + } + return fmt.Sprintf(fmtStr, arg.Value().(string)), nil + case types.IntType: + argInt, ok := arg.Value().(int64) + if !ok { + return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value()) + } + return fmt.Sprintf(fmtStr, argInt), nil + case types.UintType: + argInt, ok := arg.Value().(uint64) + if !ok { + return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value()) + } + return fmt.Sprintf(fmtStr, argInt), nil + default: + return "", hexFormatError(runtimeID, arg.Type().TypeName()) + } + } +} + +func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) { + switch arg.Type() { + case types.IntType: + argInt := arg.Value().(int64) + return fmt.Sprintf("%o", argInt), nil + case types.UintType: + argInt := arg.Value().(uint64) + return fmt.Sprintf("%o", argInt), nil + default: + return "", octalFormatError(runtimeID, arg.Type().TypeName()) + } +} + +// stringFormatValidator implements the cel.ASTValidator interface allowing for static validation +// of string.format calls. +type stringFormatValidator struct{} + +// Name returns the name of the validator. +func (stringFormatValidator) Name() string { + return "cel.lib.ext.validate.functions.string.format" +} + +// Configure implements the ASTValidatorConfigurer interface and augments the list of functions to skip +// during homogeneous aggregate literal type-checks. +func (stringFormatValidator) Configure(config cel.MutableValidatorConfig) error { + functions := config.GetOrDefault(cel.HomogeneousAggregateLiteralExemptFunctions, []string{}).([]string) + functions = append(functions, "format") + return config.Set(cel.HomogeneousAggregateLiteralExemptFunctions, functions) +} + +// Validate parses all literal format strings and type checks the format clause against the argument +// at the corresponding ordinal within the list literal argument to the function, if one is specified. +func (stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.CheckedAST, iss *cel.Issues) { + root := ast.NavigateCheckedAST(a) + formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs) + for _, e := range formatCallExprs { + call := e.AsCall() + formatStr := call.Target().AsLiteral().Value().(string) + args := call.Args()[0].AsList().Elements() + formatCheck := &stringFormatChecker{ + args: args, + typeMap: a.TypeMap, + } + // use a placeholder locale, since locale doesn't affect syntax + _, err := parseFormatString(formatStr, formatCheck, formatCheck, "en_US") + if err != nil { + iss.ReportErrorAtID(getErrorExprID(e.ID(), err), err.Error()) + continue + } + seenArgs := formatCheck.argsRequested + if len(args) > seenArgs { + iss.ReportErrorAtID(e.ID(), + "too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(args)) + } + } +} + +// getErrorExprID determines which list literal argument triggered a type-disagreement for the +// purposes of more accurate error message reports. +func getErrorExprID(id int64, err error) int64 { + fmtErr, ok := err.(formatError) + if ok { + return fmtErr.id + } + wrapped := errors.Unwrap(err) + if wrapped != nil { + return getErrorExprID(id, wrapped) + } + return id +} + +// matchConstantFormatStringWithListLiteralArgs matches all valid expression nodes for string +// format checking. +func matchConstantFormatStringWithListLiteralArgs(e ast.NavigableExpr) bool { + if e.Kind() != ast.CallKind { + return false + } + call := e.AsCall() + // TODO: expose the OverloadID() method on the NavigableCallExpr to refine this check. + if !call.IsMemberFunction() || call.FunctionName() != "format" { + return false + } + formatString := call.Target() + if formatString.Kind() != ast.LiteralKind && formatString.Type() != cel.StringType { + return false + } + args := call.Args() + if len(args) != 1 { + return false + } + formatArgs := args[0] + return formatArgs.Kind() == ast.ListKind +} + +// stringFormatChecker implements the formatStringInterpolater interface +type stringFormatChecker struct { + args []ast.NavigableExpr + argsRequested int + currArgIndex int64 + typeMap map[int64]*cel.Type +} + +func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) { + formatArg := c.args[c.currArgIndex] + valid, badID := c.verifyString(formatArg) + if !valid { + return "", stringFormatError(badID, c.typeOf(badID).TypeName()) + } + return "", nil +} + +func (c *stringFormatChecker) Decimal(arg ref.Val, locale string) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType) + if !valid { + return "", decimalFormatError(id, c.typeOf(id).TypeName()) + } + return "", nil +} + +func (c *stringFormatChecker) Fixed(precision *int) func(ref.Val, string) (string, error) { + return func(arg ref.Val, locale string) (string, error) { + id := c.args[c.currArgIndex].ID() + // we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values + valid := c.verifyTypeOneOf(id, types.DoubleType, types.StringType) + if !valid { + return "", fixedPointFormatError(id, c.typeOf(id).TypeName()) + } + return "", nil + } +} + +func (c *stringFormatChecker) Scientific(precision *int) func(ref.Val, string) (string, error) { + return func(arg ref.Val, locale string) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.DoubleType, types.StringType) + if !valid { + return "", scientificFormatError(id, c.typeOf(id).TypeName()) + } + return "", nil + } +} + +func (c *stringFormatChecker) Binary(arg ref.Val, locale string) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.BoolType) + if !valid { + return "", binaryFormatError(id, c.typeOf(id).TypeName()) + } + return "", nil +} + +func (c *stringFormatChecker) Hex(useUpper bool) func(ref.Val, string) (string, error) { + return func(arg ref.Val, locale string) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType, types.StringType, types.BytesType) + if !valid { + return "", hexFormatError(id, c.typeOf(id).TypeName()) + } + return "", nil + } +} + +func (c *stringFormatChecker) Octal(arg ref.Val, locale string) (string, error) { + id := c.args[c.currArgIndex].ID() + valid := c.verifyTypeOneOf(id, types.IntType, types.UintType) + if !valid { + return "", octalFormatError(id, c.typeOf(id).TypeName()) + } + return "", nil +} + +func (c *stringFormatChecker) Arg(index int64) (ref.Val, error) { + c.argsRequested++ + c.currArgIndex = index + // return a dummy value - this is immediately passed to back to us + // through one of the FormatCallback functions, so anything will do + return types.Int(0), nil +} + +func (c *stringFormatChecker) Size() int64 { + return int64(len(c.args)) +} + +func (c *stringFormatChecker) typeOf(id int64) *cel.Type { + t, found := c.typeMap[id] + if found { + return t + } + return cel.DynType +} + +func (c *stringFormatChecker) verifyTypeOneOf(id int64, validTypes ...*cel.Type) bool { + t := c.typeOf(id) + if t == cel.DynType { + return true + } + for _, vt := range validTypes { + // Only check runtime type compatibility without delving deeper into parameterized types + if t.Kind() == vt.Kind() { + return true + } + } + return false +} + +func (c *stringFormatChecker) verifyString(sub ast.NavigableExpr) (bool, int64) { + paramA := cel.TypeParamType("A") + paramB := cel.TypeParamType("B") + subVerified := c.verifyTypeOneOf(sub.ID(), + cel.ListType(paramA), cel.MapType(paramA, paramB), + cel.IntType, cel.UintType, cel.DoubleType, cel.BoolType, cel.StringType, + cel.TimestampType, cel.BytesType, cel.DurationType, cel.TypeType, cel.NullType) + if !subVerified { + return false, sub.ID() + } + if sub.Kind() != ast.ListKind && sub.Kind() != ast.MapKind { + return true, sub.ID() + } + members := sub.Children() + for _, m := range members { + // recursively verify if we're dealing with a list/map + verified, id := c.verifyString(m) + if !verified { + return false, id + } + } + return true, sub.ID() +} + +// helper routines for reporting common errors during string formatting static validation and +// runtime execution. + +func binaryFormatError(id int64, badType string) error { + return newFormatError(id, "only integers and bools can be formatted as binary, was given %s", badType) +} + +func decimalFormatError(id int64, badType string) error { + return newFormatError(id, "decimal clause can only be used on integers, was given %s", badType) +} + +func fixedPointFormatError(id int64, badType string) error { + return newFormatError(id, "fixed-point clause can only be used on doubles, was given %s", badType) +} + +func hexFormatError(id int64, badType string) error { + return newFormatError(id, "only integers, byte buffers, and strings can be formatted as hex, was given %s", badType) +} + +func octalFormatError(id int64, badType string) error { + return newFormatError(id, "octal clause can only be used on integers, was given %s", badType) +} + +func scientificFormatError(id int64, badType string) error { + return newFormatError(id, "scientific clause can only be used on doubles, was given %s", badType) +} + +func stringFormatError(id int64, badType string) error { + return newFormatError(id, "string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", badType) +} + +type formatError struct { + id int64 + msg string +} + +func newFormatError(id int64, msg string, args ...any) error { + return formatError{ + id: id, + msg: fmt.Sprintf(msg, args...), + } +} + +func (e formatError) Error() string { + return e.msg +} + +func (e formatError) Is(target error) bool { + return e.msg == target.Error() +} + +// stringArgList implements the formatListArgs interface. +type stringArgList struct { + args traits.Lister +} + +func (c *stringArgList) Arg(index int64) (ref.Val, error) { + if index >= c.args.Size().Value().(int64) { + return nil, fmt.Errorf("index %d out of range", index) + } + return c.args.Get(types.Int(index)), nil +} + +func (c *stringArgList) Size() int64 { + return c.args.Size().Value().(int64) +} + +// formatStringInterpolator is an interface that allows user-defined behavior +// for formatting clause implementations, as well as argument retrieval. +// Each function is expected to support the appropriate types as laid out in +// the string.format documentation, and to return an error if given an inappropriate type. +type formatStringInterpolator interface { + // String takes a ref.Val and a string representing the current locale identifier + // and returns the Val formatted as a string, or an error if one occurred. + String(ref.Val, string) (string, error) + + // Decimal takes a ref.Val and a string representing the current locale identifier + // and returns the Val formatted as a decimal integer, or an error if one occurred. + Decimal(ref.Val, string) (string, error) + + // Fixed takes an int pointer representing precision (or nil if none was given) and + // returns a function operating in a similar manner to String and Decimal, taking a + // ref.Val and locale and returning the appropriate string. A closure is returned + // so precision can be set without needing an additional function call/configuration. + Fixed(*int) func(ref.Val, string) (string, error) + + // Scientific functions identically to Fixed, except the string returned from the closure + // is expected to be in scientific notation. + Scientific(*int) func(ref.Val, string) (string, error) + + // Binary takes a ref.Val and a string representing the current locale identifier + // and returns the Val formatted as a binary integer, or an error if one occurred. + Binary(ref.Val, string) (string, error) + + // Hex takes a boolean that, if true, indicates the hex string output by the returned + // closure should use uppercase letters for A-F. + Hex(bool) func(ref.Val, string) (string, error) + + // Octal takes a ref.Val and a string representing the current locale identifier and + // returns the Val formatted in octal, or an error if one occurred. + Octal(ref.Val, string) (string, error) +} + +// formatListArgs is an interface that allows user-defined list-like datatypes to be used +// for formatting clause implementations. +type formatListArgs interface { + // Arg returns the ref.Val at the given index, or an error if one occurred. + Arg(int64) (ref.Val, error) + + // Size returns the length of the argument list. + Size() int64 +} + +// parseFormatString formats a string according to the string.format syntax, taking the clause implementations +// from the provided FormatCallback and the args from the given FormatList. +func parseFormatString(formatStr string, callback formatStringInterpolator, list formatListArgs, locale string) (string, error) { + i := 0 + argIndex := 0 + var builtStr strings.Builder + for i < len(formatStr) { + if formatStr[i] == '%' { + if i+1 < len(formatStr) && formatStr[i+1] == '%' { + err := builtStr.WriteByte('%') + if err != nil { + return "", fmt.Errorf("error writing format string: %w", err) + } + i += 2 + continue + } else { + argAny, err := list.Arg(int64(argIndex)) + if err != nil { + return "", err + } + if i+1 >= len(formatStr) { + return "", errors.New("unexpected end of string") + } + if int64(argIndex) >= list.Size() { + return "", fmt.Errorf("index %d out of range", argIndex) + } + numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale) + if refErr != nil { + return "", refErr + } + _, err = builtStr.WriteString(val) + if err != nil { + return "", fmt.Errorf("error writing format string: %w", err) + } + i += numRead + argIndex++ + } + } else { + err := builtStr.WriteByte(formatStr[i]) + if err != nil { + return "", fmt.Errorf("error writing format string: %w", err) + } + i++ + } + } + return builtStr.String(), nil +} + +// parseAndFormatClause parses the format clause at the start of the given string with val, and returns +// how many characters were consumed and the substituted string form of val, or an error if one occurred. +func parseAndFormatClause(formatStr string, val ref.Val, callback formatStringInterpolator, list formatListArgs, locale string) (int, string, error) { + i := 1 + read, formatter, err := parseFormattingClause(formatStr[i:], callback) + i += read + if err != nil { + return -1, "", newParseFormatError("could not parse formatting clause", err) + } + + valStr, err := formatter(val, locale) + if err != nil { + return -1, "", newParseFormatError("error during formatting", err) + } + return i, valStr, nil +} + +func parseFormattingClause(formatStr string, callback formatStringInterpolator) (int, clauseImpl, error) { + i := 0 + read, precision, err := parsePrecision(formatStr[i:]) + i += read + if err != nil { + return -1, nil, fmt.Errorf("error while parsing precision: %w", err) + } + r := rune(formatStr[i]) + i++ + switch r { + case 's': + return i, callback.String, nil + case 'd': + return i, callback.Decimal, nil + case 'f': + return i, callback.Fixed(precision), nil + case 'e': + return i, callback.Scientific(precision), nil + case 'b': + return i, callback.Binary, nil + case 'x', 'X': + return i, callback.Hex(unicode.IsUpper(r)), nil + case 'o': + return i, callback.Octal, nil + default: + return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r) + } +} + +func parsePrecision(formatStr string) (int, *int, error) { + i := 0 + if formatStr[i] != '.' { + return i, nil, nil + } + i++ + var buffer strings.Builder + for { + if i >= len(formatStr) { + return -1, nil, errors.New("could not find end of precision specifier") + } + if !isASCIIDigit(rune(formatStr[i])) { + break + } + buffer.WriteByte(formatStr[i]) + i++ + } + precision, err := strconv.Atoi(buffer.String()) + if err != nil { + return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err) + } + return i, &precision, nil +} + +func isASCIIDigit(r rune) bool { + return r <= unicode.MaxASCII && unicode.IsDigit(r) +} + +type parseFormatError struct { + msg string + wrapped error +} + +func newParseFormatError(msg string, wrapped error) error { + return parseFormatError{msg: msg, wrapped: wrapped} +} + +func (e parseFormatError) Error() string { + return fmt.Sprintf("%s: %s", e.msg, e.wrapped.Error()) +} + +func (e parseFormatError) Is(target error) bool { + return e.Error() == target.Error() +} + +func (e parseFormatError) Unwrap() error { + return e.wrapped +} + +const ( + runtimeID = int64(-1) +) diff --git a/ext/strings.go b/ext/strings.go index 88c119f2b..935c9b6a1 100644 --- a/ext/strings.go +++ b/ext/strings.go @@ -21,19 +21,16 @@ import ( "fmt" "math" "reflect" - "sort" "strings" "unicode" "unicode/utf8" "golang.org/x/text/language" - "golang.org/x/text/message" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" - "github.com/google/cel-go/interpreter" ) const ( @@ -271,7 +268,10 @@ const ( // 'TacoCat'.upperAscii() // returns 'TACOCAT' // 'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII' func Strings(options ...StringsOption) cel.EnvOption { - s := &stringLib{version: math.MaxUint32} + s := &stringLib{ + version: math.MaxUint32, + validateFormat: true, + } for _, o := range options { s = o(s) } @@ -279,8 +279,9 @@ func Strings(options ...StringsOption) cel.EnvOption { } type stringLib struct { - locale string - version uint32 + locale string + version uint32 + validateFormat bool } // LibraryName implements the SingletonLibrary interface method. @@ -317,6 +318,17 @@ func StringsVersion(version uint32) StringsOption { } } +// StringsValidateFormatCalls validates type-checked ASTs to ensure that string.format() calls have +// valid formatting clauses and valid argument types for each clause. +// +// Enabled by default. +func StringsValidateFormatCalls(value bool) StringsOption { + return func(s *stringLib) *stringLib { + s.validateFormat = value + return s + } +} + // CompileOptions implements the Library interface method. func (lib *stringLib) CompileOptions() []cel.EnvOption { formatLocale := "en_US" @@ -440,13 +452,15 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption { cel.FunctionBinding(func(args ...ref.Val) ref.Val { s := string(args[0].(types.String)) formatArgs := args[1].(traits.Lister) - return stringOrError(interpreter.ParseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale)) + return stringOrError(parseFormatString(s, &stringFormatter{}, &stringArgList{formatArgs}, formatLocale)) }))), cel.Function("strings.quote", cel.Overload("strings_quote", []*cel.Type{cel.StringType}, cel.StringType, cel.UnaryBinding(func(str ref.Val) ref.Val { s := str.(types.String) return stringOrError(quote(string(s))) - })))) + }))), + + cel.ASTValidators(stringFormatValidator{})) } if lib.version >= 2 { @@ -471,7 +485,7 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption { cel.UnaryBinding(func(list ref.Val) ref.Val { l, err := list.ConvertToNative(stringListType) if err != nil { - return types.NewErr(err.Error()) + return types.WrapErr(err) } return stringOrError(join(l.([]string))) })), @@ -479,13 +493,16 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption { cel.BinaryBinding(func(list, delim ref.Val) ref.Val { l, err := list.ConvertToNative(stringListType) if err != nil { - return types.NewErr(err.Error()) + return types.WrapErr(err) } d := delim.(types.String) return stringOrError(joinSeparator(l.([]string), string(d))) }))), ) } + if lib.validateFormat { + opts = append(opts, cel.ASTValidators(stringFormatValidator{})) + } return opts } @@ -661,238 +678,6 @@ func joinValSeparator(strs traits.Lister, separator string) (string, error) { return sb.String(), nil } -type clauseImpl func(ref.Val, string) (string, error) - -func clauseForType(argType ref.Type) (clauseImpl, error) { - switch argType { - case types.IntType, types.UintType: - return formatDecimal, nil - case types.StringType, types.BytesType, types.BoolType, types.NullType, types.TypeType: - return FormatString, nil - case types.TimestampType, types.DurationType: - // special case to ensure timestamps/durations get printed as CEL literals - return func(arg ref.Val, locale string) (string, error) { - argStrVal := arg.ConvertToType(types.StringType) - argStr := argStrVal.Value().(string) - if arg.Type() == types.TimestampType { - return fmt.Sprintf("timestamp(%q)", argStr), nil - } - if arg.Type() == types.DurationType { - return fmt.Sprintf("duration(%q)", argStr), nil - } - return "", fmt.Errorf("cannot convert argument of type %s to timestamp/duration", arg.Type().TypeName()) - }, nil - case types.ListType: - return formatList, nil - case types.MapType: - return formatMap, nil - case types.DoubleType: - // avoid formatFixed so we can output a period as the decimal separator in order - // to always be a valid CEL literal - return func(arg ref.Val, locale string) (string, error) { - argDouble, ok := arg.Value().(float64) - if !ok { - return "", fmt.Errorf("couldn't convert %s to float64", arg.Type().TypeName()) - } - fmtStr := fmt.Sprintf("%%.%df", defaultPrecision) - return fmt.Sprintf(fmtStr, argDouble), nil - }, nil - case types.TypeType: - return func(arg ref.Val, locale string) (string, error) { - return fmt.Sprintf("type(%s)", arg.Value().(string)), nil - }, nil - default: - return nil, fmt.Errorf("no formatting function for %s", argType.TypeName()) - } -} - -func formatList(arg ref.Val, locale string) (string, error) { - argList := arg.(traits.Lister) - argIterator := argList.Iterator() - var listStrBuilder strings.Builder - _, err := listStrBuilder.WriteRune('[') - if err != nil { - return "", fmt.Errorf("error writing to list string: %w", err) - } - for argIterator.HasNext() == types.True { - member := argIterator.Next() - memberFormat, err := clauseForType(member.Type()) - if err != nil { - return "", err - } - unquotedStr, err := memberFormat(member, locale) - if err != nil { - return "", err - } - str := quoteForCEL(member, unquotedStr) - _, err = listStrBuilder.WriteString(str) - if err != nil { - return "", fmt.Errorf("error writing to list string: %w", err) - } - if argIterator.HasNext() == types.True { - _, err = listStrBuilder.WriteString(", ") - if err != nil { - return "", fmt.Errorf("error writing to list string: %w", err) - } - } - } - _, err = listStrBuilder.WriteRune(']') - if err != nil { - return "", fmt.Errorf("error writing to list string: %w", err) - } - return listStrBuilder.String(), nil -} - -func formatMap(arg ref.Val, locale string) (string, error) { - argMap := arg.(traits.Mapper) - argIterator := argMap.Iterator() - type mapPair struct { - key string - value string - } - argPairs := make([]mapPair, argMap.Size().Value().(int64)) - i := 0 - for argIterator.HasNext() == types.True { - key := argIterator.Next() - var keyFormat clauseImpl - switch key.Type() { - case types.StringType, types.BoolType: - keyFormat = FormatString - case types.IntType, types.UintType: - keyFormat = formatDecimal - default: - return "", fmt.Errorf("no formatting function for map key of type %s", key.Type().TypeName()) - } - unquotedKeyStr, err := keyFormat(key, locale) - if err != nil { - return "", err - } - keyStr := quoteForCEL(key, unquotedKeyStr) - value, found := argMap.Find(key) - if !found { - return "", fmt.Errorf("could not find key: %q", key) - } - valueFormat, err := clauseForType(value.Type()) - if err != nil { - return "", err - } - unquotedValueStr, err := valueFormat(value, locale) - if err != nil { - return "", err - } - valueStr := quoteForCEL(value, unquotedValueStr) - argPairs[i] = mapPair{keyStr, valueStr} - i++ - } - sort.SliceStable(argPairs, func(x, y int) bool { - return argPairs[x].key < argPairs[y].key - }) - var mapStrBuilder strings.Builder - _, err := mapStrBuilder.WriteRune('{') - if err != nil { - return "", fmt.Errorf("error writing to map string: %w", err) - } - for i, entry := range argPairs { - _, err = mapStrBuilder.WriteString(fmt.Sprintf("%s:%s", entry.key, entry.value)) - if err != nil { - return "", fmt.Errorf("error writing to map string: %w", err) - } - if i < len(argPairs)-1 { - _, err = mapStrBuilder.WriteString(", ") - if err != nil { - return "", fmt.Errorf("error writing to map string: %w", err) - } - } - } - _, err = mapStrBuilder.WriteRune('}') - if err != nil { - return "", fmt.Errorf("error writing to map string: %w", err) - } - return mapStrBuilder.String(), nil -} - -// quoteForCEL takes a formatted, unquoted value and quotes it in a manner -// suitable for embedding directly in CEL. -func quoteForCEL(refVal ref.Val, unquotedValue string) string { - switch refVal.Type() { - case types.StringType: - return fmt.Sprintf("%q", unquotedValue) - case types.BytesType: - return fmt.Sprintf("b%q", unquotedValue) - case types.DoubleType: - // special case to handle infinity/NaN - num := refVal.Value().(float64) - if math.IsInf(num, 1) || math.IsInf(num, -1) || math.IsNaN(num) { - return fmt.Sprintf("%q", unquotedValue) - } - return unquotedValue - default: - return unquotedValue - } -} - -// FormatString returns the string representation of a CEL value. -// It is used to implement the %s specifier in the (string).format() extension -// function. -func FormatString(arg ref.Val, locale string) (string, error) { - switch arg.Type() { - case types.ListType: - return formatList(arg, locale) - case types.MapType: - return formatMap(arg, locale) - case types.IntType, types.UintType, types.DoubleType, - types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType: - argStrVal := arg.ConvertToType(types.StringType) - argStr, ok := argStrVal.Value().(string) - if !ok { - return "", fmt.Errorf("could not convert argument %q to string", argStrVal) - } - return argStr, nil - case types.NullType: - return "null", nil - default: - return "", fmt.Errorf("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given %s", arg.Type().TypeName()) - } -} - -func formatDecimal(arg ref.Val, locale string) (string, error) { - switch arg.Type() { - case types.IntType: - argInt, ok := arg.ConvertToType(types.IntType).Value().(int64) - if !ok { - return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value()) - } - return fmt.Sprintf("%d", argInt), nil - case types.UintType: - argInt, ok := arg.ConvertToType(types.UintType).Value().(uint64) - if !ok { - return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value()) - } - return fmt.Sprintf("%d", argInt), nil - default: - return "", fmt.Errorf("decimal clause can only be used on integers, was given %s", arg.Type().TypeName()) - } -} - -func matchLanguage(locale string) (language.Tag, error) { - matcher, err := makeMatcher(locale) - if err != nil { - return language.Und, err - } - tag, _ := language.MatchStrings(matcher, locale) - return tag, nil -} - -func makeMatcher(locale string) (language.Matcher, error) { - tags := make([]language.Tag, 0) - tag, err := language.Parse(locale) - if err != nil { - return nil, err - } - tags = append(tags, tag) - return language.NewMatcher(tags), nil -} - // quote implements a string quoting function. The string will be wrapped in // double quotes, and all valid CEL escape sequences will be escaped to show up // literally if printed. If the input contains any invalid UTF-8, the invalid runes @@ -940,156 +725,6 @@ func sanitize(s string) string { return sanitizedStringBuilder.String() } -type stringFormatter struct{} - -func (c *stringFormatter) String(arg ref.Val, locale string) (string, error) { - return FormatString(arg, locale) -} - -func (c *stringFormatter) Decimal(arg ref.Val, locale string) (string, error) { - return formatDecimal(arg, locale) -} - -func (c *stringFormatter) Fixed(precision *int) func(ref.Val, string) (string, error) { - if precision == nil { - precision = new(int) - *precision = defaultPrecision - } - return func(arg ref.Val, locale string) (string, error) { - strException := false - if arg.Type() == types.StringType { - argStr := arg.Value().(string) - if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" { - strException = true - } - } - if arg.Type() != types.DoubleType && !strException { - return "", fmt.Errorf("fixed-point clause can only be used on doubles, was given %s", arg.Type().TypeName()) - } - argFloatVal := arg.ConvertToType(types.DoubleType) - argFloat, ok := argFloatVal.Value().(float64) - if !ok { - return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value()) - } - fmtStr := fmt.Sprintf("%%.%df", *precision) - - matchedLocale, err := matchLanguage(locale) - if err != nil { - return "", fmt.Errorf("error matching locale: %w", err) - } - return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil - } -} - -func (c *stringFormatter) Scientific(precision *int) func(ref.Val, string) (string, error) { - if precision == nil { - precision = new(int) - *precision = defaultPrecision - } - return func(arg ref.Val, locale string) (string, error) { - strException := false - if arg.Type() == types.StringType { - argStr := arg.Value().(string) - if argStr == "NaN" || argStr == "Infinity" || argStr == "-Infinity" { - strException = true - } - } - if arg.Type() != types.DoubleType && !strException { - return "", fmt.Errorf("scientific clause can only be used on doubles, was given %s", arg.Type().TypeName()) - } - argFloatVal := arg.ConvertToType(types.DoubleType) - argFloat, ok := argFloatVal.Value().(float64) - if !ok { - return "", fmt.Errorf("could not convert \"%s\" to float64", argFloatVal.Value()) - } - matchedLocale, err := matchLanguage(locale) - if err != nil { - return "", fmt.Errorf("error matching locale: %w", err) - } - fmtStr := fmt.Sprintf("%%%de", *precision) - return message.NewPrinter(matchedLocale).Sprintf(fmtStr, argFloat), nil - } -} - -func (c *stringFormatter) Binary(arg ref.Val, locale string) (string, error) { - switch arg.Type() { - case types.IntType: - argInt := arg.Value().(int64) - // locale is intentionally unused as integers formatted as binary - // strings are locale-independent - return fmt.Sprintf("%b", argInt), nil - case types.UintType: - argInt := arg.Value().(uint64) - return fmt.Sprintf("%b", argInt), nil - case types.BoolType: - argBool := arg.Value().(bool) - if argBool { - return "1", nil - } - return "0", nil - default: - return "", fmt.Errorf("only integers and bools can be formatted as binary, was given %s", arg.Type().TypeName()) - } -} - -func (c *stringFormatter) Hex(useUpper bool) func(ref.Val, string) (string, error) { - return func(arg ref.Val, locale string) (string, error) { - fmtStr := "%x" - if useUpper { - fmtStr = "%X" - } - switch arg.Type() { - case types.StringType, types.BytesType: - if arg.Type() == types.BytesType { - return fmt.Sprintf(fmtStr, arg.Value().([]byte)), nil - } - return fmt.Sprintf(fmtStr, arg.Value().(string)), nil - case types.IntType: - argInt, ok := arg.Value().(int64) - if !ok { - return "", fmt.Errorf("could not convert \"%s\" to int64", arg.Value()) - } - return fmt.Sprintf(fmtStr, argInt), nil - case types.UintType: - argInt, ok := arg.Value().(uint64) - if !ok { - return "", fmt.Errorf("could not convert \"%s\" to uint64", arg.Value()) - } - return fmt.Sprintf(fmtStr, argInt), nil - default: - return "", fmt.Errorf("only integers, byte buffers, and strings can be formatted as hex, was given %s", arg.Type().TypeName()) - } - } -} - -func (c *stringFormatter) Octal(arg ref.Val, locale string) (string, error) { - switch arg.Type() { - case types.IntType: - argInt := arg.Value().(int64) - return fmt.Sprintf("%o", argInt), nil - case types.UintType: - argInt := arg.Value().(uint64) - return fmt.Sprintf("%o", argInt), nil - default: - return "", fmt.Errorf("octal clause can only be used on integers, was given %s", arg.Type().TypeName()) - } -} - -type stringArgList struct { - args traits.Lister -} - -func (c *stringArgList) Arg(index int64) (ref.Val, error) { - if index >= c.args.Size().Value().(int64) { - return nil, fmt.Errorf("index %d out of range", index) - } - return c.args.Get(types.Int(index)), nil -} - -func (c *stringArgList) ArgSize() int64 { - return c.args.Size().Value().(int64) -} - var ( stringListType = reflect.TypeOf([]string{}) ) diff --git a/ext/strings_test.go b/ext/strings_test.go index b8fa2ebeb..002226eb2 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -340,7 +340,7 @@ func TestStrings(t *testing.T) { out.Value(), tc.err, tc.expr) } if !strings.Contains(err.Error(), tc.err) { - t.Errorf("got error %v, wanted error %s for expr: %s", err, tc.err, tc.expr) + t.Errorf("got %q, expected error to contain %q for expr: %s", err, tc.err, tc.expr) } } else if err != nil { t.Fatal(err) @@ -1018,14 +1018,14 @@ func TestStringFormat(t *testing.T) { format: "%s", formatArgs: "[1, 2, ext.TestAllTypes{PbVal: test.TestAllTypes{}}]", skipCompileCheck: true, - err: "error during formatting: no formatting function for ext.TestAllTypes", + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", }, { name: "object inside map", format: "%s", formatArgs: `{1: "a", 2: ext.TestAllTypes{}}`, skipCompileCheck: true, - err: "error during formatting: no formatting function for ext.TestAllTypes", + err: "error during formatting: string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps, was given ext.TestAllTypes", }, { name: "null not allowed for %d", @@ -1117,7 +1117,7 @@ func TestStringFormat(t *testing.T) { name: "compile-time %d check", format: "int is %d", formatArgs: "5.2", - err: "error during formatting: integer clause can only be used on integers", + err: "error during formatting: decimal clause can only be used on integers", }, { name: "compile-time %f check", @@ -1162,7 +1162,7 @@ func TestStringFormat(t *testing.T) { err: "error during formatting: octal clause can only be used on integers", }, } - evalExpr := func(env *cel.Env, expr string, evalArgs any, skipCompileCheck bool, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) { + evalExpr := func(env *cel.Env, expr string, evalArgs any, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) { t.Logf("evaluating expr: %s", expr) parsedAst, issues := env.Parse(expr) if issues.Err() != nil { @@ -1170,23 +1170,16 @@ func TestStringFormat(t *testing.T) { } checkedAst, issues := env.Check(parsedAst) if issues.Err() != nil { - t.Fatalf("env.Check(%v) failed: %v", expr, issues.Err()) + return nil, issues.Err() } evalOpts := make([]cel.ProgramOption, 0) - if !skipCompileCheck { - evalOpts = append(evalOpts, cel.EvalOptions(cel.OptCheckStringFormat)) - } costTracker := &noopCostEstimator{} if expectedRuntimeCost != 0 { evalOpts = append(evalOpts, cel.CostTracking(costTracker)) } program, err := env.Program(checkedAst, evalOpts...) if err != nil { - if skipCompileCheck { - t.Fatal(err) - } else { - return nil, err - } + return nil, err } actualEstimatedCost, err := env.EstimateCost(checkedAst, costTracker) @@ -1254,36 +1247,36 @@ func TestStringFormat(t *testing.T) { } return opts } - buildOpts := func(locale string, variables []cel.EnvOption) []cel.EnvOption { + buildOpts := func(skipCompileCheck bool, locale string, variables []cel.EnvOption) []cel.EnvOption { opts := []cel.EnvOption{ - Strings(StringsLocale(locale)), + Strings(StringsLocale(locale), StringsValidateFormatCalls(!skipCompileCheck)), cel.Container("ext"), cel.Abbrevs("google.expr.proto3.test"), cel.Types(&proto3pb.TestAllTypes{}), - cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals()), NativeTypes( reflect.TypeOf(&TestNestedType{}), reflect.ValueOf(&TestAllTypes{}), ), } + opts = append(opts, cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals())) opts = append(opts, variables...) return opts } runCase := func(format, formatArgs, locale string, dynArgs map[string]any, skipCompileCheck bool, expectedRuntimeCost uint64, expectedEstimatedCost checker.CostEstimate, t *testing.T) (ref.Val, error) { - env, err := cel.NewEnv(buildOpts(locale, buildVariables(dynArgs))...) + env, err := cel.NewEnv(buildOpts(skipCompileCheck, locale, buildVariables(dynArgs))...) if err != nil { t.Fatalf("cel.NewEnv() failed: %v", err) } expr := fmt.Sprintf("%q.format([%s])", format, formatArgs) if len(dynArgs) == 0 { - return evalExpr(env, expr, cel.NoVars(), skipCompileCheck, expectedRuntimeCost, expectedEstimatedCost, t) + return evalExpr(env, expr, cel.NoVars(), expectedRuntimeCost, expectedEstimatedCost, t) } - return evalExpr(env, expr, dynArgs, skipCompileCheck, expectedRuntimeCost, expectedEstimatedCost, t) + return evalExpr(env, expr, dynArgs, expectedRuntimeCost, expectedEstimatedCost, t) } checkCase := func(output ref.Val, expectedOutput string, err error, expectedErr string, t *testing.T) { if err != nil { if expectedErr != "" { - if err.Error() != expectedErr { + if !strings.Contains(err.Error(), expectedErr) { t.Fatalf("expected %q as error message, got %q", expectedErr, err.Error()) } } else { diff --git a/interpreter/BUILD.bazel b/interpreter/BUILD.bazel index 3a5219eb5..220e23d47 100644 --- a/interpreter/BUILD.bazel +++ b/interpreter/BUILD.bazel @@ -14,7 +14,6 @@ go_library( "decorators.go", "dispatcher.go", "evalstate.go", - "formatting.go", "interpretable.go", "interpreter.go", "optimizations.go", diff --git a/interpreter/formatting.go b/interpreter/formatting.go deleted file mode 100644 index e3f753374..000000000 --- a/interpreter/formatting.go +++ /dev/null @@ -1,383 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package interpreter - -import ( - "errors" - "fmt" - "strconv" - "strings" - "unicode" - - "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" -) - -type typeVerifier func(int64, ...ref.Type) (bool, error) - -// InterpolateFormattedString checks the syntax and cardinality of any string.format calls present in the expression and reports -// any errors at compile time. -func InterpolateFormattedString(verifier typeVerifier) InterpretableDecorator { - return func(inter Interpretable) (Interpretable, error) { - call, ok := inter.(InterpretableCall) - if !ok { - return inter, nil - } - if call.OverloadID() != "string_format" { - return inter, nil - } - args := call.Args() - if len(args) != 2 { - return nil, fmt.Errorf("wrong number of arguments to string.format (expected 2, got %d)", len(args)) - } - fmtStrInter, ok := args[0].(InterpretableConst) - if !ok { - return inter, nil - } - var fmtArgsInter InterpretableConstructor - fmtArgsInter, ok = args[1].(InterpretableConstructor) - if !ok { - return inter, nil - } - if fmtArgsInter.Type() != types.ListType { - // don't necessarily return an error since the list may be DynType - return inter, nil - } - formatStr := fmtStrInter.Value().Value().(string) - initVals := fmtArgsInter.InitVals() - - formatCheck := &formatCheck{ - args: initVals, - verifier: verifier, - } - // use a placeholder locale, since locale doesn't affect syntax - _, err := ParseFormatString(formatStr, formatCheck, formatCheck, "en_US") - if err != nil { - return nil, err - } - seenArgs := formatCheck.argsRequested - if len(initVals) > seenArgs { - return nil, fmt.Errorf("too many arguments supplied to string.format (expected %d, got %d)", seenArgs, len(initVals)) - } - return inter, nil - } -} - -type formatCheck struct { - args []Interpretable - argsRequested int - curArgIndex int64 - enableCheckArgTypes bool - verifier typeVerifier -} - -func (c *formatCheck) String(arg ref.Val, locale string) (string, error) { - valid, err := verifyString(c.args[c.curArgIndex], c.verifier) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("string clause can only be used on strings, bools, bytes, ints, doubles, maps, lists, types, durations, and timestamps") - } - return "", nil -} - -func (c *formatCheck) Decimal(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.IntType, types.UintType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("integer clause can only be used on integers") - } - return "", nil -} - -func (c *formatCheck) Fixed(precision *int) func(ref.Val, string) (string, error) { - return func(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - // we allow StringType since "NaN", "Infinity", and "-Infinity" are also valid values - valid, err := c.verifier(id, types.DoubleType, types.StringType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("fixed-point clause can only be used on doubles") - } - return "", nil - } -} - -func (c *formatCheck) Scientific(precision *int) func(ref.Val, string) (string, error) { - return func(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.DoubleType, types.StringType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("scientific clause can only be used on doubles") - } - return "", nil - } -} - -func (c *formatCheck) Binary(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.IntType, types.UintType, types.BoolType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("only integers and bools can be formatted as binary") - } - return "", nil -} - -func (c *formatCheck) Hex(useUpper bool) func(ref.Val, string) (string, error) { - return func(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.IntType, types.UintType, types.StringType, types.BytesType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("only integers, byte buffers, and strings can be formatted as hex") - } - return "", nil - } -} - -func (c *formatCheck) Octal(arg ref.Val, locale string) (string, error) { - id := c.args[c.curArgIndex].ID() - valid, err := c.verifier(id, types.IntType, types.UintType) - if err != nil { - return "", err - } - if !valid { - return "", errors.New("octal clause can only be used on integers") - } - return "", nil -} - -func (c *formatCheck) Arg(index int64) (ref.Val, error) { - c.argsRequested++ - c.curArgIndex = index - // return a dummy value - this is immediately passed to back to us - // through one of the FormatCallback functions, so anything will do - return types.Int(0), nil -} - -func (c *formatCheck) ArgSize() int64 { - return int64(len(c.args)) -} - -func verifyString(sub Interpretable, verifier typeVerifier) (bool, error) { - subVerified, err := verifier(sub.ID(), - types.ListType, types.MapType, types.IntType, types.UintType, types.DoubleType, - types.BoolType, types.StringType, types.TimestampType, types.BytesType, types.DurationType, types.TypeType, types.NullType) - if err != nil { - return false, err - } - if !subVerified { - return false, nil - } - con, ok := sub.(InterpretableConstructor) - if ok { - members := con.InitVals() - for _, m := range members { - // recursively verify if we're dealing with a list/map - verified, err := verifyString(m, verifier) - if err != nil { - return false, err - } - if !verified { - return false, nil - } - } - } - return true, nil - -} - -// FormatStringInterpolator is an interface that allows user-defined behavior -// for formatting clause implementations, as well as argument retrieval. -// Each function is expected to support the appropriate types as laid out in -// the string.format documentation, and to return an error if given an inappropriate type. -type FormatStringInterpolator interface { - // String takes a ref.Val and a string representing the current locale identifier - // and returns the Val formatted as a string, or an error if one occurred. - String(ref.Val, string) (string, error) - - // Decimal takes a ref.Val and a string representing the current locale identifier - // and returns the Val formatted as a decimal integer, or an error if one occurred. - Decimal(ref.Val, string) (string, error) - - // Fixed takes an int pointer representing precision (or nil if none was given) and - // returns a function operating in a similar manner to String and Decimal, taking a - // ref.Val and locale and returning the appropriate string. A closure is returned - // so precision can be set without needing an additional function call/configuration. - Fixed(*int) func(ref.Val, string) (string, error) - - // Scientific functions identically to Fixed, except the string returned from the closure - // is expected to be in scientific notation. - Scientific(*int) func(ref.Val, string) (string, error) - - // Binary takes a ref.Val and a string representing the current locale identifier - // and returns the Val formatted as a binary integer, or an error if one occurred. - Binary(ref.Val, string) (string, error) - - // Hex takes a boolean that, if true, indicates the hex string output by the returned - // closure should use uppercase letters for A-F. - Hex(bool) func(ref.Val, string) (string, error) - - // Octal takes a ref.Val and a string representing the current locale identifier and - // returns the Val formatted in octal, or an error if one occurred. - Octal(ref.Val, string) (string, error) -} - -// FormatList is an interface that allows user-defined list-like datatypes to be used -// for formatting clause implementations. -type FormatList interface { - // Arg returns the ref.Val at the given index, or an error if one occurred. - Arg(int64) (ref.Val, error) - // ArgSize returns the length of the argument list. - ArgSize() int64 -} - -type clauseImpl func(ref.Val, string) (string, error) - -// ParseFormatString formats a string according to the string.format syntax, taking the clause implementations -// from the provided FormatCallback and the args from the given FormatList. -func ParseFormatString(formatStr string, callback FormatStringInterpolator, list FormatList, locale string) (string, error) { - i := 0 - argIndex := 0 - var builtStr strings.Builder - for i < len(formatStr) { - if formatStr[i] == '%' { - if i+1 < len(formatStr) && formatStr[i+1] == '%' { - err := builtStr.WriteByte('%') - if err != nil { - return "", fmt.Errorf("error writing format string: %w", err) - } - i += 2 - continue - } else { - argAny, err := list.Arg(int64(argIndex)) - if err != nil { - return "", err - } - if i+1 >= len(formatStr) { - return "", errors.New("unexpected end of string") - } - if int64(argIndex) >= list.ArgSize() { - return "", fmt.Errorf("index %d out of range", argIndex) - } - numRead, val, refErr := parseAndFormatClause(formatStr[i:], argAny, callback, list, locale) - if refErr != nil { - return "", refErr - } - _, err = builtStr.WriteString(val) - if err != nil { - return "", fmt.Errorf("error writing format string: %w", err) - } - i += numRead - argIndex++ - } - } else { - err := builtStr.WriteByte(formatStr[i]) - if err != nil { - return "", fmt.Errorf("error writing format string: %w", err) - } - i++ - } - } - return builtStr.String(), nil -} - -// parseAndFormatClause parses the format clause at the start of the given string with val, and returns -// how many characters were consumed and the substituted string form of val, or an error if one occurred. -func parseAndFormatClause(formatStr string, val ref.Val, callback FormatStringInterpolator, list FormatList, locale string) (int, string, error) { - i := 1 - read, formatter, err := parseFormattingClause(formatStr[i:], callback) - i += read - if err != nil { - return -1, "", fmt.Errorf("could not parse formatting clause: %s", err) - } - - valStr, err := formatter(val, locale) - if err != nil { - return -1, "", fmt.Errorf("error during formatting: %s", err) - } - return i, valStr, nil -} - -func parseFormattingClause(formatStr string, callback FormatStringInterpolator) (int, clauseImpl, error) { - i := 0 - read, precision, err := parsePrecision(formatStr[i:]) - i += read - if err != nil { - return -1, nil, fmt.Errorf("error while parsing precision: %w", err) - } - r := rune(formatStr[i]) - i++ - switch r { - case 's': - return i, callback.String, nil - case 'd': - return i, callback.Decimal, nil - case 'f': - return i, callback.Fixed(precision), nil - case 'e': - return i, callback.Scientific(precision), nil - case 'b': - return i, callback.Binary, nil - case 'x', 'X': - return i, callback.Hex(unicode.IsUpper(r)), nil - case 'o': - return i, callback.Octal, nil - default: - return -1, nil, fmt.Errorf("unrecognized formatting clause \"%c\"", r) - } -} - -func parsePrecision(formatStr string) (int, *int, error) { - i := 0 - if formatStr[i] != '.' { - return i, nil, nil - } - i++ - var buffer strings.Builder - for { - if i >= len(formatStr) { - return -1, nil, errors.New("could not find end of precision specifier") - } - if !isASCIIDigit(rune(formatStr[i])) { - break - } - buffer.WriteByte(formatStr[i]) - i++ - } - precision, err := strconv.Atoi(buffer.String()) - if err != nil { - return -1, nil, fmt.Errorf("error while converting precision to integer: %w", err) - } - return i, &precision, nil -} - -func isASCIIDigit(r rune) bool { - return r <= unicode.MaxASCII && unicode.IsDigit(r) -} From 5359cfdef2be0ce8039efbe5670b96dd9e45852e Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 31 Jul 2023 13:19:21 -0400 Subject: [PATCH 03/31] Nil safety checks for cel.Ast (#784) --- cel/env.go | 18 +++++++++++++++--- cel/env_test.go | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/cel/env.go b/cel/env.go index 5cbb86a75..e8d6e2ea8 100644 --- a/cel/env.go +++ b/cel/env.go @@ -47,16 +47,25 @@ type Ast struct { // Expr returns the proto serializable instance of the parsed/checked expression. func (ast *Ast) Expr() *exprpb.Expr { + if ast == nil { + return nil + } return ast.expr } // IsChecked returns whether the Ast value has been successfully type-checked. func (ast *Ast) IsChecked() bool { + if ast == nil { + return false + } return ast.typeMap != nil && len(ast.typeMap) > 0 } // SourceInfo returns character offset and newline position information about expression elements. func (ast *Ast) SourceInfo() *exprpb.SourceInfo { + if ast == nil { + return nil + } return ast.info } @@ -65,9 +74,6 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo { // // Deprecated: use OutputType func (ast *Ast) ResultType() *exprpb.Type { - if !ast.IsChecked() { - return chkdecls.Dyn - } out := ast.OutputType() t, err := TypeToExprType(out) if err != nil { @@ -79,6 +85,9 @@ func (ast *Ast) ResultType() *exprpb.Type { // OutputType returns the output type of the expression if the Ast has been type-checked, else // returns cel.DynType as the parse step cannot infer types. func (ast *Ast) OutputType() *Type { + if ast == nil { + return types.ErrorType + } t, found := ast.typeMap[ast.expr.GetId()] if !found { return DynType @@ -89,6 +98,9 @@ func (ast *Ast) OutputType() *Type { // Source returns a view of the input used to create the Ast. This source may be complete or // constructed from the SourceInfo. func (ast *Ast) Source() Source { + if ast == nil { + return nil + } return ast.source } diff --git a/cel/env_test.go b/cel/env_test.go index c8f7ebee6..663727961 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -29,6 +29,25 @@ import ( exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) +func TestAstNil(t *testing.T) { + var ast *Ast + if ast.IsChecked() { + t.Error("ast.IsChecked() returned true for nil ast") + } + if ast.Expr() != nil { + t.Errorf("ast.Expr() got %v, wanted nil", ast.Expr()) + } + if ast.SourceInfo() != nil { + t.Errorf("ast.SourceInfo() got %v, wanted nil", ast.SourceInfo()) + } + if ast.OutputType() != types.ErrorType { + t.Errorf("ast.OutputType() got %v, wanted error type", ast.OutputType()) + } + if ast.Source() != nil { + t.Errorf("ast.Source() got %v, wanted nil", ast.Source()) + } +} + func TestIssuesNil(t *testing.T) { var iss *Issues iss = iss.Append(iss) From 20720f3f21a9be7bc082d7a06a7494b78497eaff Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Fri, 4 Aug 2023 11:06:42 -0400 Subject: [PATCH 04/31] Fix cost estimates to propagate result sizes. (#787) --- checker/cost.go | 26 +++++++++++++++++++++++--- checker/cost_test.go | 16 ++++++++++++++++ 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/checker/cost.go b/checker/cost.go index d6a799a8d..d02f2628a 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -535,14 +535,34 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args if est := c.estimator.EstimateCallCost(function, overloadID, target, args); est != nil { callEst := *est - return CallEstimate{CostEstimate: callEst.Add(argCostSum())} + return CallEstimate{CostEstimate: callEst.Add(argCostSum()), ResultSize: est.ResultSize} } switch overloadID { // O(n) functions - case overloads.StartsWithString, overloads.EndsWithString, overloads.StringToBytes, overloads.BytesToString, overloads.ExtQuoteString, overloads.ExtFormatString: - if overloadID == overloads.ExtFormatString { + case overloads.ExtFormatString: + if target != nil { + // ResultSize not calculated because we can't bound the max size. return CallEstimate{CostEstimate: c.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} } + case overloads.StringToBytes: + if len(args) == 1 { + sz := c.sizeEstimate(args[0]) + // ResultSize max is when each char converts to 4 bytes. + return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min, Max: sz.Max * 4}} + } + case overloads.BytesToString: + if len(args) == 1 { + sz := c.sizeEstimate(args[0]) + // ResultSize min is when 4 bytes convert to 1 char. + return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min / 4, Max: sz.Max}} + } + case overloads.ExtQuoteString: + if len(args) == 1 { + sz := c.sizeEstimate(args[0]) + // ResultSize max is when each char is escaped. 2 quote chars always added. + return CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum()), ResultSize: &SizeEstimate{Min: sz.Min + 2, Max: sz.Max*2 + 2}} + } + case overloads.StartsWithString, overloads.EndsWithString: if len(args) == 1 { return CallEstimate{CostEstimate: c.sizeEstimate(args[0]).MultiplyByCostFactor(common.StringTraversalCostFactor).Add(argCostSum())} } diff --git a/checker/cost_test.go b/checker/cost_test.go index c5ac6fc1b..9b47167ed 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -261,6 +261,14 @@ func TestCost(t *testing.T) { expr: `string(input)`, wanted: CostEstimate{Min: 1, Max: 51}, }, + { + name: "bytes to string conversion equality", + decls: []*exprpb.Decl{decls.NewVar("input", decls.Bytes)}, + hints: map[string]int64{"input": 500}, + // equality check ensures that the resultSize calculation is included in cost + expr: `string(input) == string(input)`, + wanted: CostEstimate{Min: 3, Max: 152}, + }, { name: "string to bytes conversion", vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)}, @@ -268,6 +276,14 @@ func TestCost(t *testing.T) { expr: `bytes(input)`, wanted: CostEstimate{Min: 1, Max: 51}, }, + { + name: "string to bytes conversion equality", + decls: []*exprpb.Decl{decls.NewVar("input", decls.String)}, + hints: map[string]int64{"input": 500}, + // equality check ensures that the resultSize calculation is included in cost + expr: `bytes(input) == bytes(input)`, + wanted: CostEstimate{Min: 3, Max: 302}, + }, { name: "int to string conversion", expr: `string(1)`, From a6388a3a20eb094b1ed1dfda1273d26a2655df50 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 9 Aug 2023 17:07:13 -0400 Subject: [PATCH 05/31] Split Expr from NavigableExpr with interpreter support (#788) * Split Expr from NavigableExpr with interpreter support * Handle errors during proto transformation in the type-checker --- cel/env.go | 41 +- cel/io.go | 32 +- cel/program.go | 37 +- cel/validator.go | 37 +- checker/checker.go | 30 +- checker/checker_test.go | 8 +- checker/cost.go | 20 +- common/ast/BUILD.bazel | 7 + common/ast/ast.go | 382 ++++++++----- common/ast/ast_test.go | 220 ++++---- common/ast/conversion.go | 615 +++++++++++++++++++++ common/ast/conversion_test.go | 298 ++++++++++ common/ast/expr.go | 949 ++++++++++++++++++-------------- common/ast/expr_test.go | 691 ++++++++++++----------- common/ast/factory.go | 193 +++++++ common/ast/navigable.go | 489 ++++++++++++++++ common/ast/navigable_test.go | 446 +++++++++++++++ common/types/provider.go | 4 +- ext/formatting.go | 106 ++-- interpreter/interpreter.go | 26 +- interpreter/interpreter_test.go | 147 +++-- interpreter/planner.go | 279 ++++------ interpreter/prune_test.go | 18 +- 23 files changed, 3685 insertions(+), 1390 deletions(-) create mode 100644 common/ast/conversion.go create mode 100644 common/ast/conversion_test.go create mode 100644 common/ast/factory.go create mode 100644 common/ast/navigable.go create mode 100644 common/ast/navigable_test.go diff --git a/cel/env.go b/cel/env.go index e8d6e2ea8..49f79ed2d 100644 --- a/cel/env.go +++ b/cel/env.go @@ -38,9 +38,9 @@ type Source = common.Source // Ast representing the checked or unchecked expression, its source, and related metadata such as // source position information. type Ast struct { - expr *exprpb.Expr - info *exprpb.SourceInfo source Source + expr celast.Expr + info *celast.SourceInfo refMap map[int64]*celast.ReferenceInfo typeMap map[int64]*types.Type } @@ -50,7 +50,8 @@ func (ast *Ast) Expr() *exprpb.Expr { if ast == nil { return nil } - return ast.expr + pbExpr, _ := celast.ExprToProto(ast.expr) + return pbExpr } // IsChecked returns whether the Ast value has been successfully type-checked. @@ -66,7 +67,8 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo { if ast == nil { return nil } - return ast.info + pbInfo, _ := celast.SourceInfoToProto(ast.info) + return pbInfo } // ResultType returns the output type of the expression if the Ast has been type-checked, else @@ -88,7 +90,7 @@ func (ast *Ast) OutputType() *Type { if ast == nil { return types.ErrorType } - t, found := ast.typeMap[ast.expr.GetId()] + t, found := ast.typeMap[ast.expr.ID()] if !found { return DynType } @@ -227,10 +229,10 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { // detailed than the information provided by Check), is returned to the caller. ast = &Ast{ source: ast.Source(), - expr: res.Expr, - info: res.SourceInfo, - refMap: res.ReferenceMap, - typeMap: res.TypeMap} + expr: res.Expr(), + info: res.SourceInfo(), + refMap: res.ReferenceMap(), + typeMap: res.TypeMap()} // Generate a validator configuration from the set of configured validators. vConfig := newValidatorConfig() @@ -433,10 +435,20 @@ func (e *Env) ParseSource(src Source) (*Ast, *Issues) { } // Manually create the Ast to ensure that the text source information is propagated on // subsequent calls to Check. + expr, err := celast.ProtoToExpr(res.GetExpr()) + if err != nil { + errs.ReportError(common.NoLocation, "invalid parsed expression: %v", err) + return nil, &Issues{errs: errs} + } + info, err := celast.ProtoToSourceInfo(res.GetSourceInfo()) + if err != nil { + errs.ReportError(common.NoLocation, "invalid source info: %v", err) + return nil, &Issues{errs: errs} + } return &Ast{ source: src, - expr: res.GetExpr(), - info: res.GetSourceInfo()}, nil + expr: expr, + info: info}, nil } // Program generates an evaluable instance of the Ast within the environment (Env). @@ -554,12 +566,7 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { // EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and // extension functions provided by estimator. func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) { - checked := &celast.CheckedAST{ - Expr: ast.Expr(), - SourceInfo: ast.SourceInfo(), - TypeMap: ast.typeMap, - ReferenceMap: ast.refMap, - } + checked := celast.NewCheckedAST(celast.NewAST(ast.expr, ast.info), ast.typeMap, ast.refMap) return checker.Cost(checked, estimator, opts...) } diff --git a/cel/io.go b/cel/io.go index 80f63140e..5107f564d 100644 --- a/cel/io.go +++ b/cel/io.go @@ -47,16 +47,16 @@ func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast { // // Prefer CheckedExprToAst if loading expressions from storage. func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) { - checkedAST, err := ast.CheckedExprToCheckedAST(checkedExpr) + checkedAST, err := ast.ToAST(checkedExpr) if err != nil { return nil, err } return &Ast{ - expr: checkedAST.Expr, - info: checkedAST.SourceInfo, source: src, - refMap: checkedAST.ReferenceMap, - typeMap: checkedAST.TypeMap, + expr: checkedAST.Expr(), + info: checkedAST.SourceInfo(), + refMap: checkedAST.ReferenceMap(), + typeMap: checkedAST.TypeMap(), }, nil } @@ -67,13 +67,11 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) { if !a.IsChecked() { return nil, fmt.Errorf("cannot convert unchecked ast") } - cAst := &ast.CheckedAST{ - Expr: a.expr, - SourceInfo: a.info, - ReferenceMap: a.refMap, - TypeMap: a.typeMap, + checked, err := astToExprAST(a) + if err != nil { + return nil, err } - return ast.CheckedASTToCheckedExpr(cAst) + return ast.ToProto(checked) } // ParsedExprToAst converts a parsed expression proto message to an Ast. @@ -89,16 +87,14 @@ func ParsedExprToAst(parsedExpr *exprpb.ParsedExpr) *Ast { // // Prefer ParsedExprToAst if loading expressions from storage. func ParsedExprToAstWithSource(parsedExpr *exprpb.ParsedExpr, src Source) *Ast { - si := parsedExpr.GetSourceInfo() - if si == nil { - si = &exprpb.SourceInfo{} - } + info, _ := ast.ProtoToSourceInfo(parsedExpr.GetSourceInfo()) if src == nil { - src = common.NewInfoSource(si) + src = common.NewInfoSource(parsedExpr.GetSourceInfo()) } + e, _ := ast.ProtoToExpr(parsedExpr.GetExpr()) return &Ast{ - expr: parsedExpr.GetExpr(), - info: si, + expr: e, + info: info, source: src, } } diff --git a/cel/program.go b/cel/program.go index 60fbfca46..2f829a2a3 100644 --- a/cel/program.go +++ b/cel/program.go @@ -243,20 +243,12 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) { } func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) { - // Unchecked programs do not contain type and reference information and may be slower to execute. - if !a.IsChecked() { - interpretable, err := - p.interpreter.NewUncheckedInterpretable(a.Expr(), decs...) - if err != nil { - return nil, err - } - p.interpretable = interpretable - return p, nil + // When the AST has been exprAST it contains metadata that can be used to speed up program execution. + exprAST, err := astToExprAST(a) + if err != nil { + return nil, err } - - // When the AST has been checked it contains metadata that can be used to speed up program execution. - checked := astToCheckedAST(a) - interpretable, err := p.interpreter.NewInterpretable(checked, decs...) + interpretable, err := p.interpreter.NewInterpretable(exprAST, decs...) if err != nil { return nil, err } @@ -525,13 +517,20 @@ func (p *evalActivationPool) Put(value any) { p.Pool.Put(a) } -func astToCheckedAST(a *Ast) *ast.CheckedAST { - return &ast.CheckedAST{ - Expr: a.Expr(), - SourceInfo: a.SourceInfo(), - TypeMap: a.typeMap, - ReferenceMap: a.refMap, +func astToExprAST(a *Ast) (*ast.AST, error) { + expr, err := ast.ProtoToExpr(a.Expr()) + if err != nil { + return nil, err + } + info, err := ast.ProtoToSourceInfo(a.SourceInfo()) + if err != nil { + return nil, err + } + baseAST := ast.NewAST(expr, info) + if !a.IsChecked() { + return baseAST, nil } + return ast.NewCheckedAST(baseAST, a.typeMap, a.refMap), nil } var ( diff --git a/cel/validator.go b/cel/validator.go index b5877d76f..59df3f4c3 100644 --- a/cel/validator.go +++ b/cel/validator.go @@ -21,8 +21,6 @@ import ( "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/overloads" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) const ( @@ -69,7 +67,7 @@ type ASTValidator interface { // // See individual validators for more information on their configuration keys and configuration // properties. - Validate(*Env, ValidatorConfig, *ast.CheckedAST, *Issues) + Validate(*Env, ValidatorConfig, *ast.AST, *Issues) } // ValidatorConfig provides an accessor method for querying validator configuration state. @@ -180,7 +178,7 @@ func ValidateComprehensionNestingLimit(limit int) ASTValidator { return nestingLimitValidator{limit: limit} } -type argChecker func(env *Env, call, arg ast.NavigableExpr) error +type argChecker func(env *Env, call, arg ast.Expr) error func newFormatValidator(funcName string, argNum int, check argChecker) formatValidator { return formatValidator{ @@ -203,8 +201,8 @@ func (v formatValidator) Name() string { // Validate searches the AST for uses of a given function name with a constant argument and performs a check // on whether the argument is a valid literal value. -func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) { - root := ast.NavigateCheckedAST(a) +func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) { + root := ast.NavigateAST(a) funcCalls := ast.MatchDescendants(root, ast.FunctionMatcher(v.funcName)) for _, call := range funcCalls { callArgs := call.AsCall().Args() @@ -221,8 +219,12 @@ func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, } } -func evalCall(env *Env, call, arg ast.NavigableExpr) error { - ast := ParsedExprToAst(&exprpb.ParsedExpr{Expr: call.ToExpr()}) +func evalCall(env *Env, call, arg ast.Expr) error { + ast := &Ast{ + expr: call, + info: ast.NewSourceInfo(nil), + source: nil, + } prg, err := env.Program(ast) if err != nil { return err @@ -231,7 +233,7 @@ func evalCall(env *Env, call, arg ast.NavigableExpr) error { return err } -func compileRegex(_ *Env, _, arg ast.NavigableExpr) error { +func compileRegex(_ *Env, _, arg ast.Expr) error { pattern := arg.AsLiteral().Value().(string) _, err := regexp.Compile(pattern) return err @@ -248,10 +250,10 @@ func (homogeneousAggregateLiteralValidator) Name() string { // // This validator makes an exception for list and map literals which occur at any level of nesting within // string format calls. -func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.CheckedAST, iss *Issues) { +func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) { var exemptedFunctions []string exemptedFunctions = c.GetOrDefault(HomogeneousAggregateLiteralExemptFunctions, exemptedFunctions).([]string) - root := ast.NavigateCheckedAST(a) + root := ast.NavigateAST(a) listExprs := ast.MatchDescendants(root, ast.KindMatcher(ast.ListKind)) for _, listExpr := range listExprs { if inExemptFunction(listExpr, exemptedFunctions) { @@ -262,7 +264,7 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig optIndices := l.OptionalIndices() var elemType *Type for i, e := range elements { - et := e.Type() + et := a.GetType(e.ID()) if isOptionalIndex(i, optIndices) { et = et.Parameters()[0] } @@ -285,9 +287,10 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig entries := m.Entries() var keyType, valType *Type for _, e := range entries { - key, val := e.Key(), e.Value() - kt, vt := key.Type(), val.Type() - if e.IsOptional() { + mapEntry := e.AsMapEntry() + key, val := mapEntry.Key(), mapEntry.Value() + kt, vt := a.GetType(key.ID()), a.GetType(val.ID()) + if mapEntry.IsOptional() { vt = vt.Parameters()[0] } if keyType == nil && valType == nil { @@ -342,8 +345,8 @@ func (v nestingLimitValidator) Name() string { return "cel.lib.std.validate.comprehension_nesting_limit" } -func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.CheckedAST, iss *Issues) { - root := ast.NavigateCheckedAST(a) +func (v nestingLimitValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Issues) { + root := ast.NavigateAST(a) comprehensions := ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) if len(comprehensions) <= v.limit { return diff --git a/checker/checker.go b/checker/checker.go index 720e4fa96..20eaa050b 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -40,12 +40,14 @@ type checker struct { } // Check performs type checking, giving a typed AST. -// The input is a ParsedExpr proto and an env which encapsulates -// type binding of variables, declarations of built-in functions, -// descriptions of protocol buffers, and a registry for errors. -// Returns a CheckedExpr proto, which might not be usable if -// there are errors in the error registry. -func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast.CheckedAST, *common.Errors) { +// +// The input is a ParsedExpr proto and an env which encapsulates type binding of variables, +// declarations of built-in functions, descriptions of protocol buffers, and a registry for +// errors. +// +// Returns a type-checked AST, which might not be usable if there are errors in the error +// registry. +func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast.AST, *common.Errors) { errs := common.NewErrors(source) c := checker{ env: env, @@ -64,13 +66,15 @@ func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast. for id, t := range c.types { m[id] = substitute(c.mappings, t, true) } - - return &ast.CheckedAST{ - Expr: parsedExpr.GetExpr(), - SourceInfo: parsedExpr.GetSourceInfo(), - TypeMap: m, - ReferenceMap: c.references, - }, errs + e, err := ast.ProtoToExpr(parsedExpr.GetExpr()) + if err != nil { + errs.ReportError(common.NoLocation, err.Error()) + } + info, err := ast.ProtoToSourceInfo(parsedExpr.GetSourceInfo()) + if err != nil { + errs.ReportError(common.NoLocation, err.Error()) + } + return ast.NewCheckedAST(ast.NewAST(e, info), m, c.references), errs } func (c *checker) check(e *exprpb.Expr) { diff --git a/checker/checker_test.go b/checker/checker_test.go index 032bb5eaf..e2ae6e801 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -2350,7 +2350,7 @@ func TestCheck(t *testing.T) { t.Errorf("Expected error not thrown: %s", tc.err) } - actual := cAst.TypeMap[pAst.Expr.Id] + actual := cAst.GetType(pAst.Expr.Id) if tc.err == "" { if actual == nil || !actual.IsEquivalentType(tc.outType) { t.Error(test.DiffMessage("Type Error", actual, tc.outType)) @@ -2358,7 +2358,7 @@ func TestCheck(t *testing.T) { } if tc.out != "" { - chkExpr, err := ast.CheckedASTToCheckedExpr(cAst) + chkExpr, err := ast.ToProto(cAst) if err != nil { t.Fatalf("CheckedAstToCheckedExpr() failed: %v", err) } @@ -2445,7 +2445,7 @@ func BenchmarkCheck(b *testing.B) { b.Errorf("Expected error not thrown: %s", tc.err) } - actual := cAst.TypeMap[pAst.Expr.Id] + actual := cAst.GetType(pAst.Expr.Id) if tc.err == "" { if actual == nil || !actual.IsEquivalentType(tc.outType) { b.Error(test.DiffMessage("Type Error", actual, tc.outType)) @@ -2453,7 +2453,7 @@ func BenchmarkCheck(b *testing.B) { } if tc.out != "" { - chkExpr, err := ast.CheckedASTToCheckedExpr(cAst) + chkExpr, err := ast.ToProto(cAst) if err != nil { b.Fatalf("CheckedAstToCheckedExpr() failed: %v", err) } diff --git a/checker/cost.go b/checker/cost.go index d02f2628a..30d477c82 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -261,7 +261,7 @@ type coster struct { iterRanges iterRangeScopes // computedSizes tracks the computed sizes of call results. computedSizes map[int64]SizeEstimate - checkedAST *ast.CheckedAST + checkedAST *ast.AST estimator CostEstimator // presenceTestCost will either be a zero or one based on whether has() macros count against cost computations. presenceTestCost CostEstimate @@ -304,7 +304,7 @@ func PresenceTestHasCost(hasCost bool) CostOption { } // Cost estimates the cost of the parsed and type checked CEL expression. -func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) { +func Cost(checker *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) { c := &coster{ checkedAST: checker, estimator: estimator, @@ -319,7 +319,11 @@ func Cost(checker *ast.CheckedAST, estimator CostEstimator, opts ...CostOption) return CostEstimate{}, err } } - return c.cost(checker.Expr), nil + epb, err := ast.ExprToProto(checker.Expr()) + if err != nil { + return CostEstimate{}, err + } + return c.cost(epb), nil } func (c *coster) cost(e *exprpb.Expr) CostEstimate { @@ -353,7 +357,7 @@ func (c *coster) costIdent(e *exprpb.Expr) CostEstimate { // build and track the field path if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok { - switch c.checkedAST.TypeMap[iterRange].Kind() { + switch c.checkedAST.GetType(iterRange).Kind() { case types.ListKind: c.addPath(e, append(c.exprPath[iterRange], "@items")) case types.MapKind: @@ -405,8 +409,8 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate { argTypes[i] = c.newAstNode(arg) } - ref := c.checkedAST.ReferenceMap[e.GetId()] - if ref == nil || len(ref.OverloadIDs) == 0 { + overloadIDs := c.checkedAST.GetOverloadIDs(e.GetId()) + if len(overloadIDs) == 0 { return CostEstimate{} } var targetType AstNode @@ -419,7 +423,7 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate { // Pick a cost estimate range that covers all the overload cost estimation ranges fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0} var resultSize *SizeEstimate - for _, overload := range ref.OverloadIDs { + for _, overload := range overloadIDs { overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts) fnCost = fnCost.Union(overloadCost.CostEstimate) if overloadCost.ResultSize != nil { @@ -644,7 +648,7 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args } func (c *coster) getType(e *exprpb.Expr) *types.Type { - return c.checkedAST.TypeMap[e.GetId()] + return c.checkedAST.GetType(e.GetId()) } func (c *coster) getPath(e *exprpb.Expr) []string { diff --git a/common/ast/BUILD.bazel b/common/ast/BUILD.bazel index 002327b9b..a29f201fa 100644 --- a/common/ast/BUILD.bazel +++ b/common/ast/BUILD.bazel @@ -15,10 +15,14 @@ go_library( name = "go_default_library", srcs = [ "ast.go", + "conversion.go", "expr.go", + "factory.go", + "navigable.go", ], importpath = "github.com/google/cel-go/common/ast", deps = [ + "//common:go_default_library", "//common/types:go_default_library", "//common/types/ref:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", @@ -30,7 +34,9 @@ go_test( name = "go_default_test", srcs = [ "ast_test.go", + "conversion_test.go", "expr_test.go", + "navigable_test.go", ], embed = [ ":go_default_library", @@ -49,5 +55,6 @@ go_test( "//test/proto3pb:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", + "@org_golang_google_protobuf//encoding/prototext:go_default_library", ], ) \ No newline at end of file diff --git a/common/ast/ast.go b/common/ast/ast.go index b3c150793..652272789 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -16,74 +16,264 @@ package ast import ( - "fmt" - + "github.com/google/cel-go/common" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" +) - structpb "google.golang.org/protobuf/types/known/structpb" +// AST contains a protobuf expression and source info along with CEL-native type and reference information. +type AST struct { + expr Expr + sourceInfo *SourceInfo + typeMap map[int64]*types.Type + refMap map[int64]*ReferenceInfo + checked bool +} - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" -) +// Expr returns the root ast.Expr value in the AST. +func (a *AST) Expr() Expr { + if a == nil { + return nilExpr + } + return a.expr +} -// CheckedAST contains a protobuf expression and source info along with CEL-native type and reference information. -type CheckedAST struct { - Expr *exprpb.Expr - SourceInfo *exprpb.SourceInfo - TypeMap map[int64]*types.Type - ReferenceMap map[int64]*ReferenceInfo -} - -// CheckedASTToCheckedExpr converts a CheckedAST to a CheckedExpr protobouf. -func CheckedASTToCheckedExpr(ast *CheckedAST) (*exprpb.CheckedExpr, error) { - refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap)) - for id, ref := range ast.ReferenceMap { - r, err := ReferenceInfoToReferenceExpr(ref) - if err != nil { - return nil, err - } - refMap[id] = r +// SourceInfo returns the source metadata associated with the parse / type-check passes. +func (a *AST) SourceInfo() *SourceInfo { + if a == nil { + return nil } - typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap)) - for id, typ := range ast.TypeMap { - t, err := types.TypeToExprType(typ) - if err != nil { - return nil, err - } - typeMap[id] = t - } - return &exprpb.CheckedExpr{ - Expr: ast.Expr, - SourceInfo: ast.SourceInfo, - ReferenceMap: refMap, - TypeMap: typeMap, - }, nil -} - -// CheckedExprToCheckedAST converts a CheckedExpr protobuf to a CheckedAST instance. -func CheckedExprToCheckedAST(checked *exprpb.CheckedExpr) (*CheckedAST, error) { - refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap())) - for id, ref := range checked.GetReferenceMap() { - r, err := ReferenceExprToReferenceInfo(ref) - if err != nil { - return nil, err + return a.sourceInfo +} + +// GetType returns the type for the expression at the given id, if one exists, else types.DynType. +func (a *AST) GetType(id int64) *types.Type { + if t, found := a.TypeMap()[id]; found { + return t + } + return types.DynType +} + +// TypeMap returns the map of expression ids to type-checked types. +// +// If the AST is not type-checked, the map will be empty. +func (a *AST) TypeMap() map[int64]*types.Type { + if a == nil { + return map[int64]*types.Type{} + } + return a.typeMap +} + +// GetOverloadIDs returns the set of overload function names for a given expression id. +// +// If the expression id is not a function call, or the AST is not type-checked, the result will be empty. +func (a *AST) GetOverloadIDs(id int64) []string { + if ref, found := a.ReferenceMap()[id]; found { + return ref.OverloadIDs + } + return []string{} +} + +// ReferenceMap returns the map of expression id to identifier, constant, and function references. +func (a *AST) ReferenceMap() map[int64]*ReferenceInfo { + if a == nil { + return map[int64]*ReferenceInfo{} + } + return a.refMap +} + +// IsChecked returns whether the AST is type-checked. +func (a *AST) IsChecked() bool { + return a != nil && a.checked +} + +// NewAST creates a base AST instance with an ast.Expr and ast.SourceInfo value. +func NewAST(e Expr, sourceInfo *SourceInfo) *AST { + if e == nil { + e = nilExpr + } + return &AST{ + expr: e, + sourceInfo: sourceInfo, + typeMap: make(map[int64]*types.Type), + refMap: make(map[int64]*ReferenceInfo), + checked: false, + } +} + +// NewCheckedAST wraps an parsed AST and augments it with type and reference metadata. +func NewCheckedAST(parsed *AST, typeMap map[int64]*types.Type, refMap map[int64]*ReferenceInfo) *AST { + return &AST{ + expr: parsed.Expr(), + sourceInfo: parsed.SourceInfo(), + typeMap: typeMap, + refMap: refMap, + checked: len(typeMap) != 0 || len(refMap) != 0, + } +} + +// NewSourceInfo creates a simple SourceInfo object from an input common.Source value. +func NewSourceInfo(src common.Source) *SourceInfo { + var lineOffsets []int32 + var desc string + if src != nil { + desc = src.Description() + lineOffsets = src.LineOffsets() + } + return &SourceInfo{ + desc: desc, + lines: lineOffsets, + offsetRanges: make(map[int64]OffsetRange), + macroCalls: make(map[int64]Expr), + } +} + +// SourceInfo records basic information about the expression as a textual input and +// as a parsed expression value. +type SourceInfo struct { + syntax string + desc string + lines []int32 + offsetRanges map[int64]OffsetRange + macroCalls map[int64]Expr +} + +// SyntaxVersion returns the syntax version associated with the text expression. +func (s *SourceInfo) SyntaxVersion() string { + if s == nil { + return "" + } + return s.syntax +} + +// Description provides information about where the expression came from. +func (s *SourceInfo) Description() string { + if s == nil { + return "" + } + return s.desc +} + +// LineOffsets returns a list of the 0-based character offsets in the input text where newlines appear. +func (s *SourceInfo) LineOffsets() []int32 { + if s == nil { + return []int32{} + } + return s.lines +} + +// MacroCalls returns a map of expression id to ast.Expr value where the id represents the expression +// node where the macro was inserted into the AST, and the ast.Expr value represents the original call +// signature which was replaced. +func (s *SourceInfo) MacroCalls() map[int64]Expr { + if s == nil { + return map[int64]Expr{} + } + return s.macroCalls +} + +// GetMacroCall returns the original ast.Expr value for the given expression if it was generated via +// a macro replacement. +// +// Note, parsing options must be enabled to track macro calls before this method will return a value. +func (s *SourceInfo) GetMacroCall(id int64) (Expr, bool) { + e, found := s.MacroCalls()[id] + return e, found +} + +// SetMacroCall records a macro call at a specific location. +func (s *SourceInfo) SetMacroCall(id int64, e Expr) { + if s == nil { + return + } + s.macroCalls[id] = e +} + +// OffsetRanges returns a map of expression id to OffsetRange values where the range indicates either: +// the start and end position in the input stream where the expression occurs, or the start position +// only. If the range only captures start position, the stop position of the range will be equal to +// the start. +func (s *SourceInfo) OffsetRanges() map[int64]OffsetRange { + if s == nil { + return map[int64]OffsetRange{} + } + return s.offsetRanges +} + +// GetOffsetRange retrieves an OffsetRange for the given expression id if one exists. +func (s *SourceInfo) GetOffsetRange(id int64) (OffsetRange, bool) { + if s == nil { + return OffsetRange{}, false + } + o, found := s.offsetRanges[id] + return o, found +} + +// SetOffsetRange sets the OffsetRange for the given expression id. +func (s *SourceInfo) SetOffsetRange(id int64, o OffsetRange) { + if s == nil { + return + } + s.offsetRanges[id] = o +} + +// GetStartLocation calculates the human-readable 1-based line and 0-based column of the first character +// of the expression node at the id. +func (s *SourceInfo) GetStartLocation(id int64) common.Location { + if o, found := s.GetOffsetRange(id); found { + line := 1 + col := int(o.Start) + for _, lineOffset := range s.LineOffsets() { + if lineOffset < o.Start { + line++ + col = int(o.Start - lineOffset) + } else { + break + } } - refMap[id] = r + return common.NewLocation(line, col) } - typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap())) - for id, typ := range checked.GetTypeMap() { - t, err := types.ExprTypeToType(typ) - if err != nil { - return nil, err + return common.NoLocation +} + +// GetStopLocation calculates the human-readable 1-based line and 0-based column of the last character for +// the expression node at the given id. +// +// If the SourceInfo was generated from a serialized protobuf representation, the stop location will +// be identical to the start location for the expression. +func (s *SourceInfo) GetStopLocation(id int64) common.Location { + if o, found := s.GetOffsetRange(id); found { + line := 1 + col := int(o.Stop) + for _, lineOffset := range s.LineOffsets() { + if lineOffset < o.Stop { + line++ + col = int(o.Stop - lineOffset) + } else { + break + } } - typeMap[id] = t + return common.NewLocation(line, col) } - return &CheckedAST{ - Expr: checked.GetExpr(), - SourceInfo: checked.GetSourceInfo(), - ReferenceMap: refMap, - TypeMap: typeMap, - }, nil + return common.NoLocation +} + +// ComputeOffset calculates the 0-based character offset from a 1-based line and 0-based column. +func (s *SourceInfo) ComputeOffset(line, col int32) int32 { + if line == 1 { + return col + } + if line < 1 || line > int32(len(s.LineOffsets())) { + return -1 + } + offset := s.LineOffsets()[line-2] + return offset + col +} + +// OffsetRange captures the start and stop positions of a section of text in the input expression. +type OffsetRange struct { + Start int32 + Stop int32 } // ReferenceInfo contains a CEL native representation of an identifier reference which may refer to @@ -148,79 +338,3 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool { } return true } - -// ReferenceInfoToReferenceExpr converts a ReferenceInfo instance to a protobuf Reference suitable for serialization. -func ReferenceInfoToReferenceExpr(info *ReferenceInfo) (*exprpb.Reference, error) { - c, err := ValToConstant(info.Value) - if err != nil { - return nil, err - } - return &exprpb.Reference{ - Name: info.Name, - OverloadId: info.OverloadIDs, - Value: c, - }, nil -} - -// ReferenceExprToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance. -func ReferenceExprToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) { - v, err := ConstantToVal(ref.GetValue()) - if err != nil { - return nil, err - } - return &ReferenceInfo{ - Name: ref.GetName(), - OverloadIDs: ref.GetOverloadId(), - Value: v, - }, nil -} - -// ValToConstant converts a CEL-native ref.Val to a protobuf Constant. -// -// Only simple scalar types are supported by this method. -func ValToConstant(v ref.Val) (*exprpb.Constant, error) { - if v == nil { - return nil, nil - } - switch v.Type() { - case types.BoolType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil - case types.BytesType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil - case types.DoubleType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil - case types.IntType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil - case types.NullType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil - case types.StringType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil - case types.UintType: - return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil - } - return nil, fmt.Errorf("unsupported constant kind: %v", v.Type()) -} - -// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val. -func ConstantToVal(c *exprpb.Constant) (ref.Val, error) { - if c == nil { - return nil, nil - } - switch c.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - return types.Bool(c.GetBoolValue()), nil - case *exprpb.Constant_BytesValue: - return types.Bytes(c.GetBytesValue()), nil - case *exprpb.Constant_DoubleValue: - return types.Double(c.GetDoubleValue()), nil - case *exprpb.Constant_Int64Value: - return types.Int(c.GetInt64Value()), nil - case *exprpb.Constant_NullValue: - return types.NullValue, nil - case *exprpb.Constant_StringValue: - return types.String(c.GetStringValue()), nil - case *exprpb.Constant_Uint64Value: - return types.Uint(c.GetUint64Value()), nil - } - return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind()) -} diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go index 1e0409145..eef521827 100644 --- a/common/ast/ast_test.go +++ b/common/ast/ast_test.go @@ -15,71 +15,139 @@ package ast_test import ( + "fmt" "reflect" "testing" - "time" - "google.golang.org/protobuf/proto" - - chkdecls "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) -func TestConvertAST(t *testing.T) { - goAST := &ast.CheckedAST{ - Expr: &exprpb.Expr{}, - SourceInfo: &exprpb.SourceInfo{}, - TypeMap: map[int64]*types.Type{ - 1: types.BoolType, - 2: types.DynType, - }, - ReferenceMap: map[int64]*ast.ReferenceInfo{ - 1: ast.NewFunctionReference(overloads.LogicalNot), - 2: ast.NewIdentReference("TRUE", types.True), - }, +func TestASTNilSafety(t *testing.T) { + ex, err := ast.ProtoToExpr(nil) + if err != nil { + t.Fatalf("ast.ProtoToExpr() failed: %v", err) } - - exprAST := &exprpb.CheckedExpr{ - Expr: &exprpb.Expr{}, - SourceInfo: &exprpb.SourceInfo{}, - TypeMap: map[int64]*exprpb.Type{ - 1: chkdecls.Bool, - 2: chkdecls.Dyn, - }, - ReferenceMap: map[int64]*exprpb.Reference{ - 1: {OverloadId: []string{overloads.LogicalNot}}, - 2: { - Name: "TRUE", - Value: &exprpb.Constant{ - ConstantKind: &exprpb.Constant_BoolValue{BoolValue: true}, - }, - }, - }, + info, err := ast.ProtoToSourceInfo(nil) + if err != nil { + t.Fatalf("ast.ProtoToSourceInfo() failed: %v", err) + } + tests := []*ast.AST{ + nil, + ast.NewAST(nil, nil), + ast.NewCheckedAST(nil, nil, nil), + ast.NewCheckedAST(ast.NewAST(nil, nil), nil, nil), + ast.NewAST(ex, info), + ast.NewCheckedAST(ast.NewAST(ex, info), map[int64]*types.Type{}, map[int64]*ast.ReferenceInfo{}), + } + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + testAST := tc + if testAST.Expr().ID() != 0 { + t.Errorf("Expr().ID() got %v, wanted 0", testAST.Expr().ID()) + } + if testAST.SourceInfo().SyntaxVersion() != "" { + t.Errorf("SourceInfo().SyntaxVersion() got %s, wanted empty string", testAST.SourceInfo().SyntaxVersion()) + } + if testAST.IsChecked() { + t.Error("IsChecked() returned true, wanted false") + } + if testAST.GetType(testAST.Expr().ID()) != types.DynType { + t.Errorf("GetType() got %v, wanted dyn", testAST.GetType(testAST.Expr().ID())) + } + if len(testAST.GetOverloadIDs(testAST.Expr().ID())) != 0 { + t.Errorf("GetOverloadIDs() got %v, wanted empty set", testAST.GetOverloadIDs(testAST.Expr().ID())) + } + }) } +} - checkedAST, err := ast.CheckedExprToCheckedAST(exprAST) - if err != nil { - t.Fatalf("CheckedExprToCheckedAST() failed: %v", err) +func TestSourceInfo(t *testing.T) { + src := common.NewStringSource("a\n? b\n: c", "custom description") + info := ast.NewSourceInfo(src) + if info.Description() != "custom description" { + t.Errorf("Description() got %s, wanted 'custom description'", info.Description()) + } + if len(info.LineOffsets()) != 3 { + t.Errorf("LineOffsets() got %v, wanted 3 offsets", info.LineOffsets()) + } + info.SetOffsetRange(1, ast.OffsetRange{Start: 0, Stop: 1}) // a + info.SetOffsetRange(2, ast.OffsetRange{Start: 4, Stop: 5}) // b + info.SetOffsetRange(3, ast.OffsetRange{Start: 8, Stop: 9}) // c + if !reflect.DeepEqual(info.GetStartLocation(1), common.NewLocation(1, 0)) { + t.Errorf("info.GetStartLocation(1) got %v, wanted line 1, col 0", info.GetStartLocation(1)) + } + if !reflect.DeepEqual(info.GetStopLocation(1), common.NewLocation(1, 1)) { + t.Errorf("info.GetStopLocation(1) got %v, wanted line 1, col 1", info.GetStopLocation(1)) } - if !reflect.DeepEqual(checkedAST.ReferenceMap, goAST.ReferenceMap) || - !reflect.DeepEqual(checkedAST.TypeMap, goAST.TypeMap) { - t.Errorf("conversion to AST did not produce identical results: got %v, wanted %v", checkedAST, goAST) + if !reflect.DeepEqual(info.GetStartLocation(2), common.NewLocation(2, 2)) { + t.Errorf("info.GetStartLocation(2) got %v, wanted line 2, col 2", info.GetStartLocation(2)) } - if !checkedAST.ReferenceMap[1].Equals(goAST.ReferenceMap[1]) || - !checkedAST.ReferenceMap[2].Equals(goAST.ReferenceMap[2]) { - t.Error("converted reference info values not equal") + if !reflect.DeepEqual(info.GetStopLocation(2), common.NewLocation(2, 3)) { + t.Errorf("info.GetStopLocation(2) got %v, wanted line 2, col 3", info.GetStopLocation(2)) } - checkedExpr, err := ast.CheckedASTToCheckedExpr(goAST) + if !reflect.DeepEqual(info.GetStartLocation(3), common.NewLocation(3, 2)) { + t.Errorf("info.GetStartLocation(3) got %v, wanted line 2, col 2", info.GetStartLocation(3)) + } + if !reflect.DeepEqual(info.GetStopLocation(3), common.NewLocation(3, 3)) { + t.Errorf("info.GetStopLocation(3) got %v, wanted line 2, col 3", info.GetStopLocation(3)) + } + if info.ComputeOffset(3, 2) != 8 { + t.Errorf("info.ComputeOffset(3, 2) got %d, wanted 8", info.ComputeOffset(3, 2)) + } +} + +func TestSourceInfoNilSafety(t *testing.T) { + info, err := ast.ProtoToSourceInfo(nil) if err != nil { - t.Fatalf("CheckedASTToCheckedExpr() failed: %v", err) + t.Fatalf("ast.ProtoToSourceInfo() failed: %v", err) } - if !proto.Equal(checkedExpr, exprAST) { - t.Errorf("conversion to protobuf did not produce identical results: got %v, wanted %v", checkedExpr, exprAST) + tests := []*ast.SourceInfo{ + nil, + info, + ast.NewSourceInfo(nil), + } + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + testInfo := tc + if testInfo.SyntaxVersion() != "" { + t.Errorf("SyntaxVersion() got %s, wanted empty string", testInfo.SyntaxVersion()) + } + if testInfo.Description() != "" { + t.Errorf("Description() got %s, wanted empty string", testInfo.Description()) + } + if len(testInfo.LineOffsets()) != 0 { + t.Errorf("LineOffsets() got %v, wanted empty list", testInfo.LineOffsets()) + } + if len(testInfo.MacroCalls()) != 0 { + t.Errorf("MacroCalls() got %v, wanted empty map", testInfo.MacroCalls()) + } + if call, found := testInfo.GetMacroCall(0); found { + t.Errorf("GetMacroCall(0) got %v, wanted not found", call) + } + if r, found := testInfo.GetOffsetRange(0); found { + t.Errorf("GetOffsetRange(0) got %v, wanted not found", r) + } + if loc := testInfo.GetStartLocation(0); loc != common.NoLocation { + t.Errorf("GetStartLocation(0) got %v, wanted no location", loc) + } + if loc := testInfo.GetStopLocation(0); loc != common.NoLocation { + t.Errorf("GetStopLocation(0) got %v, wanted no location", loc) + } + if off := testInfo.ComputeOffset(1, 0); off != 0 { + t.Errorf("ComputeOffset(1, 0) got %d, wanted 0", off) + } + if off := testInfo.ComputeOffset(-2, 0); off != -1 { + t.Errorf("ComputeOffset(-2, 0) got %d, wanted -1", off) + } + if off := testInfo.ComputeOffset(2, 0); off != -1 { + t.Errorf("ComputeOffset(2, 0) got %d, wanted -1", off) + } + }) } } @@ -173,57 +241,3 @@ func TestReferenceInfoAddOverload(t *testing.T) { t.Error("repeated AddOverload() did not produce equal references") } } - -func TestReferenceInfoToReferenceExprError(t *testing.T) { - out, err := ast.ReferenceInfoToReferenceExpr( - ast.NewIdentReference("SECOND", types.Duration{Duration: time.Duration(1) * time.Second})) - if err == nil { - t.Errorf("ReferenceInfoToReferenceExpr() got %v, wanted error", out) - } -} - -func TestReferenceExprToReferenceInfoError(t *testing.T) { - out, err := ast.ReferenceExprToReferenceInfo(&exprpb.Reference{Value: &exprpb.Constant{}}) - if err == nil { - t.Errorf("ReferenceExprToReferenceInfo() got %v, wanted error", out) - } -} - -func TestConvertVal(t *testing.T) { - tests := []ref.Val{ - types.True, - types.Bytes("bytes"), - types.Double(3.2), - types.Int(-1), - types.NullValue, - types.String("string"), - types.Uint(27), - } - for _, tst := range tests { - c, err := ast.ValToConstant(tst) - if err != nil { - t.Errorf("ValToConstant(%v) failed: %v", tst, err) - } - v, err := ast.ConstantToVal(c) - if err != nil { - t.Errorf("ValToConstant(%v) failed: %v", c, err) - } - if tst.Equal(v) != types.True { - t.Errorf("roundtrip from %v to %v and back did not produce equal results, got %v, wanted %v", tst, c, v, tst) - } - } -} - -func TestValToConstantError(t *testing.T) { - out, err := ast.ValToConstant(types.Duration{Duration: time.Duration(10)}) - if err == nil { - t.Errorf("ValToConstant() got %v, wanted error", out) - } -} - -func TestConstantToValError(t *testing.T) { - out, err := ast.ConstantToVal(&exprpb.Constant{}) - if err == nil { - t.Errorf("ConstantToVal() got %v, wanted error", out) - } -} diff --git a/common/ast/conversion.go b/common/ast/conversion.go new file mode 100644 index 000000000..d8516b187 --- /dev/null +++ b/common/ast/conversion.go @@ -0,0 +1,615 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import ( + "fmt" + + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + + structpb "google.golang.org/protobuf/types/known/structpb" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +// ToProto converts an AST to a CheckedExpr protobouf. +func ToProto(ast *AST) (*exprpb.CheckedExpr, error) { + refMap := make(map[int64]*exprpb.Reference, len(ast.ReferenceMap())) + for id, ref := range ast.ReferenceMap() { + r, err := ReferenceInfoToProto(ref) + if err != nil { + return nil, err + } + refMap[id] = r + } + typeMap := make(map[int64]*exprpb.Type, len(ast.TypeMap())) + for id, typ := range ast.TypeMap() { + t, err := types.TypeToExprType(typ) + if err != nil { + return nil, err + } + typeMap[id] = t + } + e, err := ExprToProto(ast.Expr()) + if err != nil { + return nil, err + } + info, err := SourceInfoToProto(ast.SourceInfo()) + if err != nil { + return nil, err + } + return &exprpb.CheckedExpr{ + Expr: e, + SourceInfo: info, + ReferenceMap: refMap, + TypeMap: typeMap, + }, nil +} + +// ToAST converts a CheckedExpr protobuf to an AST instance. +func ToAST(checked *exprpb.CheckedExpr) (*AST, error) { + refMap := make(map[int64]*ReferenceInfo, len(checked.GetReferenceMap())) + for id, ref := range checked.GetReferenceMap() { + r, err := ProtoToReferenceInfo(ref) + if err != nil { + return nil, err + } + refMap[id] = r + } + typeMap := make(map[int64]*types.Type, len(checked.GetTypeMap())) + for id, typ := range checked.GetTypeMap() { + t, err := types.ExprTypeToType(typ) + if err != nil { + return nil, err + } + typeMap[id] = t + } + info, err := ProtoToSourceInfo(checked.GetSourceInfo()) + if err != nil { + return nil, err + } + root, err := ProtoToExpr(checked.GetExpr()) + if err != nil { + return nil, err + } + ast := NewCheckedAST(NewAST(root, info), typeMap, refMap) + return ast, nil +} + +// ProtoToExpr converts a protobuf Expr value to an ast.Expr value. +func ProtoToExpr(e *exprpb.Expr) (Expr, error) { + factory := NewExprFactory() + return exprInternal(factory, e) +} + +func exprInternal(factory ExprFactory, e *exprpb.Expr) (Expr, error) { + id := e.GetId() + switch e.GetExprKind().(type) { + case *exprpb.Expr_CallExpr: + return exprCall(factory, id, e.GetCallExpr()) + case *exprpb.Expr_ComprehensionExpr: + return exprComprehension(factory, id, e.GetComprehensionExpr()) + case *exprpb.Expr_ConstExpr: + return exprLiteral(factory, id, e.GetConstExpr()) + case *exprpb.Expr_IdentExpr: + return exprIdent(factory, id, e.GetIdentExpr()) + case *exprpb.Expr_ListExpr: + return exprList(factory, id, e.GetListExpr()) + case *exprpb.Expr_SelectExpr: + return exprSelect(factory, id, e.GetSelectExpr()) + case *exprpb.Expr_StructExpr: + s := e.GetStructExpr() + if s.GetMessageName() != "" { + return exprStruct(factory, id, s) + } + return exprMap(factory, id, s) + } + return factory.NewUnspecifiedExpr(id), nil +} + +func exprCall(factory ExprFactory, id int64, call *exprpb.Expr_Call) (Expr, error) { + var err error + args := make([]Expr, len(call.GetArgs())) + for i, a := range call.GetArgs() { + args[i], err = exprInternal(factory, a) + if err != nil { + return nil, err + } + } + if call.GetTarget() == nil { + return factory.NewCall(id, call.GetFunction(), args...), nil + } + + target, err := exprInternal(factory, call.GetTarget()) + if err != nil { + return nil, err + } + return factory.NewMemberCall(id, call.GetFunction(), target, args...), nil +} + +func exprComprehension(factory ExprFactory, id int64, comp *exprpb.Expr_Comprehension) (Expr, error) { + iterRange, err := exprInternal(factory, comp.GetIterRange()) + if err != nil { + return nil, err + } + accuInit, err := exprInternal(factory, comp.GetAccuInit()) + if err != nil { + return nil, err + } + loopCond, err := exprInternal(factory, comp.GetLoopCondition()) + if err != nil { + return nil, err + } + loopStep, err := exprInternal(factory, comp.GetLoopStep()) + if err != nil { + return nil, err + } + result, err := exprInternal(factory, comp.GetResult()) + if err != nil { + return nil, err + } + return factory.NewComprehension(id, + iterRange, + comp.GetIterVar(), + comp.GetAccuVar(), + accuInit, + loopCond, + loopStep, + result), nil +} + +func exprLiteral(factory ExprFactory, id int64, c *exprpb.Constant) (Expr, error) { + val, err := ConstantToVal(c) + if err != nil { + return nil, err + } + return factory.NewLiteral(id, val), nil +} + +func exprIdent(factory ExprFactory, id int64, i *exprpb.Expr_Ident) (Expr, error) { + return factory.NewIdent(id, i.GetName()), nil +} + +func exprList(factory ExprFactory, id int64, l *exprpb.Expr_CreateList) (Expr, error) { + elems := make([]Expr, len(l.GetElements())) + for i, e := range l.GetElements() { + elem, err := exprInternal(factory, e) + if err != nil { + return nil, err + } + elems[i] = elem + } + return factory.NewList(id, elems, l.GetOptionalIndices()), nil +} + +func exprMap(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) { + entries := make([]EntryExpr, len(s.GetEntries())) + var err error + for i, entry := range s.GetEntries() { + entries[i], err = exprMapEntry(factory, entry.GetId(), entry) + if err != nil { + return nil, err + } + } + return factory.NewMap(id, entries), nil +} + +func exprMapEntry(factory ExprFactory, id int64, e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) { + k, err := exprInternal(factory, e.GetMapKey()) + if err != nil { + return nil, err + } + v, err := exprInternal(factory, e.GetValue()) + if err != nil { + return nil, err + } + return factory.NewMapEntry(id, k, v, e.GetOptionalEntry()), nil +} + +func exprSelect(factory ExprFactory, id int64, s *exprpb.Expr_Select) (Expr, error) { + op, err := exprInternal(factory, s.GetOperand()) + if err != nil { + return nil, err + } + if s.GetTestOnly() { + return factory.NewPresenceTest(id, op, s.GetField()), nil + } + return factory.NewSelect(id, op, s.GetField()), nil +} + +func exprStruct(factory ExprFactory, id int64, s *exprpb.Expr_CreateStruct) (Expr, error) { + fields := make([]EntryExpr, len(s.GetEntries())) + var err error + for i, field := range s.GetEntries() { + fields[i], err = exprStructField(factory, field.GetId(), field) + if err != nil { + return nil, err + } + } + return factory.NewStruct(id, s.GetMessageName(), fields), nil +} + +func exprStructField(factory ExprFactory, id int64, f *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) { + v, err := exprInternal(factory, f.GetValue()) + if err != nil { + return nil, err + } + return factory.NewStructField(id, f.GetFieldKey(), v, f.GetOptionalEntry()), nil +} + +// ExprToProto serializes an ast.Expr value to a protobuf Expr representation. +func ExprToProto(e Expr) (*exprpb.Expr, error) { + if e == nil { + return &exprpb.Expr{}, nil + } + switch e.Kind() { + case CallKind: + return protoCall(e.ID(), e.AsCall()) + case ComprehensionKind: + return protoComprehension(e.ID(), e.AsComprehension()) + case IdentKind: + return protoIdent(e.ID(), e.AsIdent()) + case ListKind: + return protoList(e.ID(), e.AsList()) + case LiteralKind: + return protoLiteral(e.ID(), e.AsLiteral()) + case MapKind: + return protoMap(e.ID(), e.AsMap()) + case SelectKind: + return protoSelect(e.ID(), e.AsSelect()) + case StructKind: + return protoStruct(e.ID(), e.AsStruct()) + case UnspecifiedExprKind: + return &exprpb.Expr{}, nil + } + return nil, fmt.Errorf("unsupported expr kind: %v", e) +} + +// EntryExprToProto converts an ast.EntryExpr to a protobuf CreateStruct entry +func EntryExprToProto(e EntryExpr) (*exprpb.Expr_CreateStruct_Entry, error) { + switch e.Kind() { + case MapEntryKind: + return protoMapEntry(e.ID(), e.AsMapEntry()) + case StructFieldKind: + return protoStructField(e.ID(), e.AsStructField()) + case UnspecifiedEntryExprKind: + return &exprpb.Expr_CreateStruct_Entry{}, nil + } + return nil, fmt.Errorf("unsupported expr entry kind: %v", e) +} + +func protoCall(id int64, call CallExpr) (*exprpb.Expr, error) { + var err error + var target *exprpb.Expr + if call.IsMemberFunction() { + target, err = ExprToProto(call.Target()) + if err != nil { + return nil, err + } + } + callArgs := call.Args() + args := make([]*exprpb.Expr, len(callArgs)) + for i, a := range callArgs { + args[i], err = ExprToProto(a) + if err != nil { + return nil, err + } + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_CallExpr{ + CallExpr: &exprpb.Expr_Call{ + Function: call.FunctionName(), + Target: target, + Args: args, + }, + }, + }, nil +} + +func protoComprehension(id int64, comp ComprehensionExpr) (*exprpb.Expr, error) { + iterRange, err := ExprToProto(comp.IterRange()) + if err != nil { + return nil, err + } + accuInit, err := ExprToProto(comp.AccuInit()) + if err != nil { + return nil, err + } + loopCond, err := ExprToProto(comp.LoopCondition()) + if err != nil { + return nil, err + } + loopStep, err := ExprToProto(comp.LoopStep()) + if err != nil { + return nil, err + } + result, err := ExprToProto(comp.Result()) + if err != nil { + return nil, err + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_ComprehensionExpr{ + ComprehensionExpr: &exprpb.Expr_Comprehension{ + IterVar: comp.IterVar(), + IterRange: iterRange, + AccuVar: comp.AccuVar(), + AccuInit: accuInit, + LoopCondition: loopCond, + LoopStep: loopStep, + Result: result, + }, + }, + }, nil +} + +func protoIdent(id int64, name string) (*exprpb.Expr, error) { + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_IdentExpr{ + IdentExpr: &exprpb.Expr_Ident{ + Name: name, + }, + }, + }, nil +} + +func protoList(id int64, list ListExpr) (*exprpb.Expr, error) { + var err error + elems := make([]*exprpb.Expr, list.Size()) + for i, e := range list.Elements() { + elems[i], err = ExprToProto(e) + if err != nil { + return nil, err + } + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_ListExpr{ + ListExpr: &exprpb.Expr_CreateList{ + Elements: elems, + OptionalIndices: list.OptionalIndices(), + }, + }, + }, nil +} + +func protoLiteral(id int64, val ref.Val) (*exprpb.Expr, error) { + c, err := ValToConstant(val) + if err != nil { + return nil, err + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_ConstExpr{ + ConstExpr: c, + }, + }, nil +} + +func protoMap(id int64, m MapExpr) (*exprpb.Expr, error) { + entries := make([]*exprpb.Expr_CreateStruct_Entry, len(m.Entries())) + var err error + for i, e := range m.Entries() { + entries[i], err = EntryExprToProto(e) + if err != nil { + return nil, err + } + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_StructExpr{ + StructExpr: &exprpb.Expr_CreateStruct{ + Entries: entries, + }, + }, + }, nil +} + +func protoMapEntry(id int64, e MapEntry) (*exprpb.Expr_CreateStruct_Entry, error) { + k, err := ExprToProto(e.Key()) + if err != nil { + return nil, err + } + v, err := ExprToProto(e.Value()) + if err != nil { + return nil, err + } + return &exprpb.Expr_CreateStruct_Entry{ + Id: id, + KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{ + MapKey: k, + }, + Value: v, + OptionalEntry: e.IsOptional(), + }, nil +} + +func protoSelect(id int64, s SelectExpr) (*exprpb.Expr, error) { + op, err := ExprToProto(s.Operand()) + if err != nil { + return nil, err + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_SelectExpr{ + SelectExpr: &exprpb.Expr_Select{ + Operand: op, + Field: s.FieldName(), + TestOnly: s.IsTestOnly(), + }, + }, + }, nil +} + +func protoStruct(id int64, s StructExpr) (*exprpb.Expr, error) { + entries := make([]*exprpb.Expr_CreateStruct_Entry, len(s.Fields())) + var err error + for i, e := range s.Fields() { + entries[i], err = EntryExprToProto(e) + if err != nil { + return nil, err + } + } + return &exprpb.Expr{ + Id: id, + ExprKind: &exprpb.Expr_StructExpr{ + StructExpr: &exprpb.Expr_CreateStruct{ + MessageName: s.TypeName(), + Entries: entries, + }, + }, + }, nil +} + +func protoStructField(id int64, f StructField) (*exprpb.Expr_CreateStruct_Entry, error) { + v, err := ExprToProto(f.Value()) + if err != nil { + return nil, err + } + return &exprpb.Expr_CreateStruct_Entry{ + Id: id, + KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{ + FieldKey: f.Name(), + }, + Value: v, + OptionalEntry: f.IsOptional(), + }, nil +} + +// SourceInfoToProto serializes an ast.SourceInfo value to a protobuf SourceInfo object. +func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) { + if info == nil { + return &exprpb.SourceInfo{}, nil + } + sourceInfo := &exprpb.SourceInfo{ + SyntaxVersion: info.SyntaxVersion(), + Location: info.Description(), + LineOffsets: info.LineOffsets(), + Positions: make(map[int64]int32, len(info.OffsetRanges())), + MacroCalls: make(map[int64]*exprpb.Expr, len(info.MacroCalls())), + } + for id, offset := range info.OffsetRanges() { + sourceInfo.Positions[id] = offset.Start + } + for id, e := range info.MacroCalls() { + call, err := ExprToProto(e) + if err != nil { + return nil, err + } + sourceInfo.MacroCalls[id] = call + } + return sourceInfo, nil +} + +// ProtoToSourceInfo deserializes +func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) { + sourceInfo := &SourceInfo{ + syntax: info.GetSyntaxVersion(), + desc: info.GetLocation(), + lines: info.GetLineOffsets(), + offsetRanges: make(map[int64]OffsetRange, len(info.GetPositions())), + macroCalls: make(map[int64]Expr, len(info.GetMacroCalls())), + } + for id, offset := range info.GetPositions() { + sourceInfo.SetOffsetRange(id, OffsetRange{Start: offset, Stop: offset}) + } + for id, e := range info.GetMacroCalls() { + call, err := ProtoToExpr(e) + if err != nil { + return nil, err + } + sourceInfo.SetMacroCall(id, call) + } + return sourceInfo, nil +} + +// ReferenceInfoToProto converts a ReferenceInfo instance to a protobuf Reference suitable for serialization. +func ReferenceInfoToProto(info *ReferenceInfo) (*exprpb.Reference, error) { + c, err := ValToConstant(info.Value) + if err != nil { + return nil, err + } + return &exprpb.Reference{ + Name: info.Name, + OverloadId: info.OverloadIDs, + Value: c, + }, nil +} + +// ProtoToReferenceInfo converts a protobuf Reference into a CEL-native ReferenceInfo instance. +func ProtoToReferenceInfo(ref *exprpb.Reference) (*ReferenceInfo, error) { + v, err := ConstantToVal(ref.GetValue()) + if err != nil { + return nil, err + } + return &ReferenceInfo{ + Name: ref.GetName(), + OverloadIDs: ref.GetOverloadId(), + Value: v, + }, nil +} + +// ValToConstant converts a CEL-native ref.Val to a protobuf Constant. +// +// Only simple scalar types are supported by this method. +func ValToConstant(v ref.Val) (*exprpb.Constant, error) { + if v == nil { + return nil, nil + } + switch v.Type() { + case types.BoolType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: v.Value().(bool)}}, nil + case types.BytesType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: v.Value().([]byte)}}, nil + case types.DoubleType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: v.Value().(float64)}}, nil + case types.IntType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: v.Value().(int64)}}, nil + case types.NullType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: structpb.NullValue_NULL_VALUE}}, nil + case types.StringType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: v.Value().(string)}}, nil + case types.UintType: + return &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: v.Value().(uint64)}}, nil + } + return nil, fmt.Errorf("unsupported constant kind: %v", v.Type()) +} + +// ConstantToVal converts a protobuf Constant to a CEL-native ref.Val. +func ConstantToVal(c *exprpb.Constant) (ref.Val, error) { + if c == nil { + return nil, nil + } + switch c.GetConstantKind().(type) { + case *exprpb.Constant_BoolValue: + return types.Bool(c.GetBoolValue()), nil + case *exprpb.Constant_BytesValue: + return types.Bytes(c.GetBytesValue()), nil + case *exprpb.Constant_DoubleValue: + return types.Double(c.GetDoubleValue()), nil + case *exprpb.Constant_Int64Value: + return types.Int(c.GetInt64Value()), nil + case *exprpb.Constant_NullValue: + return types.NullValue, nil + case *exprpb.Constant_StringValue: + return types.String(c.GetStringValue()), nil + case *exprpb.Constant_Uint64Value: + return types.Uint(c.GetUint64Value()), nil + } + return nil, fmt.Errorf("unsupported constant kind: %v", c.GetConstantKind()) +} diff --git a/common/ast/conversion_test.go b/common/ast/conversion_test.go new file mode 100644 index 000000000..2ed5e096b --- /dev/null +++ b/common/ast/conversion_test.go @@ -0,0 +1,298 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast_test + +import ( + "fmt" + "reflect" + "testing" + "time" + + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + + chkdecls "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/parser" + + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +func TestConvertAST(t *testing.T) { + tests := []struct { + goAST *ast.AST + pbAST *exprpb.CheckedExpr + }{ + { + goAST: ast.NewCheckedAST(ast.NewAST(nil, nil), + map[int64]*types.Type{ + 1: types.BoolType, + 2: types.DynType, + }, + map[int64]*ast.ReferenceInfo{ + 1: ast.NewFunctionReference(overloads.LogicalNot), + 2: ast.NewIdentReference("TRUE", types.True), + }, + ), + pbAST: &exprpb.CheckedExpr{ + Expr: &exprpb.Expr{}, + SourceInfo: &exprpb.SourceInfo{}, + TypeMap: map[int64]*exprpb.Type{ + 1: chkdecls.Bool, + 2: chkdecls.Dyn, + }, + ReferenceMap: map[int64]*exprpb.Reference{ + 1: {OverloadId: []string{overloads.LogicalNot}}, + 2: { + Name: "TRUE", + Value: &exprpb.Constant{ + ConstantKind: &exprpb.Constant_BoolValue{BoolValue: true}, + }, + }, + }, + }, + }, + } + + for i, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + goAST := tc.goAST + pbAST := tc.pbAST + checkedAST, err := ast.ToAST(pbAST) + if err != nil { + t.Fatalf("ProtoToAST() failed: %v", err) + } + if !reflect.DeepEqual(checkedAST.ReferenceMap(), goAST.ReferenceMap()) || + !reflect.DeepEqual(checkedAST.TypeMap(), goAST.TypeMap()) { + t.Errorf("conversion to AST did not produce identical results: got %v, wanted %v", checkedAST, goAST) + } + if !checkedAST.ReferenceMap()[1].Equals(goAST.ReferenceMap()[1]) || + !checkedAST.ReferenceMap()[2].Equals(goAST.ReferenceMap()[2]) { + t.Error("converted reference info values not equal") + } + checkedExpr, err := ast.ToProto(goAST) + if err != nil { + t.Fatalf("ASTToProto() failed: %v", err) + } + if !proto.Equal(checkedExpr, pbAST) { + t.Errorf("conversion to protobuf did not produce identical results: got %v, wanted %v", checkedExpr, pbAST) + } + }) + } +} + +func TestConvertExpr(t *testing.T) { + fac := ast.NewExprFactory() + tests := []struct { + expr string + wantExpr ast.Expr + macroCalls map[int64]ast.Expr + }{ + { + expr: `true`, + wantExpr: fac.NewLiteral(1, types.True), + }, + { + expr: `a`, + wantExpr: fac.NewIdent(1, "a"), + }, + { + expr: `a.b`, + wantExpr: fac.NewSelect(2, fac.NewIdent(1, "a"), "b"), + }, + { + expr: `has(a.b)`, + wantExpr: fac.NewPresenceTest(4, fac.NewIdent(2, "a"), "b"), + macroCalls: map[int64]ast.Expr{ + 4: fac.NewCall(0, "has", fac.NewSelect(3, fac.NewIdent(2, "a"), "b")), + }, + }, + { + expr: `!a`, + wantExpr: fac.NewCall(1, "!_", fac.NewIdent(2, "a")), + }, + { + expr: `a.size()`, + wantExpr: fac.NewMemberCall(2, "size", fac.NewIdent(1, "a")), + }, + { + expr: `[a]`, + wantExpr: fac.NewList(1, []ast.Expr{fac.NewIdent(2, "a")}, []int32{}), + }, + { + expr: `[?a]`, + wantExpr: fac.NewList(1, []ast.Expr{fac.NewIdent(2, "a")}, []int32{0}), + }, + { + expr: `{'string': 42}`, + wantExpr: fac.NewMap(1, []ast.EntryExpr{ + fac.NewMapEntry(2, + fac.NewLiteral(3, types.String("string")), + fac.NewLiteral(4, types.Int(42)), + false), + }), + }, + { + expr: `{?'string': a.?b}`, + wantExpr: fac.NewMap(1, []ast.EntryExpr{ + fac.NewMapEntry(2, + fac.NewLiteral(3, types.String("string")), + fac.NewCall(6, "_?._", fac.NewIdent(4, "a"), fac.NewLiteral(5, types.String("b"))), + true), + }), + }, + { + expr: `custom.StructType{uint_field: 42u}`, + wantExpr: fac.NewStruct(1, + "custom.StructType", + []ast.EntryExpr{ + fac.NewStructField(2, + "uint_field", + fac.NewLiteral(3, types.Uint(42)), + false), + }), + }, + { + expr: `[].exists(i, i)`, + wantExpr: fac.NewComprehension(12, + fac.NewList(1, []ast.Expr{}, []int32{}), + "i", + "__result__", + fac.NewLiteral(5, types.False), + fac.NewCall(8, "@not_strictly_false", fac.NewCall(7, "!_", fac.NewAccuIdent(6))), + fac.NewCall(10, "_||_", fac.NewAccuIdent(9), fac.NewIdent(4, "i")), + fac.NewAccuIdent(11), + ), + macroCalls: map[int64]ast.Expr{ + 12: fac.NewMemberCall(0, "exists", + fac.NewList(1, []ast.Expr{}, []int32{}), + fac.NewIdent(3, "i"), + fac.NewIdent(4, "i"), + ), + }, + }, + } + p, err := parser.NewParser( + parser.Macros(parser.AllMacros...), + parser.EnableOptionalSyntax(true), + parser.PopulateMacroCalls(true)) + if err != nil { + t.Fatalf("parser.NewParser() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + parsed, errs := p.Parse(common.NewTextSource(tc.expr)) + if len(errs.GetErrors()) != 0 { + t.Fatalf("Parse() failed: %s", errs.ToDisplayString()) + } + wantPBExpr, err := ast.ExprToProto(tc.wantExpr) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + if !proto.Equal(parsed.GetExpr(), wantPBExpr) { + t.Errorf("got %v\n, wanted %v", + prototext.Format(parsed.GetExpr()), prototext.Format(wantPBExpr)) + } + gotExpr, err := ast.ProtoToExpr(parsed.GetExpr()) + if err != nil { + t.Fatalf("ast.ProtoToExpr() failed: %v", err) + } + if !reflect.DeepEqual(gotExpr, tc.wantExpr) { + t.Errorf("got %v, wanted %v", gotExpr, tc.wantExpr) + } + info, err := ast.ProtoToSourceInfo(parsed.GetSourceInfo()) + if err != nil { + t.Fatalf("ast.ProtoToSourceInfo() failed: %v", err) + } + for id, wantCall := range tc.macroCalls { + call, found := info.GetMacroCall(id) + if !found { + t.Fatalf("GetMacroCall(%d) returned not found", id) + } + if !reflect.DeepEqual(call, wantCall) { + t.Errorf("macro call got %v, wanted %v", call, wantCall) + } + } + pbInfo, err := ast.SourceInfoToProto(info) + if err != nil { + t.Fatalf("ast.SourceInfoToProto() failed: %v", err) + } + if !proto.Equal(parsed.GetSourceInfo(), pbInfo) { + t.Errorf("source info roundtrip got %v, wanted %v", + prototext.Format(pbInfo), prototext.Format(parsed.GetSourceInfo())) + } + }) + } +} + +func TestReferenceInfoToProtoError(t *testing.T) { + out, err := ast.ReferenceInfoToProto( + ast.NewIdentReference("SECOND", types.Duration{Duration: time.Duration(1) * time.Second})) + if err == nil { + t.Errorf("ReferenceInfoToProto() got %v, wanted error", out) + } +} + +func TestProtoToReferenceInfoError(t *testing.T) { + out, err := ast.ProtoToReferenceInfo(&exprpb.Reference{Value: &exprpb.Constant{}}) + if err == nil { + t.Errorf("ProtoToReferenceInfo() got %v, wanted error", out) + } +} + +func TestConvertVal(t *testing.T) { + tests := []ref.Val{ + types.True, + types.Bytes("bytes"), + types.Double(3.2), + types.Int(-1), + types.NullValue, + types.String("string"), + types.Uint(27), + } + for _, tst := range tests { + c, err := ast.ValToConstant(tst) + if err != nil { + t.Errorf("ValToConstant(%v) failed: %v", tst, err) + } + v, err := ast.ConstantToVal(c) + if err != nil { + t.Errorf("ValToConstant(%v) failed: %v", c, err) + } + if tst.Equal(v) != types.True { + t.Errorf("roundtrip from %v to %v and back did not produce equal results, got %v, wanted %v", tst, c, v, tst) + } + } +} + +func TestValToConstantError(t *testing.T) { + out, err := ast.ValToConstant(types.Duration{Duration: time.Duration(10)}) + if err == nil { + t.Errorf("ValToConstant() got %v, wanted error", out) + } +} + +func TestConstantToValError(t *testing.T) { + out, err := ast.ConstantToVal(&exprpb.Constant{}) + if err == nil { + t.Errorf("ConstantToVal() got %v, wanted error", out) + } +} diff --git a/common/ast/expr.go b/common/ast/expr.go index 489c177b3..1fad19fb1 100644 --- a/common/ast/expr.go +++ b/common/ast/expr.go @@ -15,168 +15,61 @@ package ast import ( - "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // ExprKind represents the expression node kind. type ExprKind int const ( - // UnspecifiedKind represents an unset expression with no specified properties. - UnspecifiedKind ExprKind = iota + // UnspecifiedExprKind represents an unset expression with no specified properties. + UnspecifiedExprKind ExprKind = iota - // LiteralKind represents a primitive scalar literal. - LiteralKind + // CallKind represents a function call. + CallKind + + // ComprehensionKind represents a comprehension expression generated by a macro. + ComprehensionKind // IdentKind represents a simple variable, constant, or type identifier. IdentKind - // SelectKind represents a field selection expression. - SelectKind - - // CallKind represents a function call. - CallKind - // ListKind represents a list literal expression. ListKind + // LiteralKind represents a primitive scalar literal. + LiteralKind + // MapKind represents a map literal expression. MapKind + // SelectKind represents a field selection expression. + SelectKind + // StructKind represents a struct literal expression. StructKind - - // ComprehensionKind represents a comprehension expression generated by a macro. - ComprehensionKind ) -// NavigateCheckedAST converts a CheckedAST to a NavigableExpr -func NavigateCheckedAST(ast *CheckedAST) NavigableExpr { - return newNavigableExpr(nil, ast.Expr, ast.TypeMap) -} - -// ExprMatcher takes a NavigableExpr in and indicates whether the value is a match. -// -// This function type should be use with the `Match` and `MatchList` calls. -type ExprMatcher func(NavigableExpr) bool - -// ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr -// is comprised of all constant values, such as a simple literal or even list and map literal. -func ConstantValueMatcher() ExprMatcher { - return matchIsConstantValue -} - -// KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches -// the specified `kind`. -func KindMatcher(kind ExprKind) ExprMatcher { - return func(e NavigableExpr) bool { - return e.Kind() == kind - } -} - -// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose -// function name is equal to `funcName` regardless of whether it is a member or a global function. -func FunctionMatcher(funcName string) ExprMatcher { - return func(e NavigableExpr) bool { - if e.Kind() != CallKind { - return false - } - return e.AsCall().FunctionName() == funcName - } -} - -// AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list. -// -// Such a result would work well with subsequent MatchList calls. -func AllMatcher() ExprMatcher { - return func(NavigableExpr) bool { - return true - } -} - -// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values of the -// descendants which match. -func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr { - return matchListInternal([]NavigableExpr{expr}, matcher, true) -} - -// MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a -// subset of NavigableExpr values which match. -func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr { - visit := make([]NavigableExpr, len(exprs)) - copy(visit, exprs) - return matchListInternal(visit, matcher, false) -} - -func matchListInternal(visit []NavigableExpr, matcher ExprMatcher, visitDescendants bool) []NavigableExpr { - var matched []NavigableExpr - for len(visit) != 0 { - e := visit[0] - if matcher(e) { - matched = append(matched, e) - } - if visitDescendants { - visit = append(visit[1:], e.Children()...) - } else { - visit = visit[1:] - } - } - return matched -} - -func matchIsConstantValue(e NavigableExpr) bool { - if e.Kind() == LiteralKind { - return true - } - if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind { - for _, child := range e.Children() { - if !matchIsConstantValue(child) { - return false - } - } - return true - } - return false -} - -// NavigableExpr represents the base navigable expression value. +// Expr represents the base expression node in a CEL abstract syntax tree. // -// Depending on the `Kind()` value, the NavigableExpr may be converted to a concrete expression types +// Depending on the `Kind()` value, the Expr may be converted to a concrete expression types // as indicated by the `As` methods. -// -// NavigableExpr values and their concrete expression types should be nil-safe. Conversion of an expr -// to the wrong kind should produce a nil value. -type NavigableExpr interface { +type Expr interface { // ID of the expression as it appears in the AST ID() int64 // Kind of the expression node. See ExprKind for the valid enum values. Kind() ExprKind - // Type of the expression node. - Type() *types.Type - - // Parent returns the parent expression node, if one exists. - Parent() (NavigableExpr, bool) - - // Children returns a list of child expression nodes. - Children() []NavigableExpr - - // ToExpr adapts this NavigableExpr to a protobuf representation. - ToExpr() *exprpb.Expr - - // AsCall adapts the expr into a NavigableCallExpr + // AsCall adapts the expr into a CallExpr // // The Kind() must be equal to a CallKind for the conversion to be well-defined. - AsCall() NavigableCallExpr + AsCall() CallExpr - // AsComprehension adapts the expr into a NavigableComprehensionExpr. + // AsComprehension adapts the expr into a ComprehensionExpr. // // The Kind() must be equal to a ComprehensionKind for the conversion to be well-defined. - AsComprehension() NavigableComprehensionExpr + AsComprehension() ComprehensionExpr // AsIdent adapts the expr into an identifier string. // @@ -188,32 +81,85 @@ type NavigableExpr interface { // The Kind() must be equal to a LiteralKind for the conversion to be well-defined. AsLiteral() ref.Val - // AsList adapts the expr into a NavigableListExpr. + // AsList adapts the expr into a ListExpr. // // The Kind() must be equal to a ListKind for the conversion to be well-defined. - AsList() NavigableListExpr + AsList() ListExpr - // AsMap adapts the expr into a NavigableMapExpr. + // AsMap adapts the expr into a MapExpr. // // The Kind() must be equal to a MapKind for the conversion to be well-defined. - AsMap() NavigableMapExpr + AsMap() MapExpr - // AsSelect adapts the expr into a NavigableSelectExpr. + // AsSelect adapts the expr into a SelectExpr. // // The Kind() must be equal to a SelectKind for the conversion to be well-defined. - AsSelect() NavigableSelectExpr + AsSelect() SelectExpr - // AsStruct adapts the expr into a NavigableStructExpr. + // AsStruct adapts the expr into a StructExpr. // // The Kind() must be equal to a StructKind for the conversion to be well-defined. - AsStruct() NavigableStructExpr + AsStruct() StructExpr - // marker interface method - isNavigable() + // RenumberIDs performs an in-place update of the expression and all of its descendents numeric ids. + RenumberIDs(IDGenerator) + + // SetKindCase replaces the contents of the current expression with the contents of the other. + // + // The SetKindCase takes ownership of any expression instances references within the input Expr. + // A shallow copy is made of the Expr value itself, but not a deep one. + // + // This method should only be used during AST rewrites using temporary Expr values. + SetKindCase(Expr) + + // isExpr is a marker interface. + isExpr() +} + +// EntryExprKind represents the possible EntryExpr kinds. +type EntryExprKind int + +const ( + // UnspecifiedEntryExprKind indicates that the entry expr is not set. + UnspecifiedEntryExprKind EntryExprKind = iota + + // MapEntryKind indicates that the entry is a MapEntry type with key and value expressions. + MapEntryKind + + // StructFieldKind indicates that the entry is a StructField with a field name and initializer + // expression. + StructFieldKind +) + +// EntryExpr represents the base entry expression in a CEL map or struct literal. +type EntryExpr interface { + // ID of the entry as it appears in the AST. + ID() int64 + + // Kind of the entry expression node. See EntryExprKind for valid enum values. + Kind() EntryExprKind + + // AsMapEntry casts the EntryExpr to a MapEntry. + // + // The Kind() must be equal to MapEntryKind for the conversion to be well-defined. + AsMapEntry() MapEntry + + // AsStructField casts the EntryExpr to a StructField + // + // The Kind() must be equal to StructFieldKind for the conversion to be well-defined. + AsStructField() StructField + + // RenumberIDs performs an in-place update of the expression and all of its descendents numeric ids. + RenumberIDs(IDGenerator) + + isEntryExpr() } -// NavigableCallExpr defines an interface for inspecting a function call and its arugments. -type NavigableCallExpr interface { +// IDGenerator produces monotonically increasing ids suitable for tagging expression nodes. +type IDGenerator func() int64 + +// CallExpr defines an interface for inspecting a function call and its arugments. +type CallExpr interface { // FunctionName returns the name of the function. FunctionName() string @@ -221,22 +167,19 @@ type NavigableCallExpr interface { IsMemberFunction() bool // Target returns the target of the expression if one is present. - Target() NavigableExpr + Target() Expr // Args returns the list of call arguments, excluding the target. - Args() []NavigableExpr - - // ReturnType returns the result type of the call. - ReturnType() *types.Type + Args() []Expr // marker interface method - isNavigable() + isExpr() } -// NavigableListExpr defines an interface for inspecting a list literal expression. -type NavigableListExpr interface { +// ListExpr defines an interface for inspecting a list literal expression. +type ListExpr interface { // Elements returns the list elements as navigable expressions. - Elements() []NavigableExpr + Elements() []Expr // OptionalIndicies returns the list of optional indices in the list literal. OptionalIndices() []int32 @@ -245,13 +188,13 @@ type NavigableListExpr interface { Size() int // marker interface method - isNavigable() + isExpr() } -// NavigableSelectExpr defines an interface for inspecting a select expression. -type NavigableSelectExpr interface { +// SelectExpr defines an interface for inspecting a select expression. +type SelectExpr interface { // Operand returns the selection operand expression. - Operand() NavigableExpr + Operand() Expr // FieldName returns the field name being selected from the operand. FieldName() string @@ -260,67 +203,67 @@ type NavigableSelectExpr interface { IsTestOnly() bool // marker interface method - isNavigable() + isExpr() } -// NavigableMapExpr defines an interface for inspecting a map expression. -type NavigableMapExpr interface { - // Entries returns the map key value pairs as NavigableEntry values. - Entries() []NavigableEntry +// MapExpr defines an interface for inspecting a map expression. +type MapExpr interface { + // Entries returns the map key value pairs as EntryExpr values. + Entries() []EntryExpr // Size returns the number of entries in the map. Size() int // marker interface method - isNavigable() + isExpr() } -// NavigableEntry defines an interface for inspecting a map entry. -type NavigableEntry interface { +// MapEntry defines an interface for inspecting a map entry. +type MapEntry interface { // Key returns the map entry key expression. - Key() NavigableExpr + Key() Expr // Value returns the map entry value expression. - Value() NavigableExpr + Value() Expr // IsOptional returns whether the entry is optional. IsOptional() bool // marker interface method - isNavigable() + isEntryExpr() } -// NavigableStructExpr defines an interfaces for inspecting a struct and its field initializers. -type NavigableStructExpr interface { +// StructExpr defines an interfaces for inspecting a struct and its field initializers. +type StructExpr interface { // TypeName returns the struct type name. TypeName() string - // Fields returns the set of field initializers in the struct expression as NavigableField values. - Fields() []NavigableField + // Fields returns the set of field initializers in the struct expression as EntryExpr values. + Fields() []EntryExpr // marker interface method - isNavigable() + isExpr() } -// NavigableField defines an interface for inspecting a struct field initialization. -type NavigableField interface { - // FieldName returns the name of the field. - FieldName() string +// StructField defines an interface for inspecting a struct field initialization. +type StructField interface { + // Name returns the name of the field. + Name() string // Value returns the field initialization expression. - Value() NavigableExpr + Value() Expr // IsOptional returns whether the field is optional. IsOptional() bool // marker interface method - isNavigable() + isEntryExpr() } -// NavigableComprehensionExpr defines an interface for inspecting a comprehension expression. -type NavigableComprehensionExpr interface { +// ComprehensionExpr defines an interface for inspecting a comprehension expression. +type ComprehensionExpr interface { // IterRange returns the iteration range expression. - IterRange() NavigableExpr + IterRange() Expr // IterVar returns the iteration variable name. IterVar() string @@ -329,388 +272,572 @@ type NavigableComprehensionExpr interface { AccuVar() string // AccuInit returns the accumulation variable initialization expression. - AccuInit() NavigableExpr + AccuInit() Expr // LoopCondition returns the loop condition expression. - LoopCondition() NavigableExpr + LoopCondition() Expr // LoopStep returns the loop step expression. - LoopStep() NavigableExpr + LoopStep() Expr // Result returns the comprehension result expression. - Result() NavigableExpr + Result() Expr // marker interface method - isNavigable() + isExpr() } -func newNavigableExpr(parent NavigableExpr, expr *exprpb.Expr, typeMap map[int64]*types.Type) NavigableExpr { - kind, factory := kindOf(expr) - nav := &navigableExprImpl{ - parent: parent, - kind: kind, - expr: expr, - typeMap: typeMap, - createChildren: factory, +var _ Expr = &expr{} + +type expr struct { + id int64 + exprKindCase +} + +type exprKindCase interface { + Kind() ExprKind + + renumberIDs(IDGenerator) + + isExpr() +} + +func (e *expr) ID() int64 { + if e == nil { + return 0 } - return nav + return e.id } -type navigableExprImpl struct { - parent NavigableExpr - kind ExprKind - expr *exprpb.Expr - typeMap map[int64]*types.Type - createChildren childFactory +func (e *expr) Kind() ExprKind { + if e == nil || e.exprKindCase == nil { + return UnspecifiedExprKind + } + return e.exprKindCase.Kind() } -func (nav *navigableExprImpl) ID() int64 { - return nav.ToExpr().GetId() +func (e *expr) AsCall() CallExpr { + if e.Kind() != CallKind { + return nilCall + } + return e.exprKindCase.(CallExpr) } -func (nav *navigableExprImpl) Kind() ExprKind { - return nav.kind +func (e *expr) AsComprehension() ComprehensionExpr { + if e.Kind() != ComprehensionKind { + return nilCompre + } + return e.exprKindCase.(ComprehensionExpr) } -func (nav *navigableExprImpl) Type() *types.Type { - if t, found := nav.typeMap[nav.ID()]; found { - return t +func (e *expr) AsIdent() string { + if e.Kind() != IdentKind { + return "" } - return types.DynType + return string(e.exprKindCase.(baseIdentExpr)) } -func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) { - if nav.parent != nil { - return nav.parent, true +func (e *expr) AsLiteral() ref.Val { + if e.Kind() != LiteralKind { + return nil } - return nil, false + return e.exprKindCase.(*baseLiteral).Val } -func (nav *navigableExprImpl) Children() []NavigableExpr { - return nav.createChildren(nav) +func (e *expr) AsList() ListExpr { + if e.Kind() != ListKind { + return nilList + } + return e.exprKindCase.(ListExpr) } -func (nav *navigableExprImpl) ToExpr() *exprpb.Expr { - return nav.expr +func (e *expr) AsMap() MapExpr { + if e.Kind() != MapKind { + return nilMap + } + return e.exprKindCase.(MapExpr) } -func (nav *navigableExprImpl) AsCall() NavigableCallExpr { - return navigableCallImpl{navigableExprImpl: nav} +func (e *expr) AsSelect() SelectExpr { + if e.Kind() != SelectKind { + return nilSel + } + return e.exprKindCase.(SelectExpr) } -func (nav *navigableExprImpl) AsComprehension() NavigableComprehensionExpr { - return navigableComprehensionImpl{navigableExprImpl: nav} +func (e *expr) AsStruct() StructExpr { + if e.Kind() != StructKind { + return nilStruct + } + return e.exprKindCase.(StructExpr) } -func (nav *navigableExprImpl) AsIdent() string { - return nav.ToExpr().GetIdentExpr().GetName() +func (e *expr) SetKindCase(other Expr) { + if e == nil { + return + } + if other == nil { + e.exprKindCase = nil + return + } + switch other.Kind() { + case CallKind: + c := other.AsCall() + e.exprKindCase = &baseCallExpr{ + function: c.FunctionName(), + target: c.Target(), + args: c.Args(), + } + case ComprehensionKind: + c := other.AsComprehension() + e.exprKindCase = &baseComprehensionExpr{ + iterRange: c.IterRange(), + iterVar: c.IterVar(), + accuVar: c.AccuVar(), + accuInit: c.AccuInit(), + loopCond: c.LoopCondition(), + loopStep: c.LoopStep(), + result: c.Result(), + } + case IdentKind: + e.exprKindCase = baseIdentExpr(other.AsIdent()) + case ListKind: + l := other.AsList() + e.exprKindCase = &baseListExpr{ + elements: l.Elements(), + optIndices: l.OptionalIndices(), + } + case LiteralKind: + e.exprKindCase = &baseLiteral{Val: other.AsLiteral()} + case MapKind: + e.exprKindCase = &baseMapExpr{ + entries: other.AsMap().Entries(), + } + case SelectKind: + s := other.AsSelect() + e.exprKindCase = &baseSelectExpr{ + operand: s.Operand(), + field: s.FieldName(), + testOnly: s.IsTestOnly(), + } + case StructKind: + s := other.AsStruct() + e.exprKindCase = &baseStructExpr{ + typeName: s.TypeName(), + fields: s.Fields(), + } + case UnspecifiedExprKind: + e.exprKindCase = nil + } } -func (nav *navigableExprImpl) AsLiteral() ref.Val { - if nav.Kind() != LiteralKind { - return nil +func (e *expr) RenumberIDs(idGen IDGenerator) { + if e.Kind() == UnspecifiedExprKind { + return } - val, err := ConstantToVal(nav.ToExpr().GetConstExpr()) - if err != nil { - panic(err) + e.id = idGen() + e.exprKindCase.renumberIDs(idGen) +} + +type baseCallExpr struct { + function string + target Expr + args []Expr + isMember bool +} + +func (*baseCallExpr) Kind() ExprKind { + return CallKind +} + +func (e *baseCallExpr) FunctionName() string { + if e == nil { + return "" } - return val + return e.function } -func (nav *navigableExprImpl) AsList() NavigableListExpr { - return navigableListImpl{navigableExprImpl: nav} +func (e *baseCallExpr) IsMemberFunction() bool { + if e == nil { + return false + } + return e.isMember } -func (nav *navigableExprImpl) AsMap() NavigableMapExpr { - return navigableMapImpl{navigableExprImpl: nav} +func (e *baseCallExpr) Target() Expr { + if e == nil || !e.IsMemberFunction() { + return nilExpr + } + return e.target } -func (nav *navigableExprImpl) AsSelect() NavigableSelectExpr { - return navigableSelectImpl{navigableExprImpl: nav} +func (e *baseCallExpr) Args() []Expr { + if e == nil { + return []Expr{} + } + return e.args } -func (nav *navigableExprImpl) AsStruct() NavigableStructExpr { - return navigableStructImpl{navigableExprImpl: nav} +func (e *baseCallExpr) renumberIDs(idGen IDGenerator) { + if e.IsMemberFunction() { + e.Target().RenumberIDs(idGen) + } + for _, arg := range e.Args() { + arg.RenumberIDs(idGen) + } } -func (nav *navigableExprImpl) createChild(e *exprpb.Expr) NavigableExpr { - return newNavigableExpr(nav, e, nav.typeMap) +func (*baseCallExpr) isExpr() {} + +var _ ComprehensionExpr = &baseComprehensionExpr{} + +type baseComprehensionExpr struct { + iterRange Expr + iterVar string + accuVar string + accuInit Expr + loopCond Expr + loopStep Expr + result Expr } -func (nav *navigableExprImpl) isNavigable() {} +func (*baseComprehensionExpr) Kind() ExprKind { + return ComprehensionKind +} -type navigableCallImpl struct { - *navigableExprImpl +func (e *baseComprehensionExpr) IterRange() Expr { + if e == nil { + return nilExpr + } + return e.iterRange } -func (call navigableCallImpl) FunctionName() string { - return call.ToExpr().GetCallExpr().GetFunction() +func (e *baseComprehensionExpr) IterVar() string { + return e.iterVar } -func (call navigableCallImpl) IsMemberFunction() bool { - return call.ToExpr().GetCallExpr().GetTarget() != nil +func (e *baseComprehensionExpr) AccuVar() string { + return e.accuVar } -func (call navigableCallImpl) Target() NavigableExpr { - t := call.ToExpr().GetCallExpr().GetTarget() - if t != nil { - return call.createChild(t) +func (e *baseComprehensionExpr) AccuInit() Expr { + if e == nil { + return nilExpr } - return nil + return e.accuInit } -func (call navigableCallImpl) Args() []NavigableExpr { - args := call.ToExpr().GetCallExpr().GetArgs() - navArgs := make([]NavigableExpr, len(args)) - for i, a := range args { - navArgs[i] = call.createChild(a) +func (e *baseComprehensionExpr) LoopCondition() Expr { + if e == nil { + return nilExpr } - return navArgs + return e.loopCond } -func (call navigableCallImpl) ReturnType() *types.Type { - return call.Type() +func (e *baseComprehensionExpr) LoopStep() Expr { + if e == nil { + return nilExpr + } + return e.loopStep } -type navigableComprehensionImpl struct { - *navigableExprImpl +func (e *baseComprehensionExpr) Result() Expr { + if e == nil { + return nilExpr + } + return e.result } -func (comp navigableComprehensionImpl) IterRange() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetIterRange()) +func (e *baseComprehensionExpr) renumberIDs(idGen IDGenerator) { + e.IterRange().RenumberIDs(idGen) + e.AccuInit().RenumberIDs(idGen) + e.LoopCondition().RenumberIDs(idGen) + e.LoopStep().RenumberIDs(idGen) + e.Result().RenumberIDs(idGen) } -func (comp navigableComprehensionImpl) IterVar() string { - return comp.ToExpr().GetComprehensionExpr().GetIterVar() +func (*baseComprehensionExpr) isExpr() {} + +var _ exprKindCase = baseIdentExpr("") + +type baseIdentExpr string + +func (baseIdentExpr) Kind() ExprKind { + return IdentKind } -func (comp navigableComprehensionImpl) AccuVar() string { - return comp.ToExpr().GetComprehensionExpr().GetAccuVar() +func (e baseIdentExpr) renumberIDs(IDGenerator) {} + +func (baseIdentExpr) isExpr() {} + +var _ exprKindCase = &baseLiteral{} +var _ ref.Val = &baseLiteral{} + +type baseLiteral struct { + ref.Val } -func (comp navigableComprehensionImpl) AccuInit() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetAccuInit()) +func (*baseLiteral) Kind() ExprKind { + return LiteralKind } -func (comp navigableComprehensionImpl) LoopCondition() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopCondition()) +func (l *baseLiteral) renumberIDs(IDGenerator) {} + +func (*baseLiteral) isExpr() {} + +var _ ListExpr = &baseListExpr{} + +type baseListExpr struct { + elements []Expr + optIndices []int32 } -func (comp navigableComprehensionImpl) LoopStep() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetLoopStep()) +func (*baseListExpr) Kind() ExprKind { + return ListKind } -func (comp navigableComprehensionImpl) Result() NavigableExpr { - return comp.createChild(comp.ToExpr().GetComprehensionExpr().GetResult()) +func (e *baseListExpr) Elements() []Expr { + if e == nil { + return []Expr{} + } + return e.elements } -type navigableListImpl struct { - *navigableExprImpl +func (e *baseListExpr) OptionalIndices() []int32 { + if e == nil { + return []int32{} + } + return e.optIndices } -func (l navigableListImpl) Elements() []NavigableExpr { - return l.Children() +func (e *baseListExpr) Size() int { + return len(e.Elements()) } -func (l navigableListImpl) OptionalIndices() []int32 { - return l.ToExpr().GetListExpr().GetOptionalIndices() +func (e *baseListExpr) renumberIDs(idGen IDGenerator) { + for _, elem := range e.Elements() { + elem.RenumberIDs(idGen) + } } -func (l navigableListImpl) Size() int { - return len(l.ToExpr().GetListExpr().GetElements()) +func (*baseListExpr) isExpr() {} + +type baseMapExpr struct { + entries []EntryExpr } -type navigableMapImpl struct { - *navigableExprImpl +func (*baseMapExpr) Kind() ExprKind { + return MapKind } -func (m navigableMapImpl) Entries() []NavigableEntry { - mapExpr := m.ToExpr().GetStructExpr() - entries := make([]NavigableEntry, len(mapExpr.GetEntries())) - for i, e := range mapExpr.GetEntries() { - entries[i] = navigableEntryImpl{ - key: m.createChild(e.GetMapKey()), - val: m.createChild(e.GetValue()), - isOpt: e.GetOptionalEntry(), - } +func (e *baseMapExpr) Entries() []EntryExpr { + if e == nil { + return []EntryExpr{} } - return entries + return e.entries } -func (m navigableMapImpl) Size() int { - return len(m.ToExpr().GetStructExpr().GetEntries()) +func (e *baseMapExpr) Size() int { + return len(e.Entries()) } -type navigableEntryImpl struct { - key NavigableExpr - val NavigableExpr - isOpt bool +func (e *baseMapExpr) renumberIDs(idGen IDGenerator) { + for _, entry := range e.Entries() { + entry.RenumberIDs(idGen) + } } -func (e navigableEntryImpl) Key() NavigableExpr { - return e.key -} +func (*baseMapExpr) isExpr() {} -func (e navigableEntryImpl) Value() NavigableExpr { - return e.val +type baseSelectExpr struct { + operand Expr + field string + testOnly bool } -func (e navigableEntryImpl) IsOptional() bool { - return e.isOpt +func (*baseSelectExpr) Kind() ExprKind { + return SelectKind } -func (e navigableEntryImpl) isNavigable() {} +func (e *baseSelectExpr) Operand() Expr { + if e == nil || e.operand == nil { + return nilExpr + } + return e.operand +} -type navigableSelectImpl struct { - *navigableExprImpl +func (e *baseSelectExpr) FieldName() string { + if e == nil { + return "" + } + return e.field } -func (sel navigableSelectImpl) FieldName() string { - return sel.ToExpr().GetSelectExpr().GetField() +func (e *baseSelectExpr) IsTestOnly() bool { + if e == nil { + return false + } + return e.testOnly } -func (sel navigableSelectImpl) IsTestOnly() bool { - return sel.ToExpr().GetSelectExpr().GetTestOnly() +func (e *baseSelectExpr) renumberIDs(idGen IDGenerator) { + e.Operand().RenumberIDs(idGen) } -func (sel navigableSelectImpl) Operand() NavigableExpr { - return sel.createChild(sel.ToExpr().GetSelectExpr().GetOperand()) +func (*baseSelectExpr) isExpr() {} + +type baseStructExpr struct { + typeName string + fields []EntryExpr } -type navigableStructImpl struct { - *navigableExprImpl +func (*baseStructExpr) Kind() ExprKind { + return StructKind } -func (s navigableStructImpl) TypeName() string { - return s.ToExpr().GetStructExpr().GetMessageName() +func (e *baseStructExpr) TypeName() string { + if e == nil { + return "" + } + return e.typeName } -func (s navigableStructImpl) Fields() []NavigableField { - fieldInits := s.ToExpr().GetStructExpr().GetEntries() - fields := make([]NavigableField, len(fieldInits)) - for i, f := range fieldInits { - fields[i] = navigableFieldImpl{ - name: f.GetFieldKey(), - val: s.createChild(f.GetValue()), - isOpt: f.GetOptionalEntry(), - } +func (e *baseStructExpr) Fields() []EntryExpr { + if e == nil { + return []EntryExpr{} } - return fields + return e.fields } -type navigableFieldImpl struct { - name string - val NavigableExpr - isOpt bool +func (e *baseStructExpr) renumberIDs(idGen IDGenerator) { + for _, f := range e.Fields() { + f.RenumberIDs(idGen) + } } -func (f navigableFieldImpl) FieldName() string { - return f.name +func (*baseStructExpr) isExpr() {} + +type entryExprKindCase interface { + Kind() EntryExprKind + + renumberIDs(IDGenerator) + + isEntryExpr() } -func (f navigableFieldImpl) Value() NavigableExpr { - return f.val +var _ EntryExpr = &entryExpr{} + +type entryExpr struct { + id int64 + entryExprKindCase } -func (f navigableFieldImpl) IsOptional() bool { - return f.isOpt +func (e *entryExpr) ID() int64 { + return e.id } -func (f navigableFieldImpl) isNavigable() {} +func (e *entryExpr) AsMapEntry() MapEntry { + if e.Kind() != MapEntryKind { + return nilMapEntry + } + return e.entryExprKindCase.(MapEntry) +} -func kindOf(expr *exprpb.Expr) (ExprKind, childFactory) { - switch expr.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - return LiteralKind, noopFactory - case *exprpb.Expr_IdentExpr: - return IdentKind, noopFactory - case *exprpb.Expr_SelectExpr: - return SelectKind, selectFactory - case *exprpb.Expr_CallExpr: - return CallKind, callArgFactory - case *exprpb.Expr_ListExpr: - return ListKind, listElemFactory - case *exprpb.Expr_StructExpr: - if expr.GetStructExpr().GetMessageName() != "" { - return StructKind, structEntryFactory - } - return MapKind, mapEntryFactory - case *exprpb.Expr_ComprehensionExpr: - return ComprehensionKind, comprehensionFactory - default: - return UnspecifiedKind, noopFactory +func (e *entryExpr) AsStructField() StructField { + if e.Kind() != StructFieldKind { + return nilStructField } + return e.entryExprKindCase.(StructField) } -type childFactory func(*navigableExprImpl) []NavigableExpr +func (e *entryExpr) RenumberIDs(idGen IDGenerator) { + e.id = idGen() + e.entryExprKindCase.renumberIDs(idGen) +} -func noopFactory(*navigableExprImpl) []NavigableExpr { - return nil +type baseMapEntry struct { + key Expr + value Expr + isOptional bool } -func selectFactory(nav *navigableExprImpl) []NavigableExpr { - return []NavigableExpr{ - nav.createChild(nav.ToExpr().GetSelectExpr().GetOperand()), - } +func (e *baseMapEntry) Kind() EntryExprKind { + return MapEntryKind } -func callArgFactory(nav *navigableExprImpl) []NavigableExpr { - call := nav.ToExpr().GetCallExpr() - argCount := len(call.GetArgs()) - if call.GetTarget() != nil { - argCount++ +func (e *baseMapEntry) Key() Expr { + if e == nil { + return nilExpr } - navExprs := make([]NavigableExpr, argCount) - i := 0 - if call.GetTarget() != nil { - navExprs[i] = nav.createChild(call.GetTarget()) - i++ - } - for _, arg := range call.GetArgs() { - navExprs[i] = nav.createChild(arg) - i++ + return e.key +} + +func (e *baseMapEntry) Value() Expr { + if e == nil { + return nilExpr } - return navExprs + return e.value } -func listElemFactory(nav *navigableExprImpl) []NavigableExpr { - l := nav.ToExpr().GetListExpr() - navExprs := make([]NavigableExpr, len(l.GetElements())) - for i, e := range l.GetElements() { - navExprs[i] = nav.createChild(e) +func (e *baseMapEntry) IsOptional() bool { + if e == nil { + return false } - return navExprs + return e.isOptional } -func structEntryFactory(nav *navigableExprImpl) []NavigableExpr { - s := nav.ToExpr().GetStructExpr() - entries := make([]NavigableExpr, len(s.GetEntries())) - for i, e := range s.GetEntries() { +func (e *baseMapEntry) renumberIDs(idGen IDGenerator) { + e.Key().RenumberIDs(idGen) + e.Value().RenumberIDs(idGen) +} - entries[i] = nav.createChild(e.GetValue()) +func (*baseMapEntry) isEntryExpr() {} + +type baseStructField struct { + field string + value Expr + isOptional bool +} + +func (f *baseStructField) Kind() EntryExprKind { + return StructFieldKind +} + +func (f *baseStructField) Name() string { + if f == nil { + return "" } - return entries + return f.field } -func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr { - s := nav.ToExpr().GetStructExpr() - entries := make([]NavigableExpr, len(s.GetEntries())*2) - j := 0 - for _, e := range s.GetEntries() { - entries[j] = nav.createChild(e.GetMapKey()) - entries[j+1] = nav.createChild(e.GetValue()) - j += 2 +func (f *baseStructField) Value() Expr { + if f == nil { + return nilExpr } - return entries + return f.value } -func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr { - compre := nav.ToExpr().GetComprehensionExpr() - return []NavigableExpr{ - nav.createChild(compre.GetIterRange()), - nav.createChild(compre.GetAccuInit()), - nav.createChild(compre.GetLoopCondition()), - nav.createChild(compre.GetLoopStep()), - nav.createChild(compre.GetResult()), +func (f *baseStructField) IsOptional() bool { + if f == nil { + return false } + return f.isOptional } + +func (f *baseStructField) renumberIDs(idGen IDGenerator) { + f.Value().RenumberIDs(idGen) +} + +func (*baseStructField) isEntryExpr() {} + +var ( + nilExpr *expr = nil + nilCall *baseCallExpr = nil + nilCompre *baseComprehensionExpr = nil + nilList *baseListExpr = nil + nilMap *baseMapExpr = nil + nilMapEntry *baseMapEntry = nil + nilSel *baseSelectExpr = nil + nilStruct *baseStructExpr = nil + nilStructField *baseStructField = nil +) diff --git a/common/ast/expr_test.go b/common/ast/expr_test.go index 05fc73010..5b38464db 100644 --- a/common/ast/expr_test.go +++ b/common/ast/expr_test.go @@ -15,442 +15,467 @@ package ast_test import ( + "fmt" + "reflect" "testing" - "google.golang.org/protobuf/proto" - - "github.com/google/cel-go/checker" - "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" - "github.com/google/cel-go/common/containers" - "github.com/google/cel-go/common/decls" - "github.com/google/cel-go/common/stdlib" "github.com/google/cel-go/common/types" - "github.com/google/cel-go/parser" - - proto3pb "github.com/google/cel-go/test/proto3pb" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) -func TestNavigateExpr(t *testing.T) { - tests := []struct { - expr string - descedantCount int - callCount int - }{ - { - expr: `'a' == 'b'`, - descedantCount: 3, - callCount: 1, - }, - { - expr: `'a'.size()`, - descedantCount: 2, - callCount: 1, - }, - { - expr: `[1, 2, 3]`, - descedantCount: 4, - callCount: 0, - }, - { - expr: `[1, 2, 3][0]`, - descedantCount: 6, - callCount: 1, - }, - { - expr: `{1u: 'hello'}`, - descedantCount: 3, - callCount: 0, - }, - { - expr: `{'hello': 'world'}.hello`, - descedantCount: 4, - callCount: 0, - }, - { - expr: `type(1) == int`, - descedantCount: 4, - callCount: 2, - }, - { - expr: `google.expr.proto3.test.TestAllTypes{single_int32: 1}`, - descedantCount: 2, - callCount: 0, - }, - { - expr: `[true].exists(i, i)`, - descedantCount: 11, // 2 for iter range, 1 for accu init, 4 for loop condition, 3 for loop step, 1 for result - callCount: 3, // @not_strictly_false(!result), accu_init || i - }, - } - - for _, tst := range tests { - tc := tst - t.Run(tc.expr, func(t *testing.T) { - checked := mustTypeCheck(t, tc.expr) - nav := ast.NavigateCheckedAST(checked) - descendants := ast.MatchDescendants(nav, ast.AllMatcher()) - if len(descendants) != tc.descedantCount { - t.Errorf("ast.MatchDescendants(%v) got %d descendants, wanted %d", checked.Expr, len(descendants), tc.descedantCount) - } - calls := ast.MatchSubset(descendants, ast.KindMatcher(ast.CallKind)) - if len(calls) != tc.callCount { - t.Errorf("ast.MatchSubset(%v) got %d calls, wanted %d", checked.Expr, len(calls), tc.callCount) +func TestSetKindCase(t *testing.T) { + fac := ast.NewExprFactory() + tests := []ast.Expr{ + fac.NewCall(1, "_==_", fac.NewLiteral(2, types.True), fac.NewLiteral(3, types.False)), + fac.NewComprehension(12, + fac.NewList(1, []ast.Expr{}, []int32{}), + "i", + "__result__", + fac.NewLiteral(5, types.False), + fac.NewCall(8, "@not_strictly_false", fac.NewCall(7, "!_", fac.NewAccuIdent(6))), + fac.NewCall(10, "_||_", fac.NewAccuIdent(9), fac.NewIdent(4, "i")), + fac.NewAccuIdent(11), + ), + fac.NewIdent(1, "a"), + fac.NewLiteral(1, types.Bytes("hello")), + fac.NewList(1, []ast.Expr{fac.NewIdent(2, "a"), fac.NewIdent(3, "b")}, []int32{}), + fac.NewMap(1, []ast.EntryExpr{ + fac.NewMapEntry(2, + fac.NewLiteral(3, types.String("string")), + fac.NewCall(6, "_?._", fac.NewIdent(4, "a"), fac.NewLiteral(5, types.String("b"))), + true), + }), + fac.NewSelect(2, fac.NewIdent(1, "a"), "b"), + fac.NewStruct(1, + "custom.StructType", + []ast.EntryExpr{ + fac.NewStructField(2, + "uint_field", + fac.NewLiteral(3, types.Uint(42)), + false), + }), + } + expr := nilTestExpr(t) + for i, tst := range tests { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + expr.SetKindCase(tst) + switch expr.Kind() { + case ast.CallKind: + if !reflect.DeepEqual(expr.AsCall(), tst.AsCall()) { + t.Errorf("got %v, wanted %v", expr.AsCall(), tst.AsCall()) + } + case ast.ComprehensionKind: + if !reflect.DeepEqual(expr.AsComprehension(), tst.AsComprehension()) { + t.Errorf("got %v, wanted %v", expr.AsComprehension(), tst.AsComprehension()) + } + case ast.IdentKind: + if !reflect.DeepEqual(expr.AsIdent(), tst.AsIdent()) { + t.Errorf("got %v, wanted %v", expr.AsIdent(), tst.AsIdent()) + } + case ast.LiteralKind: + if !reflect.DeepEqual(expr.AsLiteral(), tst.AsLiteral()) { + t.Errorf("got %v, wanted %v", expr.AsLiteral(), tst.AsLiteral()) + } + case ast.ListKind: + if !reflect.DeepEqual(expr.AsList(), tst.AsList()) { + t.Errorf("got %v, wanted %v", expr.AsList(), tst.AsList()) + } + case ast.MapKind: + case ast.SelectKind: + case ast.StructKind: + default: + t.Errorf("unable to determine kind case: %v", tst) } }) } } -func TestNavigateExprNilSafety(t *testing.T) { - tests := []struct { - name string - e ast.NavigableExpr - }{ - { - name: "nil expr", - e: ast.NavigateCheckedAST(&ast.CheckedAST{TypeMap: map[int64]*types.Type{}}), - }, - { - name: "empty expr", - e: ast.NavigateCheckedAST(&ast.CheckedAST{Expr: &exprpb.Expr{}, TypeMap: map[int64]*types.Type{}}), - }, - } - for _, tst := range tests { - tc := tst - t.Run(tc.name, func(t *testing.T) { - e := tc.e - if e.Kind() != ast.UnspecifiedKind { - t.Errorf("Kind() got %v, wanted unspecified kind", e.Kind()) - } - if e.ID() != 0 { - t.Errorf("ID() got %d, wanted 0", e.ID()) - } - if e.Type() != types.DynType { - t.Errorf("Type() got %v, wanted types.DynType", e.Type()) - } - if p, found := e.Parent(); found { - t.Errorf("Parent() got %v, waned not found", p) - } - if len(e.Children()) != 0 { - t.Errorf("Children() got %v, wanted none", e.Children()) - } - if e.AsLiteral() != nil { - t.Errorf("AsLiteral() got %v, wanted nil", e.AsLiteral()) - } - }) - } -} - -func TestNavigableCallExpr_Member(t *testing.T) { - checked := mustTypeCheck(t, `'hello'.size()`) - expr := ast.NavigateCheckedAST(checked) +func TestCall(t *testing.T) { + fac := ast.NewExprFactory() + expr := fac.NewCall(1, "size", fac.NewLiteral(2, types.String("hello"))) if expr.Kind() != ast.CallKind { - t.Errorf("Kind() got %v, wanted CallKind", expr.Kind()) + t.Fatalf("NewCall() produced non-call expression: %v", expr.Kind()) } call := expr.AsCall() if call.FunctionName() != "size" { - t.Errorf("FunctionName() got %s, wanted size", call.FunctionName()) - } - if call.Target() == nil { - t.Fatalf("Target() got nil, wanted non-nil") + t.Errorf("Function() got %s, wanted 'size''", call.FunctionName()) } - if len(call.Args()) != 0 { - t.Errorf("Args() got %v, wanted 0", call.Args()) - } - target := call.Target() - if target.Kind() != ast.LiteralKind { - t.Errorf("Kind() got %v, wanted literal", target.Kind()) - } - if target.AsLiteral().Equal(types.String("hello")) != types.True { - t.Errorf("AsLiteral() got %v, wanted 'hello'", target.AsLiteral()) - } - if call.ReturnType() != expr.Type() { - t.Errorf("ReturnType() got %v, wanted %v", call.ReturnType(), expr.Type()) - } - if call.ReturnType() != types.IntType { - t.Errorf("ReturnType() got %v, wanted int", call.ReturnType()) + if len(call.Args()) != 1 { + t.Errorf("Args() got %v, wanted one", call.Args()) } - if p, found := target.Parent(); !found || p != expr { - t.Errorf("Parent() got %v, wanted %v", p, expr) + if call.IsMemberFunction() { + t.Error("IsMemberFunction() got true, wanted false") } - sizeFn := ast.MatchDescendants(expr, ast.FunctionMatcher("size")) - if len(sizeFn) != 1 { - t.Errorf("ast.MatchDescendants() size function returned %v, wanted 1", sizeFn) + if call.Target().ID() != 0 { + t.Errorf("empty Target() got %d, wanted 0", call.Target().ID()) } - constantValues := ast.MatchDescendants(expr, ast.ConstantValueMatcher()) - if len(constantValues) != 1 { - t.Fatalf("ast.MatchDescendants() constant values returned %v, wanted 1 value", constantValues) + expr.RenumberIDs(testIDGen(100)) + if expr.ID() != 101 { + t.Errorf("expr.ID() got %d, wanted 101 after RenumberIDs", expr.ID()) } - if constantValues[0].AsLiteral().Equal(types.String("hello")) != types.True { - t.Errorf("constantValues[0] got %v, wanted 'hello'", constantValues[0]) + if expr.AsCall().Args()[0].ID() != 102 { + t.Errorf("Args()[0].ID() got %d, wanted 102 after RenumberIDs", expr.AsCall().Args()[0].ID()) } } -func TestNavigableCallExpr_Global(t *testing.T) { - checked := mustTypeCheck(t, `size('hello')`) - expr := ast.NavigateCheckedAST(checked) +func TestMemberCall(t *testing.T) { + fac := ast.NewExprFactory() + expr := fac.NewMemberCall(1, "size", fac.NewLiteral(2, types.String("hello"))) if expr.Kind() != ast.CallKind { - t.Errorf("Kind() got %v, wanted CallKind", expr.Kind()) + t.Fatalf("NewCall() produced non-call expression: %v", expr.Kind()) } call := expr.AsCall() if call.FunctionName() != "size" { - t.Errorf("FunctionName() got %s, wanted size", call.FunctionName()) + t.Errorf("Function() got %s, wanted 'size''", call.FunctionName()) } - if call.Target() != nil { - t.Fatalf("Target() got non-nil, wanted nil") + if len(call.Args()) != 0 { + t.Errorf("Args() got %v, wanted zero", call.Args()) } - if len(call.Args()) != 1 { - t.Errorf("Args() got %v, wanted 1", call.Args()) + if !call.IsMemberFunction() { + t.Error("IsMemberFunction() got false, wanted true") } - arg := call.Args()[0] - if arg.Kind() != ast.LiteralKind { - t.Errorf("Kind() got %v, wanted literal", arg.Kind()) + if call.Target().ID() != 2 { + t.Errorf("Target() got %d, wanted 2", call.Target().ID()) } - if arg.AsLiteral().Equal(types.String("hello")) != types.True { - t.Errorf("AsLiteral() got %v, wanted 'hello'", arg.AsLiteral()) + expr.RenumberIDs(testIDGen(100)) + if expr.ID() != 101 { + t.Errorf("expr.ID() got %d, wanted 101 after RenumberIDs", expr.ID()) } - if call.ReturnType() != expr.Type() { - t.Errorf("ReturnType() got %v, wanted %v", call.ReturnType(), expr.Type()) + if expr.AsCall().Target().ID() != 102 { + t.Errorf("Target().ID() got %d, wanted 102 after RenumberIDs", expr.AsCall().Target().ID()) } - if call.ReturnType() != types.IntType { - t.Errorf("ReturnType() got %v, wanted int", call.ReturnType()) +} + +func TestCallNil(t *testing.T) { + expr := nilTestExpr(t) + call := expr.AsCall() + if call.FunctionName() != "" { + t.Errorf("FunctionName() got %s, wanted empty value", call.FunctionName()) } - if p, found := arg.Parent(); !found || p != expr { - t.Errorf("Parent() got %v, wanted %v", p, expr) + if len(call.Args()) != 0 { + t.Errorf("Args() got %v, wanted zero", call.Args()) } - sizeFn := ast.MatchDescendants(expr, ast.FunctionMatcher("size")) - if len(sizeFn) != 1 { - t.Errorf("ast.MatchDescendants() size function returned %v, wanted 1", sizeFn) + if call.IsMemberFunction() { + t.Error("IsMemberFunction() got true, wanted false") } - constantValues := ast.MatchDescendants(expr, ast.ConstantValueMatcher()) - if len(constantValues) != 1 { - t.Fatalf("ast.MatchDescendants() constant values returned %v, wanted 1 value", constantValues) + if call.Target().ID() != 0 { + t.Errorf("empty Target() got %d, wanted 0", call.Target().ID()) } - if constantValues[0].AsLiteral().Equal(types.String("hello")) != types.True { - t.Errorf("constantValues[0] got %v, wanted 'hello'", constantValues[0]) + expr.RenumberIDs(testIDGen(100)) + if expr.ID() != 1 { + t.Errorf("Renumbering an unspecified expression mutated the value: %v", expr) } } -func TestNavigableListExpr(t *testing.T) { - checked := mustTypeCheck(t, `[[1], [2]]`) - expr := ast.NavigateCheckedAST(checked) - if expr.Kind() != ast.ListKind { - t.Errorf("Kind() got %v, wanted ListKind", expr.Kind()) +func TestComprehension(t *testing.T) { + fac := ast.NewExprFactory() + expr := fac.NewComprehension(1, + fac.NewList(2, []ast.Expr{ + fac.NewLiteral(3, types.Int(1)), + fac.NewLiteral(4, types.Int(2)), + }, []int32{}), + "i", + "__result__", + fac.NewLiteral(5, types.False), + fac.NewLiteral(6, types.True), + fac.NewCall(7, "_||_", + fac.NewAccuIdent(8), + fac.NewCall(9, "<", fac.NewIdent(10, "i"), fac.NewLiteral(11, types.Int(3))), + ), + fac.NewAccuIdent(12), + ) + comp := expr.AsComprehension() + if comp.IterRange().Kind() != ast.ListKind { + t.Errorf("IterRange() got %v, wanted list", comp.IterRange().Kind()) } - list := expr.AsList() - if list.Size() != 2 { - t.Errorf("Size() got %d, wanted 2", list.Size()) + if comp.AccuInit().Kind() != ast.LiteralKind { + t.Errorf("AccuInit() got %v, wanted literal", comp.AccuInit().Kind()) } - if len(list.OptionalIndices()) != 0 { - t.Errorf("OptionalIndicies() returned %v, wanted none", list.OptionalIndices()) + if comp.LoopCondition().Kind() != ast.LiteralKind { + t.Errorf("LoopCondition() got %v, wanted literal", comp.LoopCondition().Kind()) + } + if comp.LoopStep().Kind() != ast.CallKind { + t.Errorf("LoopStep() got %v, wanted call", comp.LoopStep().Kind()) } - if len(list.Elements()) != 2 { - t.Errorf("Elements() returned %v, wanted 2", list.Elements()) + if comp.Result().Kind() != ast.IdentKind { + t.Errorf("Result() got %v, wanted identifier", comp.Result().Kind()) + } + expr.RenumberIDs(testIDGen(100)) + if expr.ID() != 101 { + t.Errorf("Renumbering an unspecified expression mutated the value: %v", expr) + } + comp = expr.AsComprehension() + if comp.IterRange().ID() != 102 { + t.Errorf("IterRange() got %v, wanted list", comp.IterRange().ID()) } - constantValues := ast.MatchDescendants(expr, ast.ConstantValueMatcher()) - if len(constantValues) != 5 { - t.Errorf("ast.MatchDescendants() constant values returned %v, wanted 5", constantValues) + if comp.AccuInit().ID() != 105 { + t.Errorf("AccuInit() got %v, wanted literal", comp.AccuInit().ID()) } - constantLists := ast.MatchSubset(constantValues, ast.KindMatcher(ast.ListKind)) - if len(constantLists) != 3 { - t.Errorf("ast.MatchSubset() constant lists returned %v, wanted 3", constantLists) + if comp.LoopCondition().ID() != 106 { + t.Errorf("LoopCondition() got %v, wanted literal", comp.LoopCondition().ID()) } - literals := ast.MatchSubset(constantValues, ast.KindMatcher(ast.LiteralKind)) - if len(literals) != 2 { - t.Errorf("ast.MatchSubset() literals returned %v, wanted 2", literals) + if comp.LoopStep().ID() != 107 { + t.Errorf("LoopStep() got %v, wanted call", comp.LoopStep().ID()) + } + if comp.Result().ID() != 112 { + t.Errorf("Result() got %v, wanted identifier", comp.Result().ID()) } } -func TestNavigableMapExpr(t *testing.T) { - checked := mustTypeCheck(t, `{'hello': 1}`) - expr := ast.NavigateCheckedAST(checked) - if expr.Kind() != ast.MapKind { - t.Errorf("Kind() got %v, wanted MapKind", expr.Kind()) +func TestComprehensionNil(t *testing.T) { + expr := nilTestExpr(t) + comp := expr.AsComprehension() + if comp.IterRange().Kind() != ast.UnspecifiedExprKind { + t.Errorf("IterRange() got %v, wanted unspecified", comp.IterRange().Kind()) } - m := expr.AsMap() - if m.Size() != 1 { - t.Errorf("Size() got %d, wanted 1", m.Size()) + if comp.AccuInit().Kind() != ast.UnspecifiedExprKind { + t.Errorf("AccuInit() got %v, wanted unspecified", comp.AccuInit().Kind()) } - if len(m.Entries()) != 1 { - t.Errorf("Entries() returned %v, wanted 1", m.Entries()) + if comp.LoopCondition().Kind() != ast.UnspecifiedExprKind { + t.Errorf("LoopCondition() got %v, wanted unspecified", comp.LoopCondition().Kind()) } - entry := m.Entries()[0] - if entry.IsOptional() { - t.Error("IsOptional() returned true, wanted false") + if comp.LoopStep().Kind() != ast.UnspecifiedExprKind { + t.Errorf("LoopStep() got %v, wanted unspecified", comp.LoopStep().Kind()) } - if entry.Key().AsLiteral().Equal(types.String("hello")) != types.True { - t.Errorf("Key() returned %v, wanted 'hello'", entry.Key().AsLiteral()) + if comp.Result().Kind() != ast.UnspecifiedExprKind { + t.Errorf("Result() got %v, wanted unspecified", comp.Result().Kind()) } - if entry.Value().AsLiteral().Equal(types.Int(1)) != types.True { - t.Errorf("Value() returned %v, wanted 1", entry.Value().AsLiteral()) + expr.RenumberIDs(testIDGen(100)) + if expr.ID() != 1 { + t.Errorf("Renumbering an unspecified expression mutated the value: %v", expr) } - descendants := ast.MatchDescendants(expr, ast.AllMatcher()) - if len(descendants) != 3 { - t.Errorf("ast.MatchDescendants() returned %v, wanted 3", descendants) +} + +func TestIdent(t *testing.T) { + fac := ast.NewExprFactory() + id := fac.NewIdent(1, "a") + if id.AsIdent() != "a" { + t.Errorf("id.AsIdent() got %s, wanted 'a'", id.AsIdent()) } - literals := ast.MatchSubset(descendants, ast.KindMatcher(ast.LiteralKind)) - if len(literals) != 2 { - t.Errorf("ast.MatchSubset() literals returned %v, wanted 2", literals) + id.RenumberIDs(testIDGen(50)) + if id.ID() != 51 { + t.Errorf("id.ID() got %d, wanted 51", id.ID()) } } -func TestNavigableStructExpr(t *testing.T) { - checked := mustTypeCheck(t, `google.expr.proto3.test.TestAllTypes{single_int32: 1}`) - expr := ast.NavigateCheckedAST(checked) - if expr.Kind() != ast.StructKind { - t.Errorf("Kind() got %v, wanted StructKind", expr.Kind()) +func TestIdentNil(t *testing.T) { + expr := nilTestExpr(t) + if expr.AsIdent() != "" { + t.Errorf("AsIdent() got %s, wanted ''", expr.AsIdent()) } - s := expr.AsStruct() - if s.TypeName() != "google.expr.proto3.test.TestAllTypes" { - t.Errorf("TypeName() got %s, wanted TestAllTypes", s.TypeName()) +} + +func TestList(t *testing.T) { + fac := ast.NewExprFactory() + expr := fac.NewList(20, []ast.Expr{fac.NewLiteral(21, types.OptionalOf(types.True))}, []int32{0}) + list := expr.AsList() + if list.Size() != 1 { + t.Errorf("list.Size() got %s, wanted 'a'", expr.AsIdent()) } - if len(s.Fields()) != 1 { - t.Errorf("Fields() returned %v, wanted 1", s.Fields()) + if list.Elements()[0].Kind() != ast.LiteralKind { + t.Errorf("list.Elements()[0] got %v, wanted true", list.Elements()[0]) } - field := s.Fields()[0] - if field.IsOptional() { - t.Error("IsOptional() returned true, wanted false") + if list.OptionalIndices()[0] != 0 { + t.Errorf("list.OptionalIndices()[0] got %d, wanted 0", list.OptionalIndices()[0]) } - if field.FieldName() != "single_int32" { - t.Errorf("FieldName() returned %s, wanted 'single_int32'", field.FieldName()) + expr.RenumberIDs(testIDGen(50)) + if expr.ID() != 51 { + t.Errorf("expr.ID() got %d, wanted 51", expr.ID()) } - if field.Value().AsLiteral().Equal(types.Int(1)) != types.True { - t.Errorf("Value() returned %v, wanted 1", field.Value().AsLiteral()) + if list.Elements()[0].ID() != 52 { + t.Errorf("list.Elements()[0].ID() got %d, wanted 52", list.Elements()[0].ID()) } - descendants := ast.MatchDescendants(expr, ast.AllMatcher()) - if len(descendants) != 2 { - t.Errorf("ast.MatchDescendants() returned %v, wanted 2", descendants) +} + +func TestListNil(t *testing.T) { + expr := nilTestExpr(t) + list := expr.AsList() + if list.Size() != 0 { + t.Errorf("nil list.Size() got %d, wanted 0", list.Size()) } - literals := ast.MatchSubset(descendants, ast.KindMatcher(ast.LiteralKind)) - if len(literals) != 1 { - t.Errorf("ast.MatchSubset() literals returned %v, wanted 1", literals) + if len(list.Elements()) != 0 { + t.Errorf("nil list.Elements() got %d, wanted 0", list.Elements()) } - if literals[0].AsLiteral().Equal(types.Int(1)) != types.True { - t.Errorf("Value().AsLiteral() got %v, wanted 1", literals[0].AsLiteral()) + if len(list.OptionalIndices()) != 0 { + t.Errorf("nil list.OptionalIndices() got %d, wanted 0", list.OptionalIndices()) } } -func TestNavigableComprehensionExpr(t *testing.T) { - checked := mustTypeCheck(t, `[true].exists(i, i)`) - expr := ast.NavigateCheckedAST(checked) - if expr.Kind() != ast.ComprehensionKind { - t.Errorf("Kind() got %v, wanted ComprehensionKind", expr.Kind()) +func TestLiteralNil(t *testing.T) { + expr := nilTestExpr(t) + if expr.AsLiteral() != nil { + t.Errorf("AsLiteral() got %v, wanted nil", expr.AsLiteral()) } - comp := expr.AsComprehension() - iterRange := comp.IterRange() - if len(ast.MatchSubset([]ast.NavigableExpr{iterRange}, ast.ConstantValueMatcher())) != 1 { - t.Errorf("IterRange() returned a non-constant list") +} + +func TestMap(t *testing.T) { + fac := ast.NewExprFactory() + expr := fac.NewMap(1, + []ast.EntryExpr{ + fac.NewMapEntry(2, fac.NewIdent(3, "a"), fac.NewLiteral(4, types.True), true), + fac.NewMapEntry(5, fac.NewIdent(6, "b"), fac.NewLiteral(7, types.False), false), + }) + m := expr.AsMap() + if m.Size() != 2 { + t.Errorf("map.Size() got %d, wanted 2", m.Size()) } - if comp.IterVar() != "i" { - t.Errorf("IterVar() got %s, wanted 'i'", comp.IterVar()) + expr.RenumberIDs(testIDGen(50)) + if expr.ID() != 51 { + t.Errorf("expr.ID() got %d, wanted 51", expr.ID()) } - if comp.AccuVar() != "__result__" { - t.Errorf("AccuVar() got %s, wanted '__result__'", comp.AccuVar()) + if m.Entries()[0].ID() != 52 { + t.Errorf("m.Entries()[0].ID() got %d, wanted 52", m.Entries()[0].ID()) } - if comp.AccuInit().AsLiteral() != types.False { - t.Errorf("AccuInit() returned %v, wanted false", comp.AccuInit().AsLiteral()) + entry := m.Entries()[0].AsMapEntry() + if key := entry.Key(); key.ID() != 53 { + t.Errorf("key.ID() got %d, wanted 53", key.ID()) } - if comp.Result().Kind() != ast.IdentKind { - t.Errorf("Result() returned %v, wanted ident", comp.Result()) + if val := entry.Value(); val.ID() != 54 { + t.Errorf("val.ID() got %d, wanted 54", val.ID()) } - if comp.LoopCondition().Kind() != ast.CallKind { - t.Errorf("LoopCondition() returned %v, wanted call", comp.LoopCondition()) + if !entry.IsOptional() { + t.Error("entry.IsOptional() got false, wanted true") } - if comp.LoopStep().Kind() != ast.CallKind { - t.Errorf("LoopStep() returned %v, wanted call", comp.LoopStep()) +} + +func TestMapNil(t *testing.T) { + expr := nilTestExpr(t) + m := expr.AsMap() + if m.Size() != 0 { + t.Errorf("nil map.Size() got %d, wanted 0", m.Size()) } - if comp.Result().AsIdent() != "__result__" { - t.Errorf("AsIdent() returned %v, wanted __result__", comp.Result().AsIdent()) + if len(m.Entries()) != 0 { + t.Errorf("nil map.Entries() got %d, wanted 0", m.Entries()) } } -func TestNavigableSelectExpr(t *testing.T) { - checked := mustTypeCheck(t, `msg.single_int32`) - expr := ast.NavigateCheckedAST(checked) - if expr.Kind() != ast.SelectKind { - t.Errorf("Kind() got %v, wanted SelectKind", expr.Kind()) +func TestMapEntryNil(t *testing.T) { + fac := ast.NewExprFactory() + field := fac.NewStructField(1, + "hello", fac.NewLiteral(3, types.String("world")), false) + + // intentionally convert the struct field to the wrong type + entry := field.AsMapEntry() + if entry.Key().Kind() != ast.UnspecifiedExprKind { + t.Errorf("entry.Key() got %s, wanted ''", entry.Key()) } + if entry.Value().Kind() != ast.UnspecifiedExprKind { + t.Errorf("entry.Value() got %v, wanted unspecified", entry.Value()) + } + if entry.IsOptional() { + t.Errorf("entry.IsOptional() got %v, wanted ''", entry.IsOptional()) + } +} + +func TestSelect(t *testing.T) { + fac := ast.NewExprFactory() + expr := fac.NewSelect(2, fac.NewIdent(1, "operand"), "field") sel := expr.AsSelect() - if sel.FieldName() != "single_int32" { - t.Errorf("FieldName() got %s, wanted single_int32", sel.FieldName()) + if sel.FieldName() != "field" { + t.Errorf("sel.FieldName() got %s, wanted 'field'", sel.FieldName()) } - if sel.Operand().Kind() != ast.IdentKind { - t.Fatalf("Operand() kind got %v, wanted ident", sel.Operand().Kind()) + if sel.Operand().AsIdent() != "operand" { + t.Errorf("sel.Operand().AsIdent() got %s, wanted 'operand'", sel.Operand().AsIdent()) } - if sel.Operand().AsIdent() != "msg" { - t.Errorf("Operand() got %v, wanted ident 'msg'", sel.Operand()) + expr.RenumberIDs(testIDGen(20)) + if sel.Operand().ID() != 22 { + t.Errorf("Operand().ID() got %d, wanted 22", sel.Operand().ID()) + } +} + +func TestSelectNil(t *testing.T) { + expr := nilTestExpr(t) + sel := expr.AsSelect() + if sel.FieldName() != "" { + t.Errorf("sel.FieldName() got %s, wanted ''", sel.FieldName()) } if sel.IsTestOnly() { - t.Error("IsTestOnly() got true, wanted false") + t.Error("sel.IsTestOnly() got true, wanted false") + } + if sel.Operand().Kind() != ast.UnspecifiedExprKind { + t.Errorf("sel.Operand() got %v, wanted unspecified", sel.Operand()) } } -func TestNavigableSelectExpr_TestOnly(t *testing.T) { - checked := mustTypeCheck(t, `has(msg.single_int32)`) - expr := ast.NavigateCheckedAST(checked) - if expr.Kind() != ast.SelectKind { - t.Errorf("Kind() got %v, wanted SelectKind", expr.Kind()) +func TestStruct(t *testing.T) { + fac := ast.NewExprFactory() + expr := fac.NewStruct(1, + "custom.StructType", + []ast.EntryExpr{ + fac.NewStructField(2, "a", fac.NewLiteral(3, types.True), true), + fac.NewStructField(4, "b", fac.NewLiteral(5, types.False), false), + }) + s := expr.AsStruct() + if s.TypeName() != "custom.StructType" { + t.Errorf("s.TypeName() got %q, wanted 'custom.StructType'", s.TypeName()) } - sel := expr.AsSelect() - if sel.FieldName() != "single_int32" { - t.Errorf("FieldName() got %s, wanted single_int32", sel.FieldName()) + if len(s.Fields()) != 2 { + t.Errorf("s.Fields() got %d, wanted 2", len(s.Fields())) + } + expr.RenumberIDs(testIDGen(50)) + if expr.ID() != 51 { + t.Errorf("expr.ID() got %d, wanted 51", expr.ID()) } - if sel.Operand().Kind() != ast.IdentKind { - t.Fatalf("Operand() kind got %v, wanted ident", sel.Operand().Kind()) + if s.Fields()[0].ID() != 52 { + t.Errorf("s.Fields()[0].ID() got %d, wanted 52", s.Fields()[0].ID()) } - if sel.Operand().AsIdent() != "msg" { - t.Errorf("Operand() got %v, wanted ident 'msg'", sel.Operand()) + field := s.Fields()[0].AsStructField() + if val := field.Value(); val.ID() != 53 { + t.Errorf("val.ID() got %d, wanted 53", val.ID()) } - if !sel.IsTestOnly() { - t.Error("IsTestOnly() got false, wanted true") + if field.Name() != "a" { + t.Errorf("field.Name() got %s, wanted 'a'", field.Name()) + } + if !field.IsOptional() { + t.Error("field.IsOptional() got false, wanted true") } } -func mustTypeCheck(t testing.TB, expr string) *ast.CheckedAST { - t.Helper() - p, err := parser.NewParser(parser.Macros(parser.AllMacros...), parser.EnableOptionalSyntax(true)) - if err != nil { - t.Fatalf("parser.NewParser() failed: %v", err) - } - exprSrc := common.NewTextSource(expr) - parsed, iss := p.Parse(exprSrc) - if len(iss.GetErrors()) != 0 { - t.Fatalf("Parse(%s) failed: %s", expr, iss.ToDisplayString()) - } - reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) - env := newTestEnv(t, containers.DefaultContainer, reg) - checked, iss := checker.Check(parsed, exprSrc, env) - if len(iss.GetErrors()) != 0 { - t.Fatalf("Check(%s) failed: %s", expr, iss.ToDisplayString()) - } - return checked +func TestStructNil(t *testing.T) { + expr := nilTestExpr(t) + s := expr.AsStruct() + if s.TypeName() != "" { + t.Errorf("nil struct.TypeName() got %s, wanted ''", s.TypeName()) + } + if len(s.Fields()) != 0 { + t.Errorf("nil struct.Fields() got %d, wanted 0", s.Fields()) + } } -func newTestRegistry(t testing.TB, msgs ...proto.Message) *types.Registry { - t.Helper() - reg, err := types.NewRegistry(msgs...) - if err != nil { - t.Fatalf("types.NewRegistry(%v) failed: %v", msgs, err) +func TestStructFieldNil(t *testing.T) { + fac := ast.NewExprFactory() + entry := fac.NewMapEntry(1, + fac.NewLiteral(2, types.String("hello")), + fac.NewLiteral(3, types.String("world")), + false) + + // intentionally convert the map entry to the wrong type + field := entry.AsStructField() + if field.Name() != "" { + t.Errorf("field.FieldName() got %s, wanted ''", field.Name()) + } + if field.Value().Kind() != ast.UnspecifiedExprKind { + t.Errorf("field.Value() got %v, wanted unspecified", field.Value()) + } + if field.IsOptional() { + t.Errorf("field.IsOptional() got %v, wanted ''", field.IsOptional()) } - return reg } -func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry) *checker.Env { +func nilTestExpr(t testing.TB) ast.Expr { t.Helper() - env, err := checker.NewEnv(cont, reg, checker.CrossTypeNumericComparisons(true)) - if err != nil { - t.Fatalf("checker.NewEnv(%v, %v) failed: %v", cont, reg, err) - } - err = env.AddIdents(stdlib.Types()...) - if err != nil { - t.Fatalf("env.Add(stdlib.Types()...) failed: %v", err) + fac := ast.NewExprFactory() + expr := fac.NewLiteral(1, types.NullValue) + expr.SetKindCase(nil) + if expr.Kind() != ast.UnspecifiedExprKind { + t.Fatalf("SetKindCase(nil) did not produce an unspecified expr kind: %v", expr.Kind()) } - err = env.AddFunctions(stdlib.Functions()...) - if err != nil { - t.Fatalf("env.Add(stdlib.Functions()...) failed: %v", err) + return expr +} + +func testIDGen(seed int64) ast.IDGenerator { + return func() int64 { + seed++ + return seed } - env.AddIdents(decls.NewVariable("msg", types.NewObjectType("google.expr.proto3.test.TestAllTypes"))) - return env } diff --git a/common/ast/factory.go b/common/ast/factory.go new file mode 100644 index 000000000..9e7c57869 --- /dev/null +++ b/common/ast/factory.go @@ -0,0 +1,193 @@ +package ast + +import "github.com/google/cel-go/common/types/ref" + +// ExprFactory interfaces defines a set of methods necessary for building native expression values. +type ExprFactory interface { + // NewCall creates an Expr value representing a global function call. + NewCall(id int64, function string, args ...Expr) Expr + + // NewComprehension creates an Expr value representing a comprehension over a value range. + NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCondition, loopStep, result Expr) Expr + + // NewMemberCall creates an Expr value representing a member function call. + NewMemberCall(id int64, function string, receiver Expr, args ...Expr) Expr + + // NewIdent creates an Expr value representing an identifier. + NewIdent(id int64, name string) Expr + + // NewAccuIdent creates an Expr value representing an accumulator identifier within a + //comprehension. + NewAccuIdent(id int64) Expr + + // NewLiteral creates an Expr value representing a literal value, such as a string or integer. + NewLiteral(id int64, value ref.Val) Expr + + // NewList creates an Expr value representing a list literal expression with optional indices. + // + // Optional indicies will typically be empty unless the CEL optional types are enabled. + NewList(id int64, elems []Expr, optIndices []int32) Expr + + // NewMap creates an Expr value representing a map literal expression + NewMap(id int64, entries []EntryExpr) Expr + + // NewMapEntry creates a MapEntry with a given key, value, and a flag indicating whether + // the key is optionally set. + NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr + + // NewPresenceTest creates an Expr representing a field presence test on an operand expression. + NewPresenceTest(id int64, operand Expr, field string) Expr + + // NewSelect creates an Expr representing a field selection on an operand expression. + NewSelect(id int64, operand Expr, field string) Expr + + // NewStruct creates an Expr value representing a struct literal with a given type name and a + // set of field initializers. + NewStruct(id int64, typeName string, fields []EntryExpr) Expr + + // NewStructField creates a StructField with a given field name, value, and a flag indicating + // whether the field is optionally set. + NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr + + // NewUnspecifiedExpr creates an empty expression node. + NewUnspecifiedExpr(id int64) Expr + + isExprFactory() +} + +type baseExprFactory struct{} + +// NewExprFactory creates an ExprFactory instance. +func NewExprFactory() ExprFactory { + return &baseExprFactory{} +} + +func (fac *baseExprFactory) NewCall(id int64, function string, args ...Expr) Expr { + if len(args) == 0 { + args = []Expr{} + } + return fac.newExpr( + id, + &baseCallExpr{ + function: function, + target: nilExpr, + args: args, + isMember: false, + }) +} + +func (fac *baseExprFactory) NewMemberCall(id int64, function string, target Expr, args ...Expr) Expr { + if len(args) == 0 { + args = []Expr{} + } + return fac.newExpr( + id, + &baseCallExpr{ + function: function, + target: target, + args: args, + isMember: true, + }) +} + +func (fac *baseExprFactory) NewComprehension(id int64, iterRange Expr, iterVar, accuVar string, accuInit, loopCond, loopStep, result Expr) Expr { + return fac.newExpr( + id, + &baseComprehensionExpr{ + iterRange: iterRange, + iterVar: iterVar, + accuVar: accuVar, + accuInit: accuInit, + loopCond: loopCond, + loopStep: loopStep, + result: result, + }) +} + +func (fac *baseExprFactory) NewIdent(id int64, name string) Expr { + return fac.newExpr(id, baseIdentExpr(name)) +} + +func (fac *baseExprFactory) NewAccuIdent(id int64) Expr { + return fac.NewIdent(id, "__result__") +} + +func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr { + return fac.newExpr(id, &baseLiteral{Val: value}) +} + +func (fac *baseExprFactory) NewList(id int64, elems []Expr, optIndices []int32) Expr { + return fac.newExpr(id, &baseListExpr{elements: elems, optIndices: optIndices}) +} + +func (fac *baseExprFactory) NewMap(id int64, entries []EntryExpr) Expr { + return fac.newExpr(id, &baseMapExpr{entries: entries}) +} + +func (fac *baseExprFactory) NewMapEntry(id int64, key, value Expr, isOptional bool) EntryExpr { + return fac.newEntryExpr( + id, + &baseMapEntry{ + key: key, + value: value, + isOptional: isOptional, + }) +} + +func (fac *baseExprFactory) NewPresenceTest(id int64, operand Expr, field string) Expr { + return fac.newExpr( + id, + &baseSelectExpr{ + operand: operand, + field: field, + testOnly: true, + }) +} + +func (fac *baseExprFactory) NewSelect(id int64, operand Expr, field string) Expr { + return fac.newExpr( + id, + &baseSelectExpr{ + operand: operand, + field: field, + }) +} + +func (fac *baseExprFactory) NewStruct(id int64, typeName string, fields []EntryExpr) Expr { + return fac.newExpr( + id, + &baseStructExpr{ + typeName: typeName, + fields: fields, + }) +} + +func (fac *baseExprFactory) NewStructField(id int64, field string, value Expr, isOptional bool) EntryExpr { + return fac.newEntryExpr( + id, + &baseStructField{ + field: field, + value: value, + isOptional: isOptional, + }) +} + +func (fac *baseExprFactory) NewUnspecifiedExpr(id int64) Expr { + return fac.newExpr(id, nil) +} + +func (*baseExprFactory) isExprFactory() {} + +func (fac *baseExprFactory) newExpr(id int64, e exprKindCase) Expr { + return &expr{ + id: id, + exprKindCase: e, + } +} + +func (fac *baseExprFactory) newEntryExpr(id int64, e entryExprKindCase) EntryExpr { + return &entryExpr{ + id: id, + entryExprKindCase: e, + } +} diff --git a/common/ast/navigable.go b/common/ast/navigable.go new file mode 100644 index 000000000..27b0fcbb6 --- /dev/null +++ b/common/ast/navigable.go @@ -0,0 +1,489 @@ +package ast + +import ( + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" +) + +// NavigableExpr represents the base navigable expression value with methods to inspect the +// parent and child expressions. +type NavigableExpr interface { + Expr + + // Type of the expression. + // + // If the expression is type-checked, the type check metadata is returned. If the expression + // has not been type-checked, the types.DynType value is returned. + Type() *types.Type + + // Parent returns the parent expression node, if one exists. + Parent() (NavigableExpr, bool) + + // Children returns a list of child expression nodes. + Children() []NavigableExpr +} + +// NavigateAST converts an AST to a NavigableExpr +func NavigateAST(ast *AST) NavigableExpr { + return newNavigableExpr(ast, nil, ast.Expr()) +} + +// ExprMatcher takes a NavigableExpr in and indicates whether the value is a match. +// +// This function type should be use with the `Match` and `MatchList` calls. +type ExprMatcher func(NavigableExpr) bool + +// ConstantValueMatcher returns an ExprMatcher which will return true if the input NavigableExpr +// is comprised of all constant values, such as a simple literal or even list and map literal. +func ConstantValueMatcher() ExprMatcher { + return matchIsConstantValue +} + +// KindMatcher returns an ExprMatcher which will return true if the input NavigableExpr.Kind() matches +// the specified `kind`. +func KindMatcher(kind ExprKind) ExprMatcher { + return func(e NavigableExpr) bool { + return e.Kind() == kind + } +} + +// FunctionMatcher returns an ExprMatcher which will match NavigableExpr nodes of CallKind type whose +// function name is equal to `funcName`. +func FunctionMatcher(funcName string) ExprMatcher { + return func(e NavigableExpr) bool { + if e.Kind() != CallKind { + return false + } + return e.AsCall().FunctionName() == funcName + } +} + +// AllMatcher returns true for all descendants of a NavigableExpr, effectively flattening them into a list. +// +// Such a result would work well with subsequent MatchList calls. +func AllMatcher() ExprMatcher { + return func(NavigableExpr) bool { + return true + } +} + +// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values of the +// descendants which match. +func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr { + return matchListInternal([]NavigableExpr{expr}, matcher, true) +} + +// MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a +// subset of NavigableExpr values which match. +func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr { + visit := make([]NavigableExpr, len(exprs)) + copy(visit, exprs) + return matchListInternal(visit, matcher, false) +} + +func matchListInternal(visit []NavigableExpr, matcher ExprMatcher, visitDescendants bool) []NavigableExpr { + var matched []NavigableExpr + for len(visit) != 0 { + e := visit[0] + if matcher(e) { + matched = append(matched, e) + } + if visitDescendants { + visit = append(visit[1:], e.Children()...) + } else { + visit = visit[1:] + } + } + return matched +} + +func matchIsConstantValue(e NavigableExpr) bool { + if e.Kind() == LiteralKind { + return true + } + if e.Kind() == StructKind || e.Kind() == MapKind || e.Kind() == ListKind { + for _, child := range e.Children() { + if !matchIsConstantValue(child) { + return false + } + } + return true + } + return false +} + +func newNavigableExpr(ast *AST, parent NavigableExpr, expr Expr) NavigableExpr { + nav := &navigableExprImpl{ + Expr: expr, + ast: ast, + parent: parent, + createChildren: getChildFactory(expr), + } + return nav +} + +type navigableExprImpl struct { + Expr + ast *AST + parent NavigableExpr + createChildren childFactory +} + +func (nav *navigableExprImpl) Parent() (NavigableExpr, bool) { + if nav.parent != nil { + return nav.parent, true + } + return nil, false +} + +func (nav *navigableExprImpl) ID() int64 { + return nav.Expr.ID() +} + +func (nav *navigableExprImpl) Kind() ExprKind { + return nav.Expr.Kind() +} + +func (nav *navigableExprImpl) Type() *types.Type { + return nav.ast.GetType(nav.ID()) +} + +func (nav *navigableExprImpl) Children() []NavigableExpr { + return nav.createChildren(nav) +} + +func (nav *navigableExprImpl) AsCall() CallExpr { + return navigableCallImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsComprehension() ComprehensionExpr { + return navigableComprehensionImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsIdent() string { + return nav.Expr.AsIdent() +} + +func (nav *navigableExprImpl) AsList() ListExpr { + return navigableListImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsLiteral() ref.Val { + return nav.Expr.AsLiteral() +} + +func (nav *navigableExprImpl) AsMap() MapExpr { + return navigableMapImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsSelect() SelectExpr { + return navigableSelectImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) AsStruct() StructExpr { + return navigableStructImpl{navigableExprImpl: nav} +} + +func (nav *navigableExprImpl) createChild(e Expr) NavigableExpr { + return newNavigableExpr(nav.ast, nav, e) +} + +func (nav *navigableExprImpl) isExpr() {} + +type navigableCallImpl struct { + *navigableExprImpl +} + +func (call navigableCallImpl) FunctionName() string { + return call.Expr.AsCall().FunctionName() +} + +func (call navigableCallImpl) IsMemberFunction() bool { + return call.Expr.AsCall().IsMemberFunction() +} + +func (call navigableCallImpl) Target() Expr { + t := call.Expr.AsCall().Target() + if t != nil { + return call.createChild(t) + } + return nil +} + +func (call navigableCallImpl) Args() []Expr { + args := call.Expr.AsCall().Args() + navArgs := make([]Expr, len(args)) + for i, a := range args { + navArgs[i] = call.createChild(a) + } + return navArgs +} + +type navigableComprehensionImpl struct { + *navigableExprImpl +} + +func (comp navigableComprehensionImpl) IterRange() Expr { + return comp.createChild(comp.Expr.AsComprehension().IterRange()) +} + +func (comp navigableComprehensionImpl) IterVar() string { + return comp.Expr.AsComprehension().IterVar() +} + +func (comp navigableComprehensionImpl) AccuVar() string { + return comp.Expr.AsComprehension().AccuVar() +} + +func (comp navigableComprehensionImpl) AccuInit() Expr { + return comp.createChild(comp.Expr.AsComprehension().AccuInit()) +} + +func (comp navigableComprehensionImpl) LoopCondition() Expr { + return comp.createChild(comp.Expr.AsComprehension().LoopCondition()) +} + +func (comp navigableComprehensionImpl) LoopStep() Expr { + return comp.createChild(comp.Expr.AsComprehension().LoopStep()) +} + +func (comp navigableComprehensionImpl) Result() Expr { + return comp.createChild(comp.Expr.AsComprehension().Result()) +} + +type navigableListImpl struct { + *navigableExprImpl +} + +func (l navigableListImpl) Elements() []Expr { + pbElems := l.Expr.AsList().Elements() + elems := make([]Expr, len(pbElems)) + for i := 0; i < len(pbElems); i++ { + elems[i] = l.createChild(pbElems[i]) + } + return elems +} + +func (l navigableListImpl) OptionalIndices() []int32 { + return l.Expr.AsList().OptionalIndices() +} + +func (l navigableListImpl) Size() int { + return l.Expr.AsList().Size() +} + +type navigableMapImpl struct { + *navigableExprImpl +} + +func (m navigableMapImpl) Entries() []EntryExpr { + mapExpr := m.Expr.AsMap() + entries := make([]EntryExpr, len(mapExpr.Entries())) + for i, e := range mapExpr.Entries() { + entry := e.AsMapEntry() + entries[i] = &entryExpr{ + id: e.ID(), + entryExprKindCase: navigableEntryImpl{ + key: m.createChild(entry.Key()), + val: m.createChild(entry.Value()), + isOpt: entry.IsOptional(), + }, + } + } + return entries +} + +func (m navigableMapImpl) Size() int { + return m.Expr.AsMap().Size() +} + +type navigableEntryImpl struct { + key NavigableExpr + val NavigableExpr + isOpt bool +} + +func (e navigableEntryImpl) Kind() EntryExprKind { + return MapEntryKind +} + +func (e navigableEntryImpl) Key() Expr { + return e.key +} + +func (e navigableEntryImpl) Value() Expr { + return e.val +} + +func (e navigableEntryImpl) IsOptional() bool { + return e.isOpt +} + +func (e navigableEntryImpl) renumberIDs(IDGenerator) {} + +func (e navigableEntryImpl) isEntryExpr() {} + +type navigableSelectImpl struct { + *navigableExprImpl +} + +func (sel navigableSelectImpl) FieldName() string { + return sel.Expr.AsSelect().FieldName() +} + +func (sel navigableSelectImpl) IsTestOnly() bool { + return sel.Expr.AsSelect().IsTestOnly() +} + +func (sel navigableSelectImpl) Operand() Expr { + return sel.createChild(sel.Expr.AsSelect().Operand()) +} + +type navigableStructImpl struct { + *navigableExprImpl +} + +func (s navigableStructImpl) TypeName() string { + return s.Expr.AsStruct().TypeName() +} + +func (s navigableStructImpl) Fields() []EntryExpr { + fieldInits := s.Expr.AsStruct().Fields() + fields := make([]EntryExpr, len(fieldInits)) + for i, f := range fieldInits { + field := f.AsStructField() + fields[i] = &entryExpr{ + id: f.ID(), + entryExprKindCase: navigableFieldImpl{ + name: field.Name(), + val: s.createChild(field.Value()), + isOpt: field.IsOptional(), + }, + } + } + return fields +} + +type navigableFieldImpl struct { + name string + val NavigableExpr + isOpt bool +} + +func (f navigableFieldImpl) Kind() EntryExprKind { + return StructFieldKind +} + +func (f navigableFieldImpl) Name() string { + return f.name +} + +func (f navigableFieldImpl) Value() Expr { + return f.val +} + +func (f navigableFieldImpl) IsOptional() bool { + return f.isOpt +} + +func (f navigableFieldImpl) renumberIDs(IDGenerator) {} + +func (f navigableFieldImpl) isEntryExpr() {} + +func getChildFactory(expr Expr) childFactory { + if expr == nil { + return noopFactory + } + switch expr.Kind() { + case LiteralKind: + return noopFactory + case IdentKind: + return noopFactory + case SelectKind: + return selectFactory + case CallKind: + return callArgFactory + case ListKind: + return listElemFactory + case MapKind: + return mapEntryFactory + case StructKind: + return structEntryFactory + case ComprehensionKind: + return comprehensionFactory + default: + return noopFactory + } +} + +type childFactory func(*navigableExprImpl) []NavigableExpr + +func noopFactory(*navigableExprImpl) []NavigableExpr { + return nil +} + +func selectFactory(nav *navigableExprImpl) []NavigableExpr { + return []NavigableExpr{nav.createChild(nav.AsSelect().Operand())} +} + +func callArgFactory(nav *navigableExprImpl) []NavigableExpr { + call := nav.Expr.AsCall() + argCount := len(call.Args()) + if call.IsMemberFunction() { + argCount++ + } + navExprs := make([]NavigableExpr, argCount) + i := 0 + if call.IsMemberFunction() { + navExprs[i] = nav.createChild(call.Target()) + i++ + } + for _, arg := range call.Args() { + navExprs[i] = nav.createChild(arg) + i++ + } + return navExprs +} + +func listElemFactory(nav *navigableExprImpl) []NavigableExpr { + l := nav.Expr.AsList() + navExprs := make([]NavigableExpr, len(l.Elements())) + for i, e := range l.Elements() { + navExprs[i] = nav.createChild(e) + } + return navExprs +} + +func structEntryFactory(nav *navigableExprImpl) []NavigableExpr { + s := nav.Expr.AsStruct() + entries := make([]NavigableExpr, len(s.Fields())) + for i, e := range s.Fields() { + f := e.AsStructField() + entries[i] = nav.createChild(f.Value()) + } + return entries +} + +func mapEntryFactory(nav *navigableExprImpl) []NavigableExpr { + m := nav.Expr.AsMap() + entries := make([]NavigableExpr, len(m.Entries())*2) + j := 0 + for _, e := range m.Entries() { + mapEntry := e.AsMapEntry() + entries[j] = nav.createChild(mapEntry.Key()) + entries[j+1] = nav.createChild(mapEntry.Value()) + j += 2 + } + return entries +} + +func comprehensionFactory(nav *navigableExprImpl) []NavigableExpr { + compre := nav.Expr.AsComprehension() + return []NavigableExpr{ + nav.createChild(compre.IterRange()), + nav.createChild(compre.AccuInit()), + nav.createChild(compre.LoopCondition()), + nav.createChild(compre.LoopStep()), + nav.createChild(compre.Result()), + } +} diff --git a/common/ast/navigable_test.go b/common/ast/navigable_test.go new file mode 100644 index 000000000..0145aadf8 --- /dev/null +++ b/common/ast/navigable_test.go @@ -0,0 +1,446 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast_test + +import ( + "testing" + + "google.golang.org/protobuf/proto" + + "github.com/google/cel-go/checker" + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/decls" + "github.com/google/cel-go/common/stdlib" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/parser" + + proto3pb "github.com/google/cel-go/test/proto3pb" +) + +func TestNavigateExpr(t *testing.T) { + tests := []struct { + expr string + descendantCount int + callCount int + }{ + { + expr: `'a' == 'b'`, + descendantCount: 3, + callCount: 1, + }, + { + expr: `'a'.size()`, + descendantCount: 2, + callCount: 1, + }, + { + expr: `[1, 2, 3]`, + descendantCount: 4, + callCount: 0, + }, + { + expr: `[1, 2, 3][0]`, + descendantCount: 6, + callCount: 1, + }, + { + expr: `{1u: 'hello'}`, + descendantCount: 3, + callCount: 0, + }, + { + expr: `{'hello': 'world'}.hello`, + descendantCount: 4, + callCount: 0, + }, + { + expr: `type(1) == int`, + descendantCount: 4, + callCount: 2, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{single_int32: 1}`, + descendantCount: 2, + callCount: 0, + }, + { + expr: `[true].exists(i, i)`, + descendantCount: 11, // 2 for iter range, 1 for accu init, 4 for loop condition, 3 for loop step, 1 for result + callCount: 3, // @not_strictly_false(!result), accu_init || i + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked := mustTypeCheck(t, tc.expr) + nav := ast.NavigateAST(checked) + descendants := ast.MatchDescendants(nav, ast.AllMatcher()) + if len(descendants) != tc.descendantCount { + t.Errorf("ast.MatchDescendants(%v) got %d descendants, wanted %d", checked.Expr(), len(descendants), tc.descendantCount) + } + calls := ast.MatchSubset(descendants, ast.KindMatcher(ast.CallKind)) + if len(calls) != tc.callCount { + t.Errorf("ast.MatchSubset(%v) got %d calls, wanted %d", checked.Expr(), len(calls), tc.callCount) + } + }) + } +} + +func TestNavigateExprNilSafety(t *testing.T) { + tests := []struct { + name string + e ast.NavigableExpr + }{ + { + name: "nil expr", + e: ast.NavigateAST(ast.NewAST(nil, nil)), + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + e := tc.e + if e.ID() != 0 { + t.Errorf("ID() got %d, wanted 0", e.ID()) + } + if e.Kind() != ast.UnspecifiedExprKind { + t.Errorf("Kind() got %v, wanted unspecified kind", e.Kind()) + } + if e.Type() != types.DynType { + t.Errorf("Type() got %v, wanted types.DynType", e.Type()) + } + if p, found := e.Parent(); found { + t.Errorf("Parent() got %v, wanted not found", p) + } + if len(e.Children()) != 0 { + t.Errorf("Children() got %v, wanted none", e.Children()) + } + if e.AsLiteral() != nil { + t.Errorf("AsLiteral() got %v, wanted nil", e.AsLiteral()) + } + if e.AsCall() == nil { + t.Errorf("AsCall() got nil, wanted non-nil for safe traversal") + } + if e.AsComprehension() == nil { + t.Errorf("AsComprehension() got nil, wanted non-nil for safe traversal") + } + }) + } +} + +func TestNavigableCallExprMember(t *testing.T) { + checked := mustTypeCheck(t, `'hello'.size()`) + expr := ast.NavigateAST(checked) + if expr.Kind() != ast.CallKind { + t.Errorf("Kind() got %v, wanted CallKind", expr.Kind()) + } + call := expr.AsCall() + if call.FunctionName() != "size" { + t.Errorf("FunctionName() got %s, wanted size", call.FunctionName()) + } + if call.Target() == nil { + t.Fatalf("Target() got nil, wanted non-nil") + } + if len(call.Args()) != 0 { + t.Errorf("Args() got %v, wanted 0", call.Args()) + } + target := call.Target() + if target.Kind() != ast.LiteralKind { + t.Errorf("Kind() got %v, wanted literal", target.Kind()) + } + if target.AsLiteral().Equal(types.String("hello")) != types.True { + t.Errorf("AsLiteral() got %v, wanted 'hello'", target.AsLiteral()) + } + if p, found := target.(ast.NavigableExpr).Parent(); !found || p != expr { + t.Errorf("Parent() got %v, wanted %v", p, expr) + } + sizeFn := ast.MatchDescendants(expr, ast.FunctionMatcher("size")) + if len(sizeFn) != 1 { + t.Errorf("ast.MatchDescendants() size function returned %v, wanted 1", sizeFn) + } + constantValues := ast.MatchDescendants(expr, ast.ConstantValueMatcher()) + if len(constantValues) != 1 { + t.Fatalf("ast.MatchDescendants() constant values returned %v, wanted 1 value", constantValues) + } + if constantValues[0].AsLiteral().Equal(types.String("hello")) != types.True { + t.Errorf("constantValues[0] got %v, wanted 'hello'", constantValues[0]) + } +} + +func TestNavigableCallExprGlobal(t *testing.T) { + checked := mustTypeCheck(t, `size('hello')`) + expr := ast.NavigateAST(checked) + if expr.Kind() != ast.CallKind { + t.Errorf("Kind() got %v, wanted CallKind", expr.Kind()) + } + call := expr.AsCall() + if call.FunctionName() != "size" { + t.Errorf("FunctionName() got %s, wanted size", call.FunctionName()) + } + if call.IsMemberFunction() { + t.Fatal("IsMemberFunction() returned true, wanted false") + } + if len(call.Args()) != 1 { + t.Errorf("Args() got %v, wanted 1", call.Args()) + } + arg := call.Args()[0] + if arg.Kind() != ast.LiteralKind { + t.Errorf("Kind() got %v, wanted literal", arg.Kind()) + } + if arg.AsLiteral().Equal(types.String("hello")) != types.True { + t.Errorf("AsLiteral() got %v, wanted 'hello'", arg.AsLiteral()) + } + if p, found := arg.(ast.NavigableExpr).Parent(); !found || p != expr { + t.Errorf("Parent() got %v, wanted %v", p, expr) + } + sizeFn := ast.MatchDescendants(expr, ast.FunctionMatcher("size")) + if len(sizeFn) != 1 { + t.Errorf("ast.MatchDescendants() size function returned %v, wanted 1", sizeFn) + } + constantValues := ast.MatchDescendants(expr, ast.ConstantValueMatcher()) + if len(constantValues) != 1 { + t.Fatalf("ast.MatchDescendants() constant values returned %v, wanted 1 value", constantValues) + } + if constantValues[0].AsLiteral().Equal(types.String("hello")) != types.True { + t.Errorf("constantValues[0] got %v, wanted 'hello'", constantValues[0]) + } +} + +func TestNavigableListExpr(t *testing.T) { + checked := mustTypeCheck(t, `[[1], [2]]`) + expr := ast.NavigateAST(checked) + if expr.Kind() != ast.ListKind { + t.Errorf("Kind() got %v, wanted ListKind", expr.Kind()) + } + list := expr.AsList() + if list.Size() != 2 { + t.Errorf("Size() got %d, wanted 2", list.Size()) + } + if len(list.OptionalIndices()) != 0 { + t.Errorf("OptionalIndicies() returned %v, wanted none", list.OptionalIndices()) + } + if len(list.Elements()) != 2 { + t.Errorf("Elements() returned %v, wanted 2", list.Elements()) + } + constantValues := ast.MatchDescendants(expr, ast.ConstantValueMatcher()) + if len(constantValues) != 5 { + t.Errorf("ast.MatchDescendants() constant values returned %v, wanted 5", constantValues) + } + constantLists := ast.MatchSubset(constantValues, ast.KindMatcher(ast.ListKind)) + if len(constantLists) != 3 { + t.Errorf("ast.MatchSubset() constant lists returned %v, wanted 3", constantLists) + } + literals := ast.MatchSubset(constantValues, ast.KindMatcher(ast.LiteralKind)) + if len(literals) != 2 { + t.Errorf("ast.MatchSubset() literals returned %v, wanted 2", literals) + } +} + +func TestNavigableMapExpr(t *testing.T) { + checked := mustTypeCheck(t, `{'hello': 1}`) + expr := ast.NavigateAST(checked) + if expr.Kind() != ast.MapKind { + t.Errorf("Kind() got %v, wanted MapKind", expr.Kind()) + } + m := expr.AsMap() + if m.Size() != 1 { + t.Errorf("Size() got %d, wanted 1", m.Size()) + } + if len(m.Entries()) != 1 { + t.Errorf("Entries() returned %v, wanted 1", m.Entries()) + } + e := m.Entries()[0] + entry := e.AsMapEntry() + if entry.IsOptional() { + t.Error("IsOptional() returned true, wanted false") + } + if entry.Key().AsLiteral().Equal(types.String("hello")) != types.True { + t.Errorf("Key() returned %v, wanted 'hello'", entry.Key().AsLiteral()) + } + if entry.Value().AsLiteral().Equal(types.Int(1)) != types.True { + t.Errorf("Value() returned %v, wanted 1", entry.Value().AsLiteral()) + } + descendants := ast.MatchDescendants(expr, ast.AllMatcher()) + if len(descendants) != 3 { + t.Errorf("ast.MatchDescendants() returned %v, wanted 3", descendants) + } + literals := ast.MatchSubset(descendants, ast.KindMatcher(ast.LiteralKind)) + if len(literals) != 2 { + t.Errorf("ast.MatchSubset() literals returned %v, wanted 2", literals) + } +} + +func TestNavigableStructExpr(t *testing.T) { + checked := mustTypeCheck(t, `google.expr.proto3.test.TestAllTypes{single_int32: 1}`) + expr := ast.NavigateAST(checked) + if expr.Kind() != ast.StructKind { + t.Errorf("Kind() got %v, wanted StructKind", expr.Kind()) + } + s := expr.AsStruct() + if s.TypeName() != "google.expr.proto3.test.TestAllTypes" { + t.Errorf("TypeName() got %s, wanted TestAllTypes", s.TypeName()) + } + if len(s.Fields()) != 1 { + t.Errorf("Fields() returned %v, wanted 1", s.Fields()) + } + f := s.Fields()[0] + field := f.AsStructField() + if field.IsOptional() { + t.Error("IsOptional() returned true, wanted false") + } + if field.Name() != "single_int32" { + t.Errorf("Name() returned %s, wanted 'single_int32'", field.Name()) + } + if field.Value().AsLiteral().Equal(types.Int(1)) != types.True { + t.Errorf("Value() returned %v, wanted 1", field.Value().AsLiteral()) + } + descendants := ast.MatchDescendants(expr, ast.AllMatcher()) + if len(descendants) != 2 { + t.Errorf("ast.MatchDescendants() returned %v, wanted 2", descendants) + } + literals := ast.MatchSubset(descendants, ast.KindMatcher(ast.LiteralKind)) + if len(literals) != 1 { + t.Errorf("ast.MatchSubset() literals returned %v, wanted 1", literals) + } + if literals[0].AsLiteral().Equal(types.Int(1)) != types.True { + t.Errorf("Value().AsLiteral() got %v, wanted 1", literals[0].AsLiteral()) + } +} + +func TestNavigableComprehensionExpr(t *testing.T) { + checked := mustTypeCheck(t, `[true].exists(i, i)`) + expr := ast.NavigateAST(checked) + if expr.Kind() != ast.ComprehensionKind { + t.Errorf("Kind() got %v, wanted ComprehensionKind", expr.Kind()) + } + comp := expr.AsComprehension() + iterRange := comp.IterRange() + if len(ast.MatchSubset([]ast.NavigableExpr{iterRange.(ast.NavigableExpr)}, ast.ConstantValueMatcher())) != 1 { + t.Errorf("IterRange() returned a non-constant list") + } + if comp.IterVar() != "i" { + t.Errorf("IterVar() got %s, wanted 'i'", comp.IterVar()) + } + if comp.AccuVar() != "__result__" { + t.Errorf("AccuVar() got %s, wanted '__result__'", comp.AccuVar()) + } + if comp.AccuInit().AsLiteral() != types.False { + t.Errorf("AccuInit() returned %v, wanted false", comp.AccuInit().AsLiteral()) + } + if comp.Result().Kind() != ast.IdentKind { + t.Errorf("Result() returned %v, wanted ident", comp.Result()) + } + if comp.LoopCondition().Kind() != ast.CallKind { + t.Errorf("LoopCondition() returned %v, wanted call", comp.LoopCondition()) + } + if comp.LoopStep().Kind() != ast.CallKind { + t.Errorf("LoopStep() returned %v, wanted call", comp.LoopStep()) + } + if comp.Result().AsIdent() != "__result__" { + t.Errorf("AsIdent() returned %v, wanted __result__", comp.Result().AsIdent()) + } +} + +func TestNavigableSelectExpr(t *testing.T) { + checked := mustTypeCheck(t, `msg.single_int32`) + expr := ast.NavigateAST(checked) + if expr.Kind() != ast.SelectKind { + t.Errorf("Kind() got %v, wanted SelectKind", expr.Kind()) + } + sel := expr.AsSelect() + if sel.FieldName() != "single_int32" { + t.Errorf("FieldName() got %s, wanted single_int32", sel.FieldName()) + } + if sel.Operand().Kind() != ast.IdentKind { + t.Fatalf("Operand() kind got %v, wanted ident", sel.Operand().Kind()) + } + if sel.Operand().AsIdent() != "msg" { + t.Errorf("Operand() got %v, wanted ident 'msg'", sel.Operand()) + } + if sel.IsTestOnly() { + t.Error("IsTestOnly() got true, wanted false") + } +} + +func TestNavigableSelectExpr_TestOnly(t *testing.T) { + checked := mustTypeCheck(t, `has(msg.single_int32)`) + expr := ast.NavigateAST(checked) + if expr.Kind() != ast.SelectKind { + t.Errorf("Kind() got %v, wanted SelectKind", expr.Kind()) + } + sel := expr.AsSelect() + if sel.FieldName() != "single_int32" { + t.Errorf("FieldName() got %s, wanted single_int32", sel.FieldName()) + } + if sel.Operand().Kind() != ast.IdentKind { + t.Fatalf("Operand() kind got %v, wanted ident", sel.Operand().Kind()) + } + if sel.Operand().AsIdent() != "msg" { + t.Errorf("Operand() got %v, wanted ident 'msg'", sel.Operand()) + } + if !sel.IsTestOnly() { + t.Error("IsTestOnly() got false, wanted true") + } +} + +func mustTypeCheck(t testing.TB, expr string) *ast.AST { + t.Helper() + p, err := parser.NewParser(parser.Macros(parser.AllMacros...), parser.EnableOptionalSyntax(true)) + if err != nil { + t.Fatalf("parser.NewParser() failed: %v", err) + } + exprSrc := common.NewTextSource(expr) + parsed, iss := p.Parse(exprSrc) + if len(iss.GetErrors()) != 0 { + t.Fatalf("Parse(%s) failed: %s", expr, iss.ToDisplayString()) + } + reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) + env := newTestEnv(t, containers.DefaultContainer, reg) + checked, iss := checker.Check(parsed, exprSrc, env) + if len(iss.GetErrors()) != 0 { + t.Fatalf("Check(%s) failed: %s", expr, iss.ToDisplayString()) + } + return checked +} + +func newTestRegistry(t testing.TB, msgs ...proto.Message) *types.Registry { + t.Helper() + reg, err := types.NewRegistry(msgs...) + if err != nil { + t.Fatalf("types.NewRegistry(%v) failed: %v", msgs, err) + } + return reg +} + +func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry) *checker.Env { + t.Helper() + env, err := checker.NewEnv(cont, reg, checker.CrossTypeNumericComparisons(true)) + if err != nil { + t.Fatalf("checker.NewEnv(%v, %v) failed: %v", cont, reg, err) + } + err = env.AddIdents(stdlib.Types()...) + if err != nil { + t.Fatalf("env.Add(stdlib.Types()...) failed: %v", err) + } + err = env.AddFunctions(stdlib.Functions()...) + if err != nil { + t.Fatalf("env.Add(stdlib.Functions()...) failed: %v", err) + } + env.AddIdents(decls.NewVariable("msg", types.NewObjectType("google.expr.proto3.test.TestAllTypes"))) + return env +} diff --git a/common/types/provider.go b/common/types/provider.go index e80b4622e..e9bda552c 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -154,7 +154,7 @@ func (p *Registry) EnumValue(enumName string) ref.Val { return Int(enumVal.Value()) } -// FieldFieldType returns the field type for a checked type value. Returns false if +// FindFieldType returns the field type for a checked type value. Returns false if // the field could not be found. // // Deprecated: use FindStructFieldType @@ -173,7 +173,7 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType, GetFrom: field.GetFrom}, true } -// FieldStructFieldType returns the field type for a checked type value. Returns +// FindStructFieldType returns the field type for a checked type value. Returns // false if the field could not be found. func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) { msgType, found := p.pbdb.DescribeType(structType) diff --git a/ext/formatting.go b/ext/formatting.go index 178316959..2f35b996c 100644 --- a/ext/formatting.go +++ b/ext/formatting.go @@ -28,6 +28,7 @@ import ( "github.com/google/cel-go/cel" "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" @@ -419,16 +420,16 @@ func (stringFormatValidator) Configure(config cel.MutableValidatorConfig) error // Validate parses all literal format strings and type checks the format clause against the argument // at the corresponding ordinal within the list literal argument to the function, if one is specified. -func (stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.CheckedAST, iss *cel.Issues) { - root := ast.NavigateCheckedAST(a) - formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs) +func (stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *ast.AST, iss *cel.Issues) { + root := ast.NavigateAST(a) + formatCallExprs := ast.MatchDescendants(root, matchConstantFormatStringWithListLiteralArgs(a)) for _, e := range formatCallExprs { call := e.AsCall() formatStr := call.Target().AsLiteral().Value().(string) args := call.Args()[0].AsList().Elements() formatCheck := &stringFormatChecker{ - args: args, - typeMap: a.TypeMap, + args: args, + ast: a, } // use a placeholder locale, since locale doesn't affect syntax _, err := parseFormatString(formatStr, formatCheck, formatCheck, "en_US") @@ -460,33 +461,47 @@ func getErrorExprID(id int64, err error) int64 { // matchConstantFormatStringWithListLiteralArgs matches all valid expression nodes for string // format checking. -func matchConstantFormatStringWithListLiteralArgs(e ast.NavigableExpr) bool { - if e.Kind() != ast.CallKind { - return false - } - call := e.AsCall() - // TODO: expose the OverloadID() method on the NavigableCallExpr to refine this check. - if !call.IsMemberFunction() || call.FunctionName() != "format" { - return false - } - formatString := call.Target() - if formatString.Kind() != ast.LiteralKind && formatString.Type() != cel.StringType { - return false - } - args := call.Args() - if len(args) != 1 { - return false +func matchConstantFormatStringWithListLiteralArgs(a *ast.AST) ast.ExprMatcher { + return func(e ast.NavigableExpr) bool { + if e.Kind() != ast.CallKind { + return false + } + call := e.AsCall() + if !call.IsMemberFunction() || call.FunctionName() != "format" { + return false + } + overloadIDs := a.GetOverloadIDs(e.ID()) + if len(overloadIDs) != 0 { + found := false + for _, overload := range overloadIDs { + if overload == overloads.ExtFormatString { + found = true + break + } + } + if !found { + return false + } + } + formatString := call.Target() + if formatString.Kind() != ast.LiteralKind && formatString.AsLiteral().Type() != cel.StringType { + return false + } + args := call.Args() + if len(args) != 1 { + return false + } + formatArgs := args[0] + return formatArgs.Kind() == ast.ListKind } - formatArgs := args[0] - return formatArgs.Kind() == ast.ListKind } // stringFormatChecker implements the formatStringInterpolater interface type stringFormatChecker struct { - args []ast.NavigableExpr + args []ast.Expr argsRequested int currArgIndex int64 - typeMap map[int64]*cel.Type + ast *ast.AST } func (c *stringFormatChecker) String(arg ref.Val, locale string) (string, error) { @@ -572,11 +587,7 @@ func (c *stringFormatChecker) Size() int64 { } func (c *stringFormatChecker) typeOf(id int64) *cel.Type { - t, found := c.typeMap[id] - if found { - return t - } - return cel.DynType + return c.ast.GetType(id) } func (c *stringFormatChecker) verifyTypeOneOf(id int64, validTypes ...*cel.Type) bool { @@ -593,7 +604,7 @@ func (c *stringFormatChecker) verifyTypeOneOf(id int64, validTypes ...*cel.Type) return false } -func (c *stringFormatChecker) verifyString(sub ast.NavigableExpr) (bool, int64) { +func (c *stringFormatChecker) verifyString(sub ast.Expr) (bool, int64) { paramA := cel.TypeParamType("A") paramB := cel.TypeParamType("B") subVerified := c.verifyTypeOneOf(sub.ID(), @@ -603,18 +614,33 @@ func (c *stringFormatChecker) verifyString(sub ast.NavigableExpr) (bool, int64) if !subVerified { return false, sub.ID() } - if sub.Kind() != ast.ListKind && sub.Kind() != ast.MapKind { + switch sub.Kind() { + case ast.ListKind: + for _, e := range sub.AsList().Elements() { + // recursively verify if we're dealing with a list/map + verified, id := c.verifyString(e) + if !verified { + return false, id + } + } return true, sub.ID() - } - members := sub.Children() - for _, m := range members { - // recursively verify if we're dealing with a list/map - verified, id := c.verifyString(m) - if !verified { - return false, id + case ast.MapKind: + for _, e := range sub.AsMap().Entries() { + // recursively verify if we're dealing with a list/map + entry := e.AsMapEntry() + verified, id := c.verifyString(entry.Key()) + if !verified { + return false, id + } + verified, id = c.verifyString(entry.Value()) + if !verified { + return false, id + } } + return true, sub.ID() + default: + return true, sub.ID() } - return true, sub.ID() } // helper routines for reporting common errors during string formatting static validation and diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 00fc74732..0aca74d88 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -22,19 +22,13 @@ import ( "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // Interpreter generates a new Interpretable from a checked or unchecked expression. type Interpreter interface { // NewInterpretable creates an Interpretable from a checked expression and an // optional list of InterpretableDecorator values. - NewInterpretable(checked *ast.CheckedAST, decorators ...InterpretableDecorator) (Interpretable, error) - - // NewUncheckedInterpretable returns an Interpretable from a parsed expression - // and an optional list of InterpretableDecorator values. - NewUncheckedInterpretable(expr *exprpb.Expr, decorators ...InterpretableDecorator) (Interpretable, error) + NewInterpretable(exprAST *ast.AST, decorators ...InterpretableDecorator) (Interpretable, error) } // EvalObserver is a functional interface that accepts an expression id and an observed value. @@ -177,7 +171,7 @@ func NewInterpreter(dispatcher Dispatcher, // NewIntepretable implements the Interpreter interface method. func (i *exprInterpreter) NewInterpretable( - checked *ast.CheckedAST, + checked *ast.AST, decorators ...InterpretableDecorator) (Interpretable, error) { p := newPlanner( i.dispatcher, @@ -187,19 +181,5 @@ func (i *exprInterpreter) NewInterpretable( i.container, checked, decorators...) - return p.Plan(checked.Expr) -} - -// NewUncheckedIntepretable implements the Interpreter interface method. -func (i *exprInterpreter) NewUncheckedInterpretable( - expr *exprpb.Expr, - decorators ...InterpretableDecorator) (Interpretable, error) { - p := newUncheckedPlanner( - i.dispatcher, - i.provider, - i.adapter, - i.attrFactory, - i.container, - decorators...) - return p.Plan(expr) + return p.Plan(checked.Expr()) } diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 26e50c2fd..73269bf07 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -29,6 +29,7 @@ import ( "github.com/google/cel-go/checker" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/functions" @@ -1606,36 +1607,33 @@ func TestInterpreter_ProtoAttributeOpt(t *testing.T) { } func TestInterpreter_LogicalAndMissingType(t *testing.T) { - src := common.NewTextSource(`a && TestProto{c: true}.c`) - parsed, errors := parser.Parse(src) - if len(errors.GetErrors()) != 0 { - t.Errorf(errors.ToDisplayString()) + parsed := testMustParse(t, `a && TestProto{c: true}.c`) + ex, err := ast.ProtoToExpr(parsed.GetExpr()) + if err != nil { + t.Fatalf("ast.ProtoToExpr() failed: %v", err) } - reg := newTestRegistry(t) cont := containers.DefaultContainer attrs := NewAttributeFactory(cont, reg, reg) intr := newStandardInterpreter(t, cont, reg, reg, attrs) - i, err := intr.NewUncheckedInterpretable(parsed.GetExpr()) + i, err := intr.NewInterpretable(ast.NewAST(ex, nil)) if err == nil { t.Errorf("Got '%v', wanted error", i) } } func TestInterpreter_ExhaustiveConditionalExpr(t *testing.T) { - src := common.NewTextSource(`a ? b < 1.0 : c == ['hello']`) - parsed, errors := parser.Parse(src) - if len(errors.GetErrors()) != 0 { - t.Errorf(errors.ToDisplayString()) + parsed := testMustParse(t, `a ? b < 1.0 : c == ['hello']`) + ex, err := ast.ProtoToExpr(parsed.GetExpr()) + if err != nil { + t.Fatalf("ast.ProtoToExpr() failed: %v", err) } - state := NewEvalState() cont := containers.DefaultContainer reg := newTestRegistry(t, &exprpb.ParsedExpr{}) attrs := NewAttributeFactory(cont, reg, reg) intr := newStandardInterpreter(t, cont, reg, reg, attrs) - interpretable, _ := intr.NewUncheckedInterpretable( - parsed.GetExpr(), + interpretable, _ := intr.NewInterpretable(ast.NewAST(ex, nil), ExhaustiveEval(), Observe(EvalStateObserver(state))) vars, _ := NewActivation(map[string]any{ "a": types.True, @@ -1712,19 +1710,17 @@ func (ca *contextActivation) ResolveName(name string) (any, bool) { func TestInterpreter_ExhaustiveLogicalOrEquals(t *testing.T) { // a || b == "b" // Operator "==" is at Expr 4, should be evaluated though "a" is true - src := common.NewTextSource(`a || b == "b"`) - parsed, errors := parser.Parse(src) - if len(errors.GetErrors()) != 0 { - t.Errorf(errors.ToDisplayString()) + parsed := testMustParse(t, `a || b == "b"`) + ex, err := ast.ProtoToExpr(parsed.GetExpr()) + if err != nil { + t.Fatalf("ast.ProtoToExpr() failed: %v", err) } - state := NewEvalState() reg := newTestRegistry(t, &exprpb.Expr{}) cont := testContainer("test") attrs := NewAttributeFactory(cont, reg, reg) interp := newStandardInterpreter(t, cont, reg, reg, attrs) - i, _ := interp.NewUncheckedInterpretable( - parsed.GetExpr(), + i, _ := interp.NewInterpretable(ast.NewAST(ex, nil), ExhaustiveEval(), Observe(EvalStateObserver(state))) vars, _ := NewActivation(map[string]any{ "a": true, @@ -1754,11 +1750,7 @@ func TestInterpreter_SetProto2PrimitiveFields(t *testing.T) { single_string: "hello world", single_bool: true }`) - parsed, errors := parser.Parse(src) - if len(errors.GetErrors()) != 0 { - t.Errorf(errors.ToDisplayString()) - } - + parsed := testMustParse(t, src) cont := testContainer("google.expr.proto2.test") reg := newTestRegistry(t, &proto2pb.TestAllTypes{}) env := newTestEnv(t, cont, reg) @@ -1795,25 +1787,21 @@ func TestInterpreter_SetProto2PrimitiveFields(t *testing.T) { "input": reg.NativeToValue(input), }) result := eval.Eval(vars) - got, ok := result.(ref.Val).Value().(bool) + got, ok := result.Value().(bool) if !ok { t.Fatalf("Got '%v', wanted 'true'.", result) } expected := true if !reflect.DeepEqual(got, expected) { t.Errorf("Could not build object properly. Got '%v', wanted '%v'", - result.(ref.Val).Value(), + result.Value(), expected) } } func TestInterpreter_MissingIdentInSelect(t *testing.T) { src := common.NewTextSource(`a.b.c`) - parsed, errors := parser.Parse(src) - if len(errors.GetErrors()) != 0 { - t.Fatalf(errors.ToDisplayString()) - } - + parsed := testMustParse(t, src) cont := testContainer("test") reg := newTestRegistry(t) env := newTestEnv(t, cont, reg) @@ -1870,10 +1858,7 @@ func TestInterpreter_TypeConversionOpt(t *testing.T) { } for _, tc := range tests { src := common.NewTextSource(tc.in) - parsed, errors := parser.Parse(src) - if len(errors.GetErrors()) != 0 { - t.Fatalf("parser.Parse(%q) failed: %v", tc.in, errors.ToDisplayString()) - } + parsed := testMustParse(t, src) cont := containers.DefaultContainer reg := newTestRegistry(t) env := newTestEnv(t, cont, reg) @@ -1917,55 +1902,22 @@ func TestInterpreter_TypeConversionOpt(t *testing.T) { } func TestInterpreter_PlanOptionalElements(t *testing.T) { + fac := ast.NewExprFactory() // [?a] manipulated so the optional index is negative. - badOptionalA := &exprpb.Expr{ - Id: 1, - ExprKind: &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - OptionalIndices: []int32{-1}, - Elements: []*exprpb.Expr{ - { - Id: 2, - ExprKind: &exprpb.Expr_IdentExpr{ - IdentExpr: &exprpb.Expr_Ident{ - Name: "a", - }, - }, - }, - }, - }, - }, - } + badOptionalA := fac.NewList(1, []ast.Expr{fac.NewIdent(2, "a")}, []int32{-1}) // [?b] manipulated so the optional index is out of range. - badOptionalB := &exprpb.Expr{ - Id: 1, - ExprKind: &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - OptionalIndices: []int32{24}, - Elements: []*exprpb.Expr{ - { - Id: 2, - ExprKind: &exprpb.Expr_IdentExpr{ - IdentExpr: &exprpb.Expr_Ident{ - Name: "b", - }, - }, - }, - }, - }, - }, - } + badOptionalB := fac.NewList(1, []ast.Expr{fac.NewIdent(2, "b")}, []int32{24}) cont := containers.DefaultContainer reg := newTestRegistry(t) attrs := NewAttributeFactory(cont, reg, reg) interp := newStandardInterpreter(t, cont, reg, reg, attrs) - _, err := interp.NewUncheckedInterpretable(badOptionalA, Optimize()) + _, err := interp.NewInterpretable(ast.NewAST(badOptionalA, nil), Optimize()) if err == nil { - t.Fatal("interp.NewUncheckedInterpretable() should have failed with negative optional index: -1") + t.Fatal("interp.NewInterpretable() should have failed with negative optional index: -1") } - _, err = interp.NewUncheckedInterpretable(badOptionalB, Optimize()) + _, err = interp.NewInterpretable(ast.NewAST(badOptionalB, nil), Optimize()) if err == nil { - t.Fatal("interp.NewUncheckedInterpretable() should have failed with out of range optional index: 24") + t.Fatal("interp.NewInterpretable() should have failed with out of range optional index: 24") } } @@ -1974,7 +1926,7 @@ func testContainer(name string) *containers.Container { return cont } -func program(ctx testing.TB, tst *testCase, opts ...InterpretableDecorator) (Interpretable, Activation, error) { +func program(t testing.TB, tst *testCase, opts ...InterpretableDecorator) (Interpretable, Activation, error) { // Configure the package. cont := containers.DefaultContainer if tst.container != "" { @@ -1991,11 +1943,11 @@ func program(ctx testing.TB, tst *testCase, opts ...InterpretableDecorator) (Int } var reg *types.Registry var env *checker.Env - reg = newTestRegistry(ctx) + reg = newTestRegistry(t) if tst.types != nil { - reg = newTestRegistry(ctx, tst.types...) + reg = newTestRegistry(t, tst.types...) } - env = newTestEnv(ctx, cont, reg) + env = newTestEnv(t, cont, reg) attrs := NewAttributeFactory(cont, reg, reg) if tst.attrs != nil { attrs = tst.attrs @@ -2011,7 +1963,7 @@ func program(ctx testing.TB, tst *testCase, opts ...InterpretableDecorator) (Int if tst.in != nil { vars, err = NewActivation(tst.in) if err != nil { - ctx.Fatalf("NewActivation(%v) failed: %v", tst.in, err) + t.Fatalf("NewActivation(%v) failed: %v", tst.in, err) } } // Adapt the test output, if needed. @@ -2020,13 +1972,13 @@ func program(ctx testing.TB, tst *testCase, opts ...InterpretableDecorator) (Int } disp := NewDispatcher() - addFunctionBindings(ctx, disp) + addFunctionBindings(t, disp) if tst.funcs != nil { err = env.AddFunctions(tst.funcs...) if err != nil { return nil, nil, fmt.Errorf("env.Add(%v) failed: %v", tst.funcs, err) } - disp.Add(funcBindings(ctx, tst.funcs...)...) + disp.Add(funcBindings(t, tst.funcs...)...) } interp := NewInterpreter(disp, cont, reg, reg, attrs) @@ -2045,9 +1997,12 @@ func program(ctx testing.TB, tst *testCase, opts ...InterpretableDecorator) (Int return nil, nil, errors.New(errs.ToDisplayString()) } if tst.unchecked { + ex, err := ast.ProtoToExpr(parsed.GetExpr()) + if err != nil { + t.Fatalf("ast.ProtoToExpr() failed: %v", err) + } // Build the program plan. - prg, err := interp.NewUncheckedInterpretable( - parsed.GetExpr(), opts...) + prg, err := interp.NewInterpretable(ast.NewAST(ex, nil), opts...) if err != nil { return nil, nil, err } @@ -2090,6 +2045,28 @@ func isFieldQual(q Qualifier, fieldName string) bool { return f.Name == fieldName } +func testMustParse(t testing.TB, data any) *exprpb.ParsedExpr { + t.Helper() + p, err := parser.NewParser() + if err != nil { + t.Fatalf("parser.NewParser() failed: %v", err) + } + var src common.Source + switch d := data.(type) { + case string: + src = common.NewTextSource(d) + case common.Source: + src = d + default: + t.Fatalf("testMustParse() got invalid parse data: %v", data) + } + parsed, errors := p.Parse(src) + if len(errors.GetErrors()) != 0 { + t.Fatalf("Parse(%q) failed: %v", src.Content(), errors.ToDisplayString()) + } + return parsed +} + func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry) *checker.Env { t.Helper() env, err := checker.NewEnv(cont, reg, checker.CrossTypeNumericComparisons(true)) diff --git a/interpreter/planner.go b/interpreter/planner.go index 757cd080e..cf371f95d 100644 --- a/interpreter/planner.go +++ b/interpreter/planner.go @@ -23,15 +23,12 @@ import ( "github.com/google/cel-go/common/functions" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" - "github.com/google/cel-go/common/types/ref" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // interpretablePlanner creates an Interpretable evaluation plan from a proto Expr value. type interpretablePlanner interface { // Plan generates an Interpretable value (or error) from the input proto Expr. - Plan(expr *exprpb.Expr) (Interpretable, error) + Plan(expr ast.Expr) (Interpretable, error) } // newPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider, @@ -43,28 +40,7 @@ func newPlanner(disp Dispatcher, adapter types.Adapter, attrFactory AttributeFactory, cont *containers.Container, - checked *ast.CheckedAST, - decorators ...InterpretableDecorator) interpretablePlanner { - return &planner{ - disp: disp, - provider: provider, - adapter: adapter, - attrFactory: attrFactory, - container: cont, - refMap: checked.ReferenceMap, - typeMap: checked.TypeMap, - decorators: decorators, - } -} - -// newUncheckedPlanner creates an interpretablePlanner which references a Dispatcher, TypeProvider, -// TypeAdapter, and Container to resolve functions and types at plan time. Namespaces present in -// Select expressions are resolved lazily at evaluation time. -func newUncheckedPlanner(disp Dispatcher, - provider types.Provider, - adapter types.Adapter, - attrFactory AttributeFactory, - cont *containers.Container, + exprAST *ast.AST, decorators ...InterpretableDecorator) interpretablePlanner { return &planner{ disp: disp, @@ -72,8 +48,8 @@ func newUncheckedPlanner(disp Dispatcher, adapter: adapter, attrFactory: attrFactory, container: cont, - refMap: make(map[int64]*ast.ReferenceInfo), - typeMap: make(map[int64]*types.Type), + refMap: exprAST.ReferenceMap(), + typeMap: exprAST.TypeMap(), decorators: decorators, } } @@ -95,22 +71,24 @@ type planner struct { // useful for layering functionality into the evaluation that is not natively understood by CEL, // such as state-tracking, expression re-write, and possibly efficient thread-safe memoization of // repeated expressions. -func (p *planner) Plan(expr *exprpb.Expr) (Interpretable, error) { - switch expr.GetExprKind().(type) { - case *exprpb.Expr_CallExpr: +func (p *planner) Plan(expr ast.Expr) (Interpretable, error) { + switch expr.Kind() { + case ast.CallKind: return p.decorate(p.planCall(expr)) - case *exprpb.Expr_IdentExpr: + case ast.IdentKind: return p.decorate(p.planIdent(expr)) - case *exprpb.Expr_SelectExpr: + case ast.LiteralKind: + return p.decorate(p.planConst(expr)) + case ast.SelectKind: return p.decorate(p.planSelect(expr)) - case *exprpb.Expr_ListExpr: + case ast.ListKind: return p.decorate(p.planCreateList(expr)) - case *exprpb.Expr_StructExpr: + case ast.MapKind: + return p.decorate(p.planCreateMap(expr)) + case ast.StructKind: return p.decorate(p.planCreateStruct(expr)) - case *exprpb.Expr_ComprehensionExpr: + case ast.ComprehensionKind: return p.decorate(p.planComprehension(expr)) - case *exprpb.Expr_ConstExpr: - return p.decorate(p.planConst(expr)) } return nil, fmt.Errorf("unsupported expr: %v", expr) } @@ -132,16 +110,16 @@ func (p *planner) decorate(i Interpretable, err error) (Interpretable, error) { } // planIdent creates an Interpretable that resolves an identifier from an Activation. -func (p *planner) planIdent(expr *exprpb.Expr) (Interpretable, error) { +func (p *planner) planIdent(expr ast.Expr) (Interpretable, error) { // Establish whether the identifier is in the reference map. - if identRef, found := p.refMap[expr.GetId()]; found { - return p.planCheckedIdent(expr.GetId(), identRef) + if identRef, found := p.refMap[expr.ID()]; found { + return p.planCheckedIdent(expr.ID(), identRef) } // Create the possible attribute list for the unresolved reference. - ident := expr.GetIdentExpr() + ident := expr.AsIdent() return &evalAttr{ adapter: p.adapter, - attr: p.attrFactory.MaybeAttribute(expr.GetId(), ident.Name), + attr: p.attrFactory.MaybeAttribute(expr.ID(), ident), }, nil } @@ -174,20 +152,20 @@ func (p *planner) planCheckedIdent(id int64, identRef *ast.ReferenceInfo) (Inter // a) selects a field from a map or proto. // b) creates a field presence test for a select within a has() macro. // c) resolves the select expression to a namespaced identifier. -func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) { +func (p *planner) planSelect(expr ast.Expr) (Interpretable, error) { // If the Select id appears in the reference map from the CheckedExpr proto then it is either // a namespaced identifier or enum value. - if identRef, found := p.refMap[expr.GetId()]; found { - return p.planCheckedIdent(expr.GetId(), identRef) + if identRef, found := p.refMap[expr.ID()]; found { + return p.planCheckedIdent(expr.ID(), identRef) } - sel := expr.GetSelectExpr() + sel := expr.AsSelect() // Plan the operand evaluation. - op, err := p.Plan(sel.GetOperand()) + op, err := p.Plan(sel.Operand()) if err != nil { return nil, err } - opType := p.typeMap[sel.GetOperand().GetId()] + opType := p.typeMap[sel.Operand().ID()] // If the Select was marked TestOnly, this is a presence test. // @@ -211,14 +189,14 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) { } // Build a qualifier for the attribute. - qual, err := p.attrFactory.NewQualifier(opType, expr.GetId(), sel.GetField(), false) + qual, err := p.attrFactory.NewQualifier(opType, expr.ID(), sel.FieldName(), false) if err != nil { return nil, err } // Modify the attribute to be test-only. - if sel.GetTestOnly() { + if sel.IsTestOnly() { attr = &evalTestOnly{ - id: expr.GetId(), + id: expr.ID(), InterpretableAttribute: attr, } } @@ -230,10 +208,10 @@ func (p *planner) planSelect(expr *exprpb.Expr) (Interpretable, error) { // planCall creates a callable Interpretable while specializing for common functions and invocation // patterns. Specifically, conditional operators &&, ||, ?:, and (in)equality functions result in // optimized Interpretable values. -func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) { - call := expr.GetCallExpr() +func (p *planner) planCall(expr ast.Expr) (Interpretable, error) { + call := expr.AsCall() target, fnName, oName := p.resolveFunction(expr) - argCount := len(call.GetArgs()) + argCount := len(call.Args()) var offset int if target != nil { argCount++ @@ -248,7 +226,7 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) { } args[0] = arg } - for i, argExpr := range call.GetArgs() { + for i, argExpr := range call.Args() { arg, err := p.Plan(argExpr) if err != nil { return nil, err @@ -307,7 +285,7 @@ func (p *planner) planCall(expr *exprpb.Expr) (Interpretable, error) { } // planCallZero generates a zero-arity callable Interpretable. -func (p *planner) planCallZero(expr *exprpb.Expr, +func (p *planner) planCallZero(expr ast.Expr, function string, overload string, impl *functions.Overload) (Interpretable, error) { @@ -315,7 +293,7 @@ func (p *planner) planCallZero(expr *exprpb.Expr, return nil, fmt.Errorf("no such overload: %s()", function) } return &evalZeroArity{ - id: expr.GetId(), + id: expr.ID(), function: function, overload: overload, impl: impl.Function, @@ -323,7 +301,7 @@ func (p *planner) planCallZero(expr *exprpb.Expr, } // planCallUnary generates a unary callable Interpretable. -func (p *planner) planCallUnary(expr *exprpb.Expr, +func (p *planner) planCallUnary(expr ast.Expr, function string, overload string, impl *functions.Overload, @@ -340,7 +318,7 @@ func (p *planner) planCallUnary(expr *exprpb.Expr, nonStrict = impl.NonStrict } return &evalUnary{ - id: expr.GetId(), + id: expr.ID(), function: function, overload: overload, arg: args[0], @@ -351,7 +329,7 @@ func (p *planner) planCallUnary(expr *exprpb.Expr, } // planCallBinary generates a binary callable Interpretable. -func (p *planner) planCallBinary(expr *exprpb.Expr, +func (p *planner) planCallBinary(expr ast.Expr, function string, overload string, impl *functions.Overload, @@ -368,7 +346,7 @@ func (p *planner) planCallBinary(expr *exprpb.Expr, nonStrict = impl.NonStrict } return &evalBinary{ - id: expr.GetId(), + id: expr.ID(), function: function, overload: overload, lhs: args[0], @@ -380,7 +358,7 @@ func (p *planner) planCallBinary(expr *exprpb.Expr, } // planCallVarArgs generates a variable argument callable Interpretable. -func (p *planner) planCallVarArgs(expr *exprpb.Expr, +func (p *planner) planCallVarArgs(expr ast.Expr, function string, overload string, impl *functions.Overload, @@ -397,7 +375,7 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr, nonStrict = impl.NonStrict } return &evalVarArgs{ - id: expr.GetId(), + id: expr.ID(), function: function, overload: overload, args: args, @@ -408,41 +386,41 @@ func (p *planner) planCallVarArgs(expr *exprpb.Expr, } // planCallEqual generates an equals (==) Interpretable. -func (p *planner) planCallEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalEq{ - id: expr.GetId(), + id: expr.ID(), lhs: args[0], rhs: args[1], }, nil } // planCallNotEqual generates a not equals (!=) Interpretable. -func (p *planner) planCallNotEqual(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallNotEqual(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalNe{ - id: expr.GetId(), + id: expr.ID(), lhs: args[0], rhs: args[1], }, nil } // planCallLogicalAnd generates a logical and (&&) Interpretable. -func (p *planner) planCallLogicalAnd(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallLogicalAnd(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalAnd{ - id: expr.GetId(), + id: expr.ID(), terms: args, }, nil } // planCallLogicalOr generates a logical or (||) Interpretable. -func (p *planner) planCallLogicalOr(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallLogicalOr(expr ast.Expr, args []Interpretable) (Interpretable, error) { return &evalOr{ - id: expr.GetId(), + id: expr.ID(), terms: args, }, nil } // planCallConditional generates a conditional / ternary (c ? t : f) Interpretable. -func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) (Interpretable, error) { +func (p *planner) planCallConditional(expr ast.Expr, args []Interpretable) (Interpretable, error) { cond := args[0] t := args[1] var tAttr Attribute @@ -464,13 +442,13 @@ func (p *planner) planCallConditional(expr *exprpb.Expr, args []Interpretable) ( return &evalAttr{ adapter: p.adapter, - attr: p.attrFactory.ConditionalAttribute(expr.GetId(), cond, tAttr, fAttr), + attr: p.attrFactory.ConditionalAttribute(expr.ID(), cond, tAttr, fAttr), }, nil } // planCallIndex either extends an attribute with the argument to the index operation, or creates // a relative attribute based on the return of a function call or operation. -func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optional bool) (Interpretable, error) { +func (p *planner) planCallIndex(expr ast.Expr, args []Interpretable, optional bool) (Interpretable, error) { op := args[0] ind := args[1] opType := p.typeMap[op.ID()] @@ -489,11 +467,11 @@ func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optiona var qual Qualifier switch ind := ind.(type) { case InterpretableConst: - qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind.Value(), optional) + qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind.Value(), optional) case InterpretableAttribute: - qual, err = p.attrFactory.NewQualifier(opType, expr.GetId(), ind, optional) + qual, err = p.attrFactory.NewQualifier(opType, expr.ID(), ind, optional) default: - qual, err = p.relativeAttr(expr.GetId(), ind, optional) + qual, err = p.relativeAttr(expr.ID(), ind, optional) } if err != nil { return nil, err @@ -505,10 +483,10 @@ func (p *planner) planCallIndex(expr *exprpb.Expr, args []Interpretable, optiona } // planCreateList generates a list construction Interpretable. -func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) { - list := expr.GetListExpr() - optionalIndices := list.GetOptionalIndices() - elements := list.GetElements() +func (p *planner) planCreateList(expr ast.Expr) (Interpretable, error) { + list := expr.AsList() + optionalIndices := list.OptionalIndices() + elements := list.Elements() optionals := make([]bool, len(elements)) for _, index := range optionalIndices { if index < 0 || index >= int32(len(elements)) { @@ -525,7 +503,7 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) { elems[i] = elemVal } return &evalList{ - id: expr.GetId(), + id: expr.ID(), elems: elems, optionals: optionals, hasOptionals: len(optionals) != 0, @@ -534,31 +512,29 @@ func (p *planner) planCreateList(expr *exprpb.Expr) (Interpretable, error) { } // planCreateStruct generates a map or object construction Interpretable. -func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) { - str := expr.GetStructExpr() - if len(str.MessageName) != 0 { - return p.planCreateObj(expr) - } - entries := str.GetEntries() +func (p *planner) planCreateMap(expr ast.Expr) (Interpretable, error) { + m := expr.AsMap() + entries := m.Entries() optionals := make([]bool, len(entries)) keys := make([]Interpretable, len(entries)) vals := make([]Interpretable, len(entries)) - for i, entry := range entries { - keyVal, err := p.Plan(entry.GetMapKey()) + for i, e := range entries { + entry := e.AsMapEntry() + keyVal, err := p.Plan(entry.Key()) if err != nil { return nil, err } keys[i] = keyVal - valVal, err := p.Plan(entry.GetValue()) + valVal, err := p.Plan(entry.Value()) if err != nil { return nil, err } vals[i] = valVal - optionals[i] = entry.GetOptionalEntry() + optionals[i] = entry.IsOptional() } return &evalMap{ - id: expr.GetId(), + id: expr.ID(), keys: keys, vals: vals, optionals: optionals, @@ -568,27 +544,28 @@ func (p *planner) planCreateStruct(expr *exprpb.Expr) (Interpretable, error) { } // planCreateObj generates an object construction Interpretable. -func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) { - obj := expr.GetStructExpr() - typeName, defined := p.resolveTypeName(obj.GetMessageName()) +func (p *planner) planCreateStruct(expr ast.Expr) (Interpretable, error) { + obj := expr.AsStruct() + typeName, defined := p.resolveTypeName(obj.TypeName()) if !defined { - return nil, fmt.Errorf("unknown type: %s", obj.GetMessageName()) - } - entries := obj.GetEntries() - optionals := make([]bool, len(entries)) - fields := make([]string, len(entries)) - vals := make([]Interpretable, len(entries)) - for i, entry := range entries { - fields[i] = entry.GetFieldKey() - val, err := p.Plan(entry.GetValue()) + return nil, fmt.Errorf("unknown type: %s", obj.TypeName()) + } + objFields := obj.Fields() + optionals := make([]bool, len(objFields)) + fields := make([]string, len(objFields)) + vals := make([]Interpretable, len(objFields)) + for i, f := range objFields { + field := f.AsStructField() + fields[i] = field.Name() + val, err := p.Plan(field.Value()) if err != nil { return nil, err } vals[i] = val - optionals[i] = entry.GetOptionalEntry() + optionals[i] = field.IsOptional() } return &evalObj{ - id: expr.GetId(), + id: expr.ID(), typeName: typeName, fields: fields, vals: vals, @@ -599,33 +576,33 @@ func (p *planner) planCreateObj(expr *exprpb.Expr) (Interpretable, error) { } // planComprehension generates an Interpretable fold operation. -func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) { - fold := expr.GetComprehensionExpr() - accu, err := p.Plan(fold.GetAccuInit()) +func (p *planner) planComprehension(expr ast.Expr) (Interpretable, error) { + fold := expr.AsComprehension() + accu, err := p.Plan(fold.AccuInit()) if err != nil { return nil, err } - iterRange, err := p.Plan(fold.GetIterRange()) + iterRange, err := p.Plan(fold.IterRange()) if err != nil { return nil, err } - cond, err := p.Plan(fold.GetLoopCondition()) + cond, err := p.Plan(fold.LoopCondition()) if err != nil { return nil, err } - step, err := p.Plan(fold.GetLoopStep()) + step, err := p.Plan(fold.LoopStep()) if err != nil { return nil, err } - result, err := p.Plan(fold.GetResult()) + result, err := p.Plan(fold.Result()) if err != nil { return nil, err } return &evalFold{ - id: expr.GetId(), - accuVar: fold.AccuVar, + id: expr.ID(), + accuVar: fold.AccuVar(), accu: accu, - iterVar: fold.IterVar, + iterVar: fold.IterVar(), iterRange: iterRange, cond: cond, step: step, @@ -635,37 +612,8 @@ func (p *planner) planComprehension(expr *exprpb.Expr) (Interpretable, error) { } // planConst generates a constant valued Interpretable. -func (p *planner) planConst(expr *exprpb.Expr) (Interpretable, error) { - val, err := p.constValue(expr.GetConstExpr()) - if err != nil { - return nil, err - } - return NewConstValue(expr.GetId(), val), nil -} - -// constValue converts a proto Constant value to a ref.Val. -func (p *planner) constValue(c *exprpb.Constant) (ref.Val, error) { - switch c.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - return p.adapter.NativeToValue(c.GetBoolValue()), nil - case *exprpb.Constant_BytesValue: - return p.adapter.NativeToValue(c.GetBytesValue()), nil - case *exprpb.Constant_DoubleValue: - return p.adapter.NativeToValue(c.GetDoubleValue()), nil - case *exprpb.Constant_DurationValue: - return p.adapter.NativeToValue(c.GetDurationValue().AsDuration()), nil - case *exprpb.Constant_Int64Value: - return p.adapter.NativeToValue(c.GetInt64Value()), nil - case *exprpb.Constant_NullValue: - return p.adapter.NativeToValue(c.GetNullValue()), nil - case *exprpb.Constant_StringValue: - return p.adapter.NativeToValue(c.GetStringValue()), nil - case *exprpb.Constant_TimestampValue: - return p.adapter.NativeToValue(c.GetTimestampValue().AsTime()), nil - case *exprpb.Constant_Uint64Value: - return p.adapter.NativeToValue(c.GetUint64Value()), nil - } - return nil, fmt.Errorf("unknown constant type: %v", c) +func (p *planner) planConst(expr ast.Expr) (Interpretable, error) { + return NewConstValue(expr.ID(), expr.AsLiteral()), nil } // resolveTypeName takes a qualified string constructed at parse time, applies the proto @@ -687,17 +635,20 @@ func (p *planner) resolveTypeName(typeName string) (string, bool) { // - The target expression may only consist of ident and select expressions. // - The function is declared in the environment using its fully-qualified name. // - The fully-qualified function name matches the string serialized target value. -func (p *planner) resolveFunction(expr *exprpb.Expr) (*exprpb.Expr, string, string) { +func (p *planner) resolveFunction(expr ast.Expr) (ast.Expr, string, string) { // Note: similar logic exists within the `checker/checker.go`. If making changes here // please consider the impact on checker.go and consolidate implementations or mirror code // as appropriate. - call := expr.GetCallExpr() - target := call.GetTarget() - fnName := call.GetFunction() + call := expr.AsCall() + var target ast.Expr = nil + if call.IsMemberFunction() { + target = call.Target() + } + fnName := call.FunctionName() // Checked expressions always have a reference map entry, and _should_ have the fully qualified // function name as the fnName value. - oRef, hasOverload := p.refMap[expr.GetId()] + oRef, hasOverload := p.refMap[expr.ID()] if hasOverload { if len(oRef.OverloadIDs) == 1 { return target, fnName, oRef.OverloadIDs[0] @@ -771,16 +722,30 @@ func (p *planner) relativeAttr(id int64, eval Interpretable, opt bool) (Interpre // toQualifiedName converts an expression AST into a qualified name if possible, with a boolean // 'found' value that indicates if the conversion is successful. -func (p *planner) toQualifiedName(operand *exprpb.Expr) (string, bool) { +func (p *planner) toQualifiedName(operand ast.Expr) (string, bool) { // If the checker identified the expression as an attribute by the type-checker, then it can't // possibly be part of qualified name in a namespace. - _, isAttr := p.refMap[operand.GetId()] + _, isAttr := p.refMap[operand.ID()] if isAttr { return "", false } // Since functions cannot be both namespaced and receiver functions, if the operand is not an // qualified variable name, return the (possibly) qualified name given the expressions. - return containers.ToQualifiedName(operand) + switch operand.Kind() { + case ast.IdentKind: + id := operand.AsIdent() + return id, true + case ast.SelectKind: + sel := operand.AsSelect() + // Test only expressions are not valid as qualified names. + if sel.IsTestOnly() { + return "", false + } + if qual, found := p.toQualifiedName(sel.Operand()); found { + return qual + "." + sel.FieldName(), true + } + } + return "", false } func stripLeadingDot(name string) string { diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index 2e1f15591..989369428 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/types" @@ -454,9 +455,9 @@ func TestPrune(t *testing.T) { t.Fatalf("parser.NewParser() failed: %v", err) } for i, tst := range testCases { - ast, iss := p.Parse(common.NewStringSource(tst.expr, "")) + parsed, iss := p.Parse(common.NewStringSource(tst.expr, "")) if len(iss.GetErrors()) > 0 { - t.Fatalf(iss.ToDisplayString()) + t.Fatalf("Parse(%q) failed: %v", tst.expr, iss.ToDisplayString()) } state := NewEvalState() reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) @@ -465,12 +466,17 @@ func TestPrune(t *testing.T) { addFunctionBindings(t, dispatcher) dispatcher.Add(funcBindings(t, optionalDecls(t)...)...) interp := NewInterpreter(dispatcher, containers.DefaultContainer, reg, reg, attrs) - - interpretable, _ := interp.NewUncheckedInterpretable( - ast.GetExpr(), + ex, err := ast.ProtoToExpr(parsed.GetExpr()) + if err != nil { + t.Fatalf("ast.ProtoToExpr() failed: %v", err) + } + interpretable, err := interp.NewInterpretable(ast.NewAST(ex, nil), ExhaustiveEval(), Observe(EvalStateObserver(state))) + if err != nil { + t.Fatalf("NewUncheckedInterpretable() failed: %v", err) + } interpretable.Eval(testActivation(t, tst.in)) - newExpr := PruneAst(ast.GetExpr(), ast.GetSourceInfo().GetMacroCalls(), state) + newExpr := PruneAst(parsed.GetExpr(), parsed.GetSourceInfo().GetMacroCalls(), state) if tst.iterRange != "" { compre := newExpr.GetExpr().GetComprehensionExpr() if compre == nil { From 337fc078a9514774ba32f6a57ca2d3604e857304 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 11 Aug 2023 17:53:54 -0400 Subject: [PATCH 06/31] Migrate the type-checker to a native AST representation (#793) * Migrate the type-checker to a native AST representation * Minor doc update to the type-checker * Patch cost_test.go fixes * Removed nil field initializer --- cel/env.go | 29 +-- cel/io.go | 4 +- checker/checker.go | 345 +++++++++++----------------- checker/checker_test.go | 17 +- checker/cost.go | 6 +- checker/cost_test.go | 4 +- checker/errors.go | 14 +- checker/printer.go | 34 +-- common/ast/BUILD.bazel | 1 + common/ast/ast.go | 16 ++ common/ast/conversion.go | 5 + common/ast/conversion_test.go | 26 +-- common/containers/BUILD.bazel | 4 +- common/containers/container.go | 22 +- common/containers/container_test.go | 36 +-- common/debug/BUILD.bazel | 4 +- common/debug/debug.go | 156 +++++++------ interpreter/interpreter_test.go | 26 +-- interpreter/prune.go | 25 +- interpreter/prune_test.go | 18 +- parser/BUILD.bazel | 3 + parser/helper_test.go | 21 +- parser/parser.go | 20 +- parser/parser_test.go | 114 +++++---- parser/unparser.go | 16 +- parser/unparser_test.go | 29 +-- 26 files changed, 465 insertions(+), 530 deletions(-) diff --git a/cel/env.go b/cel/env.go index 49f79ed2d..d786012da 100644 --- a/cel/env.go +++ b/cel/env.go @@ -210,9 +210,6 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) { // It is possible to have both non-nil Ast and Issues values returned from this call: however, // the mere presence of an Ast does not imply that it is valid for use. func (e *Env) Check(ast *Ast) (*Ast, *Issues) { - // Note, errors aren't currently possible on the Ast to ParsedExpr conversion. - pe, _ := AstToParsedExpr(ast) - // Construct the internal checker env, erroring if there is an issue adding the declarations. chk, err := e.initChecker() if err != nil { @@ -221,7 +218,7 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo()) } - res, errs := checker.Check(pe, ast.Source(), chk) + res, errs := checker.Check(celast.NewAST(ast.expr, ast.info), ast.Source(), chk) if len(errs.GetErrors()) > 0 { return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo()) } @@ -433,22 +430,10 @@ func (e *Env) ParseSource(src Source) (*Ast, *Issues) { if len(errs.GetErrors()) > 0 { return nil, &Issues{errs: errs} } - // Manually create the Ast to ensure that the text source information is propagated on - // subsequent calls to Check. - expr, err := celast.ProtoToExpr(res.GetExpr()) - if err != nil { - errs.ReportError(common.NoLocation, "invalid parsed expression: %v", err) - return nil, &Issues{errs: errs} - } - info, err := celast.ProtoToSourceInfo(res.GetSourceInfo()) - if err != nil { - errs.ReportError(common.NoLocation, "invalid source info: %v", err) - return nil, &Issues{errs: errs} - } return &Ast{ source: src, - expr: expr, - info: info}, nil + expr: res.Expr(), + info: res.SourceInfo()}, nil } // Program generates an evaluable instance of the Ast within the environment (Env). @@ -544,8 +529,12 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) { // TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an // Ast format and then Program again. func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { - pruned := interpreter.PruneAst(a.Expr(), a.SourceInfo().GetMacroCalls(), details.State()) - expr, err := AstToString(ParsedExprToAst(pruned)) + pruned := interpreter.PruneAst(a.expr, a.info.MacroCalls(), details.State()) + newAST := &Ast{ + expr: pruned.Expr(), + info: pruned.SourceInfo(), + } + expr, err := AstToString(newAST) if err != nil { return nil, err } diff --git a/cel/io.go b/cel/io.go index 5107f564d..c8c9c911f 100644 --- a/cel/io.go +++ b/cel/io.go @@ -112,9 +112,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) { // Note, the conversion may not be an exact replica of the original expression, but will produce // a string that is semantically equivalent and whose textual representation is stable. func AstToString(a *Ast) (string, error) { - expr := a.Expr() - info := a.SourceInfo() - return parser.Unparse(expr, info) + return parser.Unparse(a.expr, a.info) } // RefValueToValue converts between ref.Val and api.expr.Value. diff --git a/checker/checker.go b/checker/checker.go index 20eaa050b..c8f99ff5b 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -25,143 +25,96 @@ import ( "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/types" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/types/ref" ) type checker struct { + *ast.AST + ast.ExprFactory env *Env errors *typeErrors mappings *mapping freeTypeVarCounter int - sourceInfo *exprpb.SourceInfo - types map[int64]*types.Type - references map[int64]*ast.ReferenceInfo } // Check performs type checking, giving a typed AST. // -// The input is a ParsedExpr proto and an env which encapsulates type binding of variables, +// The input is a parsed AST and an env which encapsulates type binding of variables, // declarations of built-in functions, descriptions of protocol buffers, and a registry for // errors. // // Returns a type-checked AST, which might not be usable if there are errors in the error // registry. -func Check(parsedExpr *exprpb.ParsedExpr, source common.Source, env *Env) (*ast.AST, *common.Errors) { +func Check(parsed *ast.AST, source common.Source, env *Env) (*ast.AST, *common.Errors) { errs := common.NewErrors(source) + typeMap := make(map[int64]*types.Type) + refMap := make(map[int64]*ast.ReferenceInfo) c := checker{ + AST: ast.NewCheckedAST(parsed, typeMap, refMap), + ExprFactory: ast.NewExprFactory(), env: env, errors: &typeErrors{errs: errs}, mappings: newMapping(), freeTypeVarCounter: 0, - sourceInfo: parsedExpr.GetSourceInfo(), - types: make(map[int64]*types.Type), - references: make(map[int64]*ast.ReferenceInfo), } - c.check(parsedExpr.GetExpr()) + c.check(c.Expr()) - // Walk over the final type map substituting any type parameters either by their bound value or - // by DYN. - m := make(map[int64]*types.Type) - for id, t := range c.types { - m[id] = substitute(c.mappings, t, true) - } - e, err := ast.ProtoToExpr(parsedExpr.GetExpr()) - if err != nil { - errs.ReportError(common.NoLocation, err.Error()) + // Walk over the final type map substituting any type parameters either by their bound value + // or by DYN. + for id, t := range c.TypeMap() { + c.SetType(id, substitute(c.mappings, t, true)) } - info, err := ast.ProtoToSourceInfo(parsedExpr.GetSourceInfo()) - if err != nil { - errs.ReportError(common.NoLocation, err.Error()) - } - return ast.NewCheckedAST(ast.NewAST(e, info), m, c.references), errs + return c.AST, errs } -func (c *checker) check(e *exprpb.Expr) { +func (c *checker) check(e ast.Expr) { if e == nil { return } - switch e.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - literal := e.GetConstExpr() - switch literal.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - c.checkBoolLiteral(e) - case *exprpb.Constant_BytesValue: - c.checkBytesLiteral(e) - case *exprpb.Constant_DoubleValue: - c.checkDoubleLiteral(e) - case *exprpb.Constant_Int64Value: - c.checkInt64Literal(e) - case *exprpb.Constant_NullValue: - c.checkNullLiteral(e) - case *exprpb.Constant_StringValue: - c.checkStringLiteral(e) - case *exprpb.Constant_Uint64Value: - c.checkUint64Literal(e) + switch e.Kind() { + case ast.LiteralKind: + literal := ref.Val(e.AsLiteral()) + switch literal.Type() { + case types.BoolType, types.BytesType, types.DoubleType, types.IntType, + types.NullType, types.StringType, types.UintType: + c.setType(e, literal.Type().(*types.Type)) } - case *exprpb.Expr_IdentExpr: + case ast.IdentKind: c.checkIdent(e) - case *exprpb.Expr_SelectExpr: + case ast.SelectKind: c.checkSelect(e) - case *exprpb.Expr_CallExpr: + case ast.CallKind: c.checkCall(e) - case *exprpb.Expr_ListExpr: + case ast.ListKind: c.checkCreateList(e) - case *exprpb.Expr_StructExpr: + case ast.MapKind: + c.checkCreateMap(e) + case ast.StructKind: c.checkCreateStruct(e) - case *exprpb.Expr_ComprehensionExpr: + case ast.ComprehensionKind: c.checkComprehension(e) default: - c.errors.unexpectedASTType(e.GetId(), c.location(e), e) + c.errors.unexpectedASTType(e.ID(), c.location(e), e) } } -func (c *checker) checkInt64Literal(e *exprpb.Expr) { - c.setType(e, types.IntType) -} - -func (c *checker) checkUint64Literal(e *exprpb.Expr) { - c.setType(e, types.UintType) -} - -func (c *checker) checkStringLiteral(e *exprpb.Expr) { - c.setType(e, types.StringType) -} - -func (c *checker) checkBytesLiteral(e *exprpb.Expr) { - c.setType(e, types.BytesType) -} - -func (c *checker) checkDoubleLiteral(e *exprpb.Expr) { - c.setType(e, types.DoubleType) -} - -func (c *checker) checkBoolLiteral(e *exprpb.Expr) { - c.setType(e, types.BoolType) -} - -func (c *checker) checkNullLiteral(e *exprpb.Expr) { - c.setType(e, types.NullType) -} - -func (c *checker) checkIdent(e *exprpb.Expr) { - identExpr := e.GetIdentExpr() +func (c *checker) checkIdent(e ast.Expr) { + identName := e.AsIdent() // Check to see if the identifier is declared. - if ident := c.env.LookupIdent(identExpr.GetName()); ident != nil { + if ident := c.env.LookupIdent(identName); ident != nil { c.setType(e, ident.Type()) c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value())) // Overwrite the identifier with its fully qualified name. - identExpr.Name = ident.Name() + e.SetKindCase(c.NewIdent(e.ID(), ident.Name())) return } c.setType(e, types.ErrorType) - c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), identExpr.GetName()) + c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), identName) } -func (c *checker) checkSelect(e *exprpb.Expr) { - sel := e.GetSelectExpr() +func (c *checker) checkSelect(e ast.Expr) { + sel := e.AsSelect() // Before traversing down the tree, try to interpret as qualified name. qname, found := containers.ToQualifiedName(e) if found { @@ -174,31 +127,26 @@ func (c *checker) checkSelect(e *exprpb.Expr) { // variable name. c.setType(e, ident.Type()) c.setReference(e, ast.NewIdentReference(ident.Name(), ident.Value())) - identName := ident.Name() - e.ExprKind = &exprpb.Expr_IdentExpr{ - IdentExpr: &exprpb.Expr_Ident{ - Name: identName, - }, - } + e.SetKindCase(c.NewIdent(e.ID(), ident.Name())) return } } - resultType := c.checkSelectField(e, sel.GetOperand(), sel.GetField(), false) - if sel.TestOnly { + resultType := c.checkSelectField(e, sel.Operand(), sel.FieldName(), false) + if sel.IsTestOnly() { resultType = types.BoolType } c.setType(e, substitute(c.mappings, resultType, false)) } -func (c *checker) checkOptSelect(e *exprpb.Expr) { +func (c *checker) checkOptSelect(e ast.Expr) { // Collect metadata related to the opt select call packaged by the parser. - call := e.GetCallExpr() - operand := call.GetArgs()[0] - field := call.GetArgs()[1] + call := e.AsCall() + operand := call.Args()[0] + field := call.Args()[1] fieldName, isString := maybeUnwrapString(field) if !isString { - c.errors.notAnOptionalFieldSelection(field.GetId(), c.location(field), field) + c.errors.notAnOptionalFieldSelection(field.ID(), c.location(field), field) return } @@ -208,7 +156,7 @@ func (c *checker) checkOptSelect(e *exprpb.Expr) { c.setReference(e, ast.NewFunctionReference("select_optional_field")) } -func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, optional bool) *types.Type { +func (c *checker) checkSelectField(e, operand ast.Expr, field string, optional bool) *types.Type { // Interpret as field selection, first traversing down the operand. c.check(operand) operandType := substitute(c.mappings, c.getType(operand), false) @@ -226,7 +174,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option // Objects yield their field type declaration as the selection result type, but only if // the field is defined. messageType := targetType - if fieldType, found := c.lookupFieldType(e.GetId(), messageType.TypeName(), field); found { + if fieldType, found := c.lookupFieldType(e.ID(), messageType.TypeName(), field); found { resultType = fieldType } case types.TypeParamKind: @@ -240,7 +188,7 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option // Dynamic / error values are treated as DYN type. Errors are handled this way as well // in order to allow forward progress on the check. if !isDynOrError(targetType) { - c.errors.typeDoesNotSupportFieldSelection(e.GetId(), c.location(e), targetType) + c.errors.typeDoesNotSupportFieldSelection(e.ID(), c.location(e), targetType) } resultType = types.DynType } @@ -252,35 +200,34 @@ func (c *checker) checkSelectField(e, operand *exprpb.Expr, field string, option return resultType } -func (c *checker) checkCall(e *exprpb.Expr) { +func (c *checker) checkCall(e ast.Expr) { // Note: similar logic exists within the `interpreter/planner.go`. If making changes here // please consider the impact on planner.go and consolidate implementations or mirror code // as appropriate. - call := e.GetCallExpr() - fnName := call.GetFunction() + call := e.AsCall() + fnName := call.FunctionName() if fnName == operators.OptSelect { c.checkOptSelect(e) return } - args := call.GetArgs() + args := call.Args() // Traverse arguments. for _, arg := range args { c.check(arg) } - target := call.GetTarget() // Regular static call with simple name. - if target == nil { + if !call.IsMemberFunction() { // Check for the existence of the function. fn := c.env.LookupFunction(fnName) if fn == nil { - c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName) + c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName) c.setType(e, types.ErrorType) return } // Overwrite the function name with its fully qualified resolved name. - call.Function = fn.Name() + e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...)) // Check to see whether the overload resolves. c.resolveOverloadOrError(e, fn, nil, args) return @@ -291,6 +238,7 @@ func (c *checker) checkCall(e *exprpb.Expr) { // target a.b. // // Check whether the target is a namespaced function name. + target := call.Target() qualifiedPrefix, maybeQualified := containers.ToQualifiedName(target) if maybeQualified { maybeQualifiedName := qualifiedPrefix + "." + fnName @@ -299,15 +247,14 @@ func (c *checker) checkCall(e *exprpb.Expr) { // The function name is namespaced and so preserving the target operand would // be an inaccurate representation of the desired evaluation behavior. // Overwrite with fully-qualified resolved function name sans receiver target. - call.Target = nil - call.Function = fn.Name() + e.SetKindCase(c.NewCall(e.ID(), fn.Name(), args...)) c.resolveOverloadOrError(e, fn, nil, args) return } } // Regular instance call. - c.check(call.Target) + c.check(target) fn := c.env.LookupFunction(fnName) // Function found, attempt overload resolution. if fn != nil { @@ -316,11 +263,11 @@ func (c *checker) checkCall(e *exprpb.Expr) { } // Function name not declared, record error. c.setType(e, types.ErrorType) - c.errors.undeclaredReference(e.GetId(), c.location(e), c.env.container.Name(), fnName) + c.errors.undeclaredReference(e.ID(), c.location(e), c.env.container.Name(), fnName) } func (c *checker) resolveOverloadOrError( - e *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) { + e ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) { // Attempt to resolve the overload. resolution := c.resolveOverload(e, fn, target, args) // No such overload, error noted in the resolveOverload call, type recorded here. @@ -334,7 +281,7 @@ func (c *checker) resolveOverloadOrError( } func (c *checker) resolveOverload( - call *exprpb.Expr, fn *decls.FunctionDecl, target *exprpb.Expr, args []*exprpb.Expr) *overloadResolution { + call ast.Expr, fn *decls.FunctionDecl, target ast.Expr, args []ast.Expr) *overloadResolution { var argTypes []*types.Type if target != nil { @@ -366,8 +313,8 @@ func (c *checker) resolveOverload( for i, argType := range argTypes { if !c.isAssignable(argType, types.BoolType) { c.errors.typeMismatch( - args[i].GetId(), - c.locationByID(args[i].GetId()), + args[i].ID(), + c.locationByID(args[i].ID()), types.BoolType, argType) resultType = types.ErrorType @@ -412,29 +359,29 @@ func (c *checker) resolveOverload( for i, argType := range argTypes { argTypes[i] = substitute(c.mappings, argType, true) } - c.errors.noMatchingOverload(call.GetId(), c.location(call), fn.Name(), argTypes, target != nil) + c.errors.noMatchingOverload(call.ID(), c.location(call), fn.Name(), argTypes, target != nil) return nil } return newResolution(checkedRef, resultType) } -func (c *checker) checkCreateList(e *exprpb.Expr) { - create := e.GetListExpr() +func (c *checker) checkCreateList(e ast.Expr) { + create := e.AsList() var elemsType *types.Type - optionalIndices := create.GetOptionalIndices() + optionalIndices := create.OptionalIndices() optionals := make(map[int32]bool, len(optionalIndices)) for _, optInd := range optionalIndices { optionals[optInd] = true } - for i, e := range create.GetElements() { + for i, e := range create.Elements() { c.check(e) elemType := c.getType(e) if optionals[int32(i)] { var isOptional bool elemType, isOptional = maybeUnwrapOptional(elemType) if !isOptional && !isDyn(elemType) { - c.errors.typeMismatch(e.GetId(), c.location(e), types.NewOptionalType(elemType), elemType) + c.errors.typeMismatch(e.ID(), c.location(e), types.NewOptionalType(elemType), elemType) } } elemsType = c.joinTypes(e, elemsType, elemType) @@ -446,32 +393,24 @@ func (c *checker) checkCreateList(e *exprpb.Expr) { c.setType(e, types.NewListType(elemsType)) } -func (c *checker) checkCreateStruct(e *exprpb.Expr) { - str := e.GetStructExpr() - if str.GetMessageName() != "" { - c.checkCreateMessage(e) - } else { - c.checkCreateMap(e) - } -} - -func (c *checker) checkCreateMap(e *exprpb.Expr) { - mapVal := e.GetStructExpr() +func (c *checker) checkCreateMap(e ast.Expr) { + mapVal := e.AsMap() var mapKeyType *types.Type var mapValueType *types.Type - for _, ent := range mapVal.GetEntries() { - key := ent.GetMapKey() + for _, e := range mapVal.Entries() { + entry := e.AsMapEntry() + key := entry.Key() c.check(key) mapKeyType = c.joinTypes(key, mapKeyType, c.getType(key)) - val := ent.GetValue() + val := entry.Value() c.check(val) valType := c.getType(val) - if ent.GetOptionalEntry() { + if entry.IsOptional() { var isOptional bool valType, isOptional = maybeUnwrapOptional(valType) if !isOptional && !isDyn(valType) { - c.errors.typeMismatch(val.GetId(), c.location(val), types.NewOptionalType(valType), valType) + c.errors.typeMismatch(val.ID(), c.location(val), types.NewOptionalType(valType), valType) } } mapValueType = c.joinTypes(val, mapValueType, valType) @@ -484,25 +423,28 @@ func (c *checker) checkCreateMap(e *exprpb.Expr) { c.setType(e, types.NewMapType(mapKeyType, mapValueType)) } -func (c *checker) checkCreateMessage(e *exprpb.Expr) { - msgVal := e.GetStructExpr() +func (c *checker) checkCreateStruct(e ast.Expr) { + msgVal := e.AsStruct() // Determine the type of the message. resultType := types.ErrorType - ident := c.env.LookupIdent(msgVal.GetMessageName()) + ident := c.env.LookupIdent(msgVal.TypeName()) if ident == nil { c.errors.undeclaredReference( - e.GetId(), c.location(e), c.env.container.Name(), msgVal.GetMessageName()) + e.ID(), c.location(e), c.env.container.Name(), msgVal.TypeName()) c.setType(e, types.ErrorType) return } // Ensure the type name is fully qualified in the AST. typeName := ident.Name() - msgVal.MessageName = typeName - c.setReference(e, ast.NewIdentReference(ident.Name(), nil)) + if msgVal.TypeName() != typeName { + e.SetKindCase(c.NewStruct(e.ID(), typeName, msgVal.Fields())) + msgVal = e.AsStruct() + } + c.setReference(e, ast.NewIdentReference(typeName, nil)) identKind := ident.Type().Kind() if identKind != types.ErrorKind { if identKind != types.TypeKind { - c.errors.notAType(e.GetId(), c.location(e), ident.Type().DeclaredTypeName()) + c.errors.notAType(e.ID(), c.location(e), ident.Type().DeclaredTypeName()) } else { resultType = ident.Type().Parameters()[0] // Backwards compatibility test between well-known types and message types @@ -513,7 +455,7 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) { } else if resultType.Kind() == types.StructKind { typeName = resultType.DeclaredTypeName() } else { - c.errors.notAMessageType(e.GetId(), c.location(e), resultType.DeclaredTypeName()) + c.errors.notAMessageType(e.ID(), c.location(e), resultType.DeclaredTypeName()) resultType = types.ErrorType } } @@ -521,37 +463,38 @@ func (c *checker) checkCreateMessage(e *exprpb.Expr) { c.setType(e, resultType) // Check the field initializers. - for _, ent := range msgVal.GetEntries() { - field := ent.GetFieldKey() - value := ent.GetValue() + for _, f := range msgVal.Fields() { + field := f.AsStructField() + fieldName := field.Name() + value := field.Value() c.check(value) fieldType := types.ErrorType - ft, found := c.lookupFieldType(ent.GetId(), typeName, field) + ft, found := c.lookupFieldType(f.ID(), typeName, fieldName) if found { fieldType = ft } valType := c.getType(value) - if ent.GetOptionalEntry() { + if field.IsOptional() { var isOptional bool valType, isOptional = maybeUnwrapOptional(valType) if !isOptional && !isDyn(valType) { - c.errors.typeMismatch(value.GetId(), c.location(value), types.NewOptionalType(valType), valType) + c.errors.typeMismatch(value.ID(), c.location(value), types.NewOptionalType(valType), valType) } } if !c.isAssignable(fieldType, valType) { - c.errors.fieldTypeMismatch(ent.GetId(), c.locationByID(ent.GetId()), field, fieldType, valType) + c.errors.fieldTypeMismatch(f.ID(), c.locationByID(f.ID()), fieldName, fieldType, valType) } } } -func (c *checker) checkComprehension(e *exprpb.Expr) { - comp := e.GetComprehensionExpr() - c.check(comp.GetIterRange()) - c.check(comp.GetAccuInit()) - accuType := c.getType(comp.GetAccuInit()) - rangeType := substitute(c.mappings, c.getType(comp.GetIterRange()), false) +func (c *checker) checkComprehension(e ast.Expr) { + comp := e.AsComprehension() + c.check(comp.IterRange()) + c.check(comp.AccuInit()) + accuType := c.getType(comp.AccuInit()) + rangeType := substitute(c.mappings, c.getType(comp.IterRange()), false) var varType *types.Type switch rangeType.Kind() { @@ -568,32 +511,32 @@ func (c *checker) checkComprehension(e *exprpb.Expr) { // Set the range iteration variable to type DYN as well. varType = types.DynType default: - c.errors.notAComprehensionRange(comp.GetIterRange().GetId(), c.location(comp.GetIterRange()), rangeType) + c.errors.notAComprehensionRange(comp.IterRange().ID(), c.location(comp.IterRange()), rangeType) varType = types.ErrorType } // Create a scope for the comprehension since it has a local accumulation variable. // This scope will contain the accumulation variable used to compute the result. c.env = c.env.enterScope() - c.env.AddIdents(decls.NewVariable(comp.GetAccuVar(), accuType)) + c.env.AddIdents(decls.NewVariable(comp.AccuVar(), accuType)) // Create a block scope for the loop. c.env = c.env.enterScope() - c.env.AddIdents(decls.NewVariable(comp.GetIterVar(), varType)) + c.env.AddIdents(decls.NewVariable(comp.IterVar(), varType)) // Check the variable references in the condition and step. - c.check(comp.GetLoopCondition()) - c.assertType(comp.GetLoopCondition(), types.BoolType) - c.check(comp.GetLoopStep()) - c.assertType(comp.GetLoopStep(), accuType) + c.check(comp.LoopCondition()) + c.assertType(comp.LoopCondition(), types.BoolType) + c.check(comp.LoopStep()) + c.assertType(comp.LoopStep(), accuType) // Exit the loop's block scope before checking the result. c.env = c.env.exitScope() - c.check(comp.GetResult()) + c.check(comp.Result()) // Exit the comprehension scope. c.env = c.env.exitScope() - c.setType(e, substitute(c.mappings, c.getType(comp.GetResult()), false)) + c.setType(e, substitute(c.mappings, c.getType(comp.Result()), false)) } // Checks compatibility of joined types, and returns the most general common type. -func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *types.Type { +func (c *checker) joinTypes(e ast.Expr, previous, current *types.Type) *types.Type { if previous == nil { return current } @@ -603,7 +546,7 @@ func (c *checker) joinTypes(e *exprpb.Expr, previous, current *types.Type) *type if c.dynAggregateLiteralElementTypesEnabled() { return types.DynType } - c.errors.typeMismatch(e.GetId(), c.location(e), previous, current) + c.errors.typeMismatch(e.ID(), c.location(e), previous, current) return types.ErrorType } @@ -637,41 +580,41 @@ func (c *checker) isAssignableList(l1, l2 []*types.Type) bool { return false } -func maybeUnwrapString(e *exprpb.Expr) (string, bool) { - switch e.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - literal := e.GetConstExpr() - switch literal.GetConstantKind().(type) { - case *exprpb.Constant_StringValue: - return literal.GetStringValue(), true +func maybeUnwrapString(e ast.Expr) (string, bool) { + switch e.Kind() { + case ast.LiteralKind: + literal := e.AsLiteral() + switch v := literal.(type) { + case types.String: + return string(v), true } } return "", false } -func (c *checker) setType(e *exprpb.Expr, t *types.Type) { - if old, found := c.types[e.GetId()]; found && !old.IsExactType(t) { - c.errors.incompatibleType(e.GetId(), c.location(e), e, old, t) +func (c *checker) setType(e ast.Expr, t *types.Type) { + if old, found := c.TypeMap()[e.ID()]; found && !old.IsExactType(t) { + c.errors.incompatibleType(e.ID(), c.location(e), e, old, t) return } - c.types[e.GetId()] = t + c.SetType(e.ID(), t) } -func (c *checker) getType(e *exprpb.Expr) *types.Type { - return c.types[e.GetId()] +func (c *checker) getType(e ast.Expr) *types.Type { + return c.TypeMap()[e.ID()] } -func (c *checker) setReference(e *exprpb.Expr, r *ast.ReferenceInfo) { - if old, found := c.references[e.GetId()]; found && !old.Equals(r) { - c.errors.referenceRedefinition(e.GetId(), c.location(e), e, old, r) +func (c *checker) setReference(e ast.Expr, r *ast.ReferenceInfo) { + if old, found := c.ReferenceMap()[e.ID()]; found && !old.Equals(r) { + c.errors.referenceRedefinition(e.ID(), c.location(e), e, old, r) return } - c.references[e.GetId()] = r + c.SetReference(e.ID(), r) } -func (c *checker) assertType(e *exprpb.Expr, t *types.Type) { +func (c *checker) assertType(e ast.Expr, t *types.Type) { if !c.isAssignable(t, c.getType(e)) { - c.errors.typeMismatch(e.GetId(), c.location(e), t, c.getType(e)) + c.errors.typeMismatch(e.ID(), c.location(e), t, c.getType(e)) } } @@ -687,26 +630,12 @@ func newResolution(r *ast.ReferenceInfo, t *types.Type) *overloadResolution { } } -func (c *checker) location(e *exprpb.Expr) common.Location { - return c.locationByID(e.GetId()) +func (c *checker) location(e ast.Expr) common.Location { + return c.locationByID(e.ID()) } func (c *checker) locationByID(id int64) common.Location { - positions := c.sourceInfo.GetPositions() - var line = 1 - if offset, found := positions[id]; found { - col := int(offset) - for _, lineOffset := range c.sourceInfo.GetLineOffsets() { - if lineOffset < offset { - line++ - col = int(offset - lineOffset) - } else { - break - } - } - return common.NewLocation(line, col) - } - return common.NoLocation + return c.SourceInfo().GetStartLocation(id) } func (c *checker) lookupFieldType(exprID int64, structType, fieldName string) (*types.Type, bool) { diff --git a/checker/checker_test.go b/checker/checker_test.go index e2ae6e801..16d0f6eb5 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -20,7 +20,6 @@ import ( "testing" "github.com/google/cel-go/common" - "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/stdlib" @@ -2350,7 +2349,7 @@ func TestCheck(t *testing.T) { t.Errorf("Expected error not thrown: %s", tc.err) } - actual := cAst.GetType(pAst.Expr.Id) + actual := cAst.GetType(pAst.Expr().ID()) if tc.err == "" { if actual == nil || !actual.IsEquivalentType(tc.outType) { t.Error(test.DiffMessage("Type Error", actual, tc.outType)) @@ -2358,11 +2357,7 @@ func TestCheck(t *testing.T) { } if tc.out != "" { - chkExpr, err := ast.ToProto(cAst) - if err != nil { - t.Fatalf("CheckedAstToCheckedExpr() failed: %v", err) - } - actualStr := Print(pAst.Expr, chkExpr) + actualStr := Print(pAst.Expr(), cAst) if !test.Compare(actualStr, tc.out) { t.Error(test.DiffMessage("Structure error", actualStr, tc.out)) } @@ -2445,7 +2440,7 @@ func BenchmarkCheck(b *testing.B) { b.Errorf("Expected error not thrown: %s", tc.err) } - actual := cAst.GetType(pAst.Expr.Id) + actual := cAst.GetType(pAst.Expr().ID()) if tc.err == "" { if actual == nil || !actual.IsEquivalentType(tc.outType) { b.Error(test.DiffMessage("Type Error", actual, tc.outType)) @@ -2453,11 +2448,7 @@ func BenchmarkCheck(b *testing.B) { } if tc.out != "" { - chkExpr, err := ast.ToProto(cAst) - if err != nil { - b.Fatalf("CheckedAstToCheckedExpr() failed: %v", err) - } - actualStr := Print(pAst.Expr, chkExpr) + actualStr := Print(pAst.Expr(), cAst) if !test.Compare(actualStr, tc.out) { b.Error(test.DiffMessage("Structure error", actualStr, tc.out)) } diff --git a/checker/cost.go b/checker/cost.go index 30d477c82..06c57c1e9 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -304,9 +304,9 @@ func PresenceTestHasCost(hasCost bool) CostOption { } // Cost estimates the cost of the parsed and type checked CEL expression. -func Cost(checker *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) { +func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEstimate, error) { c := &coster{ - checkedAST: checker, + checkedAST: checked, estimator: estimator, exprPath: map[int64][]string{}, iterRanges: map[string][]int64{}, @@ -319,7 +319,7 @@ func Cost(checker *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEs return CostEstimate{}, err } } - epb, err := ast.ExprToProto(checker.Expr()) + epb, err := ast.ExprToProto(checked.Expr()) if err != nil { return CostEstimate{}, err } diff --git a/checker/cost_test.go b/checker/cost_test.go index 9b47167ed..92f98d84c 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -263,7 +263,7 @@ func TestCost(t *testing.T) { }, { name: "bytes to string conversion equality", - decls: []*exprpb.Decl{decls.NewVar("input", decls.Bytes)}, + vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)}, hints: map[string]int64{"input": 500}, // equality check ensures that the resultSize calculation is included in cost expr: `string(input) == string(input)`, @@ -278,7 +278,7 @@ func TestCost(t *testing.T) { }, { name: "string to bytes conversion equality", - decls: []*exprpb.Decl{decls.NewVar("input", decls.String)}, + vars: []*decls.VariableDecl{decls.NewVariable("input", types.StringType)}, hints: map[string]int64{"input": 500}, // equality check ensures that the resultSize calculation is included in cost expr: `bytes(input) == bytes(input)`, diff --git a/checker/errors.go b/checker/errors.go index c2b96498d..c7032590d 100644 --- a/checker/errors.go +++ b/checker/errors.go @@ -20,8 +20,6 @@ import ( "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // typeErrors is a specialization of Errors. @@ -34,9 +32,9 @@ func (e *typeErrors) fieldTypeMismatch(id int64, l common.Location, name string, name, FormatCELType(field), FormatCELType(value)) } -func (e *typeErrors) incompatibleType(id int64, l common.Location, ex *exprpb.Expr, prev, next *types.Type) { +func (e *typeErrors) incompatibleType(id int64, l common.Location, ex ast.Expr, prev, next *types.Type) { e.errs.ReportErrorAtID(id, l, - "incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next) + "incompatible type already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next) } func (e *typeErrors) noMatchingOverload(id int64, l common.Location, name string, args []*types.Type, isInstance bool) { @@ -49,7 +47,7 @@ func (e *typeErrors) notAComprehensionRange(id int64, l common.Location, t *type FormatCELType(t)) } -func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field *exprpb.Expr) { +func (e *typeErrors) notAnOptionalFieldSelection(id int64, l common.Location, field ast.Expr) { e.errs.ReportErrorAtID(id, l, "unsupported optional field selection: %v", field) } @@ -61,9 +59,9 @@ func (e *typeErrors) notAMessageType(id int64, l common.Location, typeName strin e.errs.ReportErrorAtID(id, l, "'%s' is not a message type", typeName) } -func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex *exprpb.Expr, prev, next *ast.ReferenceInfo) { +func (e *typeErrors) referenceRedefinition(id int64, l common.Location, ex ast.Expr, prev, next *ast.ReferenceInfo) { e.errs.ReportErrorAtID(id, l, - "reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.GetId(), prev, next) + "reference already exists for expression: %v(%d) old:%v, new:%v", ex, ex.ID(), prev, next) } func (e *typeErrors) typeDoesNotSupportFieldSelection(id int64, l common.Location, t *types.Type) { @@ -87,6 +85,6 @@ func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typ e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName) } -func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex *exprpb.Expr) { +func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex ast.Expr) { e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex)) } diff --git a/checker/printer.go b/checker/printer.go index 15cba06ee..7a3984f02 100644 --- a/checker/printer.go +++ b/checker/printer.go @@ -17,40 +17,40 @@ package checker import ( "sort" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/debug" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) type semanticAdorner struct { - checks *exprpb.CheckedExpr + checked *ast.AST } var _ debug.Adorner = &semanticAdorner{} func (a *semanticAdorner) GetMetadata(elem any) string { result := "" - e, isExpr := elem.(*exprpb.Expr) + e, isExpr := elem.(ast.Expr) if !isExpr { return result } - t := a.checks.TypeMap[e.GetId()] + t := a.checked.TypeMap()[e.ID()] if t != nil { result += "~" - result += FormatCheckedType(t) + result += FormatCELType(t) } - switch e.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr, - *exprpb.Expr_CallExpr, - *exprpb.Expr_StructExpr, - *exprpb.Expr_SelectExpr: - if ref, found := a.checks.ReferenceMap[e.GetId()]; found { - if len(ref.GetOverloadId()) == 0 { + switch e.Kind() { + case ast.IdentKind, + ast.CallKind, + ast.ListKind, + ast.StructKind, + ast.SelectKind: + if ref, found := a.checked.ReferenceMap()[e.ID()]; found { + if len(ref.OverloadIDs) == 0 { result += "^" + ref.Name } else { - sort.Strings(ref.GetOverloadId()) - for i, overload := range ref.GetOverloadId() { + sort.Strings(ref.OverloadIDs) + for i, overload := range ref.OverloadIDs { if i == 0 { result += "^" } else { @@ -68,7 +68,7 @@ func (a *semanticAdorner) GetMetadata(elem any) string { // Print returns a string representation of the Expr message, // annotated with types from the CheckedExpr. The Expr must // be a sub-expression embedded in the CheckedExpr. -func Print(e *exprpb.Expr, checks *exprpb.CheckedExpr) string { - a := &semanticAdorner{checks: checks} +func Print(e ast.Expr, checked *ast.AST) string { + a := &semanticAdorner{checked: checked} return debug.ToAdornedDebugString(e, a) } diff --git a/common/ast/BUILD.bazel b/common/ast/BUILD.bazel index a29f201fa..c92a0f179 100644 --- a/common/ast/BUILD.bazel +++ b/common/ast/BUILD.bazel @@ -7,6 +7,7 @@ package( "//common:__subpackages__", "//ext:__subpackages__", "//interpreter:__subpackages__", + "//parser:__subpackages__", ], licenses = ["notice"], # Apache 2.0 ) diff --git a/common/ast/ast.go b/common/ast/ast.go index 652272789..8612e99bb 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -54,6 +54,14 @@ func (a *AST) GetType(id int64) *types.Type { return types.DynType } +// SetType sets the type of the expression node at the given id. +func (a *AST) SetType(id int64, t *types.Type) { + if a == nil { + return + } + a.typeMap[id] = t +} + // TypeMap returns the map of expression ids to type-checked types. // // If the AST is not type-checked, the map will be empty. @@ -82,6 +90,14 @@ func (a *AST) ReferenceMap() map[int64]*ReferenceInfo { return a.refMap } +// SetReference adds a reference to the checked AST type map. +func (a *AST) SetReference(id int64, r *ReferenceInfo) { + if a == nil { + return + } + a.refMap[id] = r +} + // IsChecked returns whether the AST is type-checked. func (a *AST) IsChecked() bool { return a != nil && a.checked diff --git a/common/ast/conversion.go b/common/ast/conversion.go index d8516b187..f2cc01ff5 100644 --- a/common/ast/conversion.go +++ b/common/ast/conversion.go @@ -273,6 +273,11 @@ func ExprToProto(e Expr) (*exprpb.Expr, error) { case StructKind: return protoStruct(e.ID(), e.AsStruct()) case UnspecifiedExprKind: + // Handle the case where a macro reference may be getting translated. + // A nested macro 'pointer' is a non-zero expression id with no kind set. + if e.ID() != 0 { + return &exprpb.Expr{Id: e.ID()}, nil + } return &exprpb.Expr{}, nil } return nil, fmt.Errorf("unsupported expr kind: %v", e) diff --git a/common/ast/conversion_test.go b/common/ast/conversion_test.go index 2ed5e096b..79226bf31 100644 --- a/common/ast/conversion_test.go +++ b/common/ast/conversion_test.go @@ -203,25 +203,23 @@ func TestConvertExpr(t *testing.T) { if len(errs.GetErrors()) != 0 { t.Fatalf("Parse() failed: %s", errs.ToDisplayString()) } + gotPBExpr, err := ast.ExprToProto(parsed.Expr()) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } wantPBExpr, err := ast.ExprToProto(tc.wantExpr) if err != nil { t.Fatalf("ast.ExprToProto() failed: %v", err) } - if !proto.Equal(parsed.GetExpr(), wantPBExpr) { + if !proto.Equal(gotPBExpr, wantPBExpr) { t.Errorf("got %v\n, wanted %v", - prototext.Format(parsed.GetExpr()), prototext.Format(wantPBExpr)) - } - gotExpr, err := ast.ProtoToExpr(parsed.GetExpr()) - if err != nil { - t.Fatalf("ast.ProtoToExpr() failed: %v", err) + prototext.Format(gotPBExpr), prototext.Format(wantPBExpr)) } + gotExpr := parsed.Expr() if !reflect.DeepEqual(gotExpr, tc.wantExpr) { t.Errorf("got %v, wanted %v", gotExpr, tc.wantExpr) } - info, err := ast.ProtoToSourceInfo(parsed.GetSourceInfo()) - if err != nil { - t.Fatalf("ast.ProtoToSourceInfo() failed: %v", err) - } + info := parsed.SourceInfo() for id, wantCall := range tc.macroCalls { call, found := info.GetMacroCall(id) if !found { @@ -231,14 +229,6 @@ func TestConvertExpr(t *testing.T) { t.Errorf("macro call got %v, wanted %v", call, wantCall) } } - pbInfo, err := ast.SourceInfoToProto(info) - if err != nil { - t.Fatalf("ast.SourceInfoToProto() failed: %v", err) - } - if !proto.Equal(parsed.GetSourceInfo(), pbInfo) { - t.Errorf("source info roundtrip got %v, wanted %v", - prototext.Format(pbInfo), prototext.Format(parsed.GetSourceInfo())) - } }) } } diff --git a/common/containers/BUILD.bazel b/common/containers/BUILD.bazel index 3f3f07887..81197f064 100644 --- a/common/containers/BUILD.bazel +++ b/common/containers/BUILD.bazel @@ -12,7 +12,7 @@ go_library( ], importpath = "github.com/google/cel-go/common/containers", deps = [ - "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "//common/ast:go_default_library", ], ) @@ -26,6 +26,6 @@ go_test( ":go_default_library", ], deps = [ - "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "//common/ast:go_default_library", ], ) diff --git a/common/containers/container.go b/common/containers/container.go index d46698d3c..52153d4cd 100644 --- a/common/containers/container.go +++ b/common/containers/container.go @@ -20,7 +20,7 @@ import ( "fmt" "strings" - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/ast" ) var ( @@ -297,19 +297,19 @@ func Name(name string) ContainerOption { // ToQualifiedName converts an expression AST into a qualified name if possible, with a boolean // 'found' value that indicates if the conversion is successful. -func ToQualifiedName(e *exprpb.Expr) (string, bool) { - switch e.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr: - id := e.GetIdentExpr() - return id.GetName(), true - case *exprpb.Expr_SelectExpr: - sel := e.GetSelectExpr() +func ToQualifiedName(e ast.Expr) (string, bool) { + switch e.Kind() { + case ast.IdentKind: + id := e.AsIdent() + return id, true + case ast.SelectKind: + sel := e.AsSelect() // Test only expressions are not valid as qualified names. - if sel.GetTestOnly() { + if sel.IsTestOnly() { return "", false } - if qual, found := ToQualifiedName(sel.GetOperand()); found { - return qual + "." + sel.GetField(), true + if qual, found := ToQualifiedName(sel.Operand()); found { + return qual + "." + sel.FieldName(), true } } return "", false diff --git a/common/containers/container_test.go b/common/containers/container_test.go index d2b89ef17..224490e95 100644 --- a/common/containers/container_test.go +++ b/common/containers/container_test.go @@ -18,7 +18,7 @@ import ( "reflect" "testing" - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/ast" ) func TestContainers_ResolveCandidateNames(t *testing.T) { @@ -204,13 +204,8 @@ func TestContainers_Extend_Name(t *testing.T) { } func TestContainers_ToQualifiedName(t *testing.T) { - ident := &exprpb.Expr{ - ExprKind: &exprpb.Expr_IdentExpr{ - IdentExpr: &exprpb.Expr_Ident{ - Name: "var", - }, - }, - } + fac := ast.NewExprFactory() + ident := fac.NewIdent(1, "var") idName, found := ToQualifiedName(ident) if !found { t.Errorf("got not found from %v expr, wanted found", ident) @@ -218,14 +213,7 @@ func TestContainers_ToQualifiedName(t *testing.T) { if idName != "var" { t.Errorf("got %v, wanted 'var'", idName) } - sel := &exprpb.Expr{ - ExprKind: &exprpb.Expr_SelectExpr{ - SelectExpr: &exprpb.Expr_Select{ - Operand: ident, - Field: "qualifier", - }, - }, - } + sel := fac.NewSelect(2, ident, "qualifier") qualName, found := ToQualifiedName(sel) if !found { t.Errorf("got not found from %v expr, wanted found", sel) @@ -234,22 +222,14 @@ func TestContainers_ToQualifiedName(t *testing.T) { t.Errorf("got %v, wanted 'var.qualifier'", qualName) } - sel.GetSelectExpr().TestOnly = true - _, found = ToQualifiedName(sel) + pres := fac.NewPresenceTest(2, ident, "qualifier") + _, found = ToQualifiedName(pres) if found { t.Error("got found, wanted not found for test-only expression") } - unary := &exprpb.Expr{ - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Function: "!_", - Args: []*exprpb.Expr{ident}, - }, - }, - } - sel.GetSelectExpr().TestOnly = false - sel.GetSelectExpr().Operand = unary + unary := fac.NewCall(2, "!_", ident) + sel = fac.NewSelect(3, unary, "qualifier") _, found = ToQualifiedName(sel) if found { t.Errorf("got found, wanted not found for %v", sel) diff --git a/common/debug/BUILD.bazel b/common/debug/BUILD.bazel index 1f029839c..724ed3404 100644 --- a/common/debug/BUILD.bazel +++ b/common/debug/BUILD.bazel @@ -13,6 +13,8 @@ go_library( importpath = "github.com/google/cel-go/common/debug", deps = [ "//common:go_default_library", - "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", + "//common/ast:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", ], ) diff --git a/common/debug/debug.go b/common/debug/debug.go index 5dab156ef..e4c01ac6e 100644 --- a/common/debug/debug.go +++ b/common/debug/debug.go @@ -22,7 +22,9 @@ import ( "strconv" "strings" - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" ) // Adorner returns debug metadata that will be tacked on to the string @@ -38,7 +40,7 @@ type Writer interface { // Buffer pushes an expression into an internal queue of expressions to // write to a string. - Buffer(e *exprpb.Expr) + Buffer(e ast.Expr) } type emptyDebugAdorner struct { @@ -51,12 +53,12 @@ func (a *emptyDebugAdorner) GetMetadata(e any) string { } // ToDebugString gives the unadorned string representation of the Expr. -func ToDebugString(e *exprpb.Expr) string { +func ToDebugString(e ast.Expr) string { return ToAdornedDebugString(e, emptyAdorner) } // ToAdornedDebugString gives the adorned string representation of the Expr. -func ToAdornedDebugString(e *exprpb.Expr, adorner Adorner) string { +func ToAdornedDebugString(e ast.Expr, adorner Adorner) string { w := newDebugWriter(adorner) w.Buffer(e) return w.String() @@ -78,49 +80,51 @@ func newDebugWriter(a Adorner) *debugWriter { } } -func (w *debugWriter) Buffer(e *exprpb.Expr) { +func (w *debugWriter) Buffer(e ast.Expr) { if e == nil { return } - switch e.ExprKind.(type) { - case *exprpb.Expr_ConstExpr: - w.append(formatLiteral(e.GetConstExpr())) - case *exprpb.Expr_IdentExpr: - w.append(e.GetIdentExpr().Name) - case *exprpb.Expr_SelectExpr: - w.appendSelect(e.GetSelectExpr()) - case *exprpb.Expr_CallExpr: - w.appendCall(e.GetCallExpr()) - case *exprpb.Expr_ListExpr: - w.appendList(e.GetListExpr()) - case *exprpb.Expr_StructExpr: - w.appendStruct(e.GetStructExpr()) - case *exprpb.Expr_ComprehensionExpr: - w.appendComprehension(e.GetComprehensionExpr()) + switch e.Kind() { + case ast.LiteralKind: + w.append(formatLiteral(e.AsLiteral())) + case ast.IdentKind: + w.append(e.AsIdent()) + case ast.SelectKind: + w.appendSelect(e.AsSelect()) + case ast.CallKind: + w.appendCall(e.AsCall()) + case ast.ListKind: + w.appendList(e.AsList()) + case ast.MapKind: + w.appendMap(e.AsMap()) + case ast.StructKind: + w.appendStruct(e.AsStruct()) + case ast.ComprehensionKind: + w.appendComprehension(e.AsComprehension()) } w.adorn(e) } -func (w *debugWriter) appendSelect(sel *exprpb.Expr_Select) { - w.Buffer(sel.GetOperand()) +func (w *debugWriter) appendSelect(sel ast.SelectExpr) { + w.Buffer(sel.Operand()) w.append(".") - w.append(sel.GetField()) - if sel.TestOnly { + w.append(sel.FieldName()) + if sel.IsTestOnly() { w.append("~test-only~") } } -func (w *debugWriter) appendCall(call *exprpb.Expr_Call) { - if call.Target != nil { - w.Buffer(call.GetTarget()) +func (w *debugWriter) appendCall(call ast.CallExpr) { + if call.IsMemberFunction() { + w.Buffer(call.Target()) w.append(".") } - w.append(call.GetFunction()) + w.append(call.FunctionName()) w.append("(") - if len(call.GetArgs()) > 0 { + if len(call.Args()) > 0 { w.addIndent() w.appendLine() - for i, arg := range call.GetArgs() { + for i, arg := range call.Args() { if i > 0 { w.append(",") w.appendLine() @@ -133,12 +137,12 @@ func (w *debugWriter) appendCall(call *exprpb.Expr_Call) { w.append(")") } -func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) { +func (w *debugWriter) appendList(list ast.ListExpr) { w.append("[") - if len(list.GetElements()) > 0 { + if len(list.Elements()) > 0 { w.appendLine() w.addIndent() - for i, elem := range list.GetElements() { + for i, elem := range list.Elements() { if i > 0 { w.append(",") w.appendLine() @@ -151,32 +155,25 @@ func (w *debugWriter) appendList(list *exprpb.Expr_CreateList) { w.append("]") } -func (w *debugWriter) appendStruct(obj *exprpb.Expr_CreateStruct) { - if obj.MessageName != "" { - w.appendObject(obj) - } else { - w.appendMap(obj) - } -} - -func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) { - w.append(obj.GetMessageName()) +func (w *debugWriter) appendStruct(obj ast.StructExpr) { + w.append(obj.TypeName()) w.append("{") - if len(obj.GetEntries()) > 0 { + if len(obj.Fields()) > 0 { w.appendLine() w.addIndent() - for i, entry := range obj.GetEntries() { + for i, f := range obj.Fields() { + field := f.AsStructField() if i > 0 { w.append(",") w.appendLine() } - if entry.GetOptionalEntry() { + if field.IsOptional() { w.append("?") } - w.append(entry.GetFieldKey()) + w.append(field.Name()) w.append(":") - w.Buffer(entry.GetValue()) - w.adorn(entry) + w.Buffer(field.Value()) + w.adorn(f) } w.removeIndent() w.appendLine() @@ -184,23 +181,24 @@ func (w *debugWriter) appendObject(obj *exprpb.Expr_CreateStruct) { w.append("}") } -func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) { +func (w *debugWriter) appendMap(m ast.MapExpr) { w.append("{") - if len(obj.GetEntries()) > 0 { + if m.Size() > 0 { w.appendLine() w.addIndent() - for i, entry := range obj.GetEntries() { + for i, e := range m.Entries() { + entry := e.AsMapEntry() if i > 0 { w.append(",") w.appendLine() } - if entry.GetOptionalEntry() { + if entry.IsOptional() { w.append("?") } - w.Buffer(entry.GetMapKey()) + w.Buffer(entry.Key()) w.append(":") - w.Buffer(entry.GetValue()) - w.adorn(entry) + w.Buffer(entry.Value()) + w.adorn(e) } w.removeIndent() w.appendLine() @@ -208,62 +206,62 @@ func (w *debugWriter) appendMap(obj *exprpb.Expr_CreateStruct) { w.append("}") } -func (w *debugWriter) appendComprehension(comprehension *exprpb.Expr_Comprehension) { +func (w *debugWriter) appendComprehension(comprehension ast.ComprehensionExpr) { w.append("__comprehension__(") w.addIndent() w.appendLine() w.append("// Variable") w.appendLine() - w.append(comprehension.GetIterVar()) + w.append(comprehension.IterVar()) w.append(",") w.appendLine() w.append("// Target") w.appendLine() - w.Buffer(comprehension.GetIterRange()) + w.Buffer(comprehension.IterRange()) w.append(",") w.appendLine() w.append("// Accumulator") w.appendLine() - w.append(comprehension.GetAccuVar()) + w.append(comprehension.AccuVar()) w.append(",") w.appendLine() w.append("// Init") w.appendLine() - w.Buffer(comprehension.GetAccuInit()) + w.Buffer(comprehension.AccuInit()) w.append(",") w.appendLine() w.append("// LoopCondition") w.appendLine() - w.Buffer(comprehension.GetLoopCondition()) + w.Buffer(comprehension.LoopCondition()) w.append(",") w.appendLine() w.append("// LoopStep") w.appendLine() - w.Buffer(comprehension.GetLoopStep()) + w.Buffer(comprehension.LoopStep()) w.append(",") w.appendLine() w.append("// Result") w.appendLine() - w.Buffer(comprehension.GetResult()) + w.Buffer(comprehension.Result()) w.append(")") w.removeIndent() } -func formatLiteral(c *exprpb.Constant) string { - switch c.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - return fmt.Sprintf("%t", c.GetBoolValue()) - case *exprpb.Constant_BytesValue: - return fmt.Sprintf("b\"%s\"", string(c.GetBytesValue())) - case *exprpb.Constant_DoubleValue: - return fmt.Sprintf("%v", c.GetDoubleValue()) - case *exprpb.Constant_Int64Value: - return fmt.Sprintf("%d", c.GetInt64Value()) - case *exprpb.Constant_StringValue: - return strconv.Quote(c.GetStringValue()) - case *exprpb.Constant_Uint64Value: - return fmt.Sprintf("%du", c.GetUint64Value()) - case *exprpb.Constant_NullValue: +func formatLiteral(c ref.Val) string { + switch v := c.(type) { + case types.Bool: + return fmt.Sprintf("%t", v) + case types.Bytes: + return fmt.Sprintf("b\"%s\"", string(v)) + case types.Double: + return fmt.Sprintf("%v", float64(v)) + case types.Int: + return fmt.Sprintf("%d", int64(v)) + case types.String: + return strconv.Quote(string(v)) + case types.Uint: + return fmt.Sprintf("%du", uint64(v)) + case types.Null: return "null" default: panic("Unknown constant type") diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 73269bf07..2b0f4864a 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -1608,15 +1608,11 @@ func TestInterpreter_ProtoAttributeOpt(t *testing.T) { func TestInterpreter_LogicalAndMissingType(t *testing.T) { parsed := testMustParse(t, `a && TestProto{c: true}.c`) - ex, err := ast.ProtoToExpr(parsed.GetExpr()) - if err != nil { - t.Fatalf("ast.ProtoToExpr() failed: %v", err) - } reg := newTestRegistry(t) cont := containers.DefaultContainer attrs := NewAttributeFactory(cont, reg, reg) intr := newStandardInterpreter(t, cont, reg, reg, attrs) - i, err := intr.NewInterpretable(ast.NewAST(ex, nil)) + i, err := intr.NewInterpretable(parsed) if err == nil { t.Errorf("Got '%v', wanted error", i) } @@ -1624,16 +1620,12 @@ func TestInterpreter_LogicalAndMissingType(t *testing.T) { func TestInterpreter_ExhaustiveConditionalExpr(t *testing.T) { parsed := testMustParse(t, `a ? b < 1.0 : c == ['hello']`) - ex, err := ast.ProtoToExpr(parsed.GetExpr()) - if err != nil { - t.Fatalf("ast.ProtoToExpr() failed: %v", err) - } state := NewEvalState() cont := containers.DefaultContainer reg := newTestRegistry(t, &exprpb.ParsedExpr{}) attrs := NewAttributeFactory(cont, reg, reg) intr := newStandardInterpreter(t, cont, reg, reg, attrs) - interpretable, _ := intr.NewInterpretable(ast.NewAST(ex, nil), + interpretable, _ := intr.NewInterpretable(parsed, ExhaustiveEval(), Observe(EvalStateObserver(state))) vars, _ := NewActivation(map[string]any{ "a": types.True, @@ -1711,16 +1703,12 @@ func TestInterpreter_ExhaustiveLogicalOrEquals(t *testing.T) { // a || b == "b" // Operator "==" is at Expr 4, should be evaluated though "a" is true parsed := testMustParse(t, `a || b == "b"`) - ex, err := ast.ProtoToExpr(parsed.GetExpr()) - if err != nil { - t.Fatalf("ast.ProtoToExpr() failed: %v", err) - } state := NewEvalState() reg := newTestRegistry(t, &exprpb.Expr{}) cont := testContainer("test") attrs := NewAttributeFactory(cont, reg, reg) interp := newStandardInterpreter(t, cont, reg, reg, attrs) - i, _ := interp.NewInterpretable(ast.NewAST(ex, nil), + i, _ := interp.NewInterpretable(parsed, ExhaustiveEval(), Observe(EvalStateObserver(state))) vars, _ := NewActivation(map[string]any{ "a": true, @@ -1997,12 +1985,8 @@ func program(t testing.TB, tst *testCase, opts ...InterpretableDecorator) (Inter return nil, nil, errors.New(errs.ToDisplayString()) } if tst.unchecked { - ex, err := ast.ProtoToExpr(parsed.GetExpr()) - if err != nil { - t.Fatalf("ast.ProtoToExpr() failed: %v", err) - } // Build the program plan. - prg, err := interp.NewInterpretable(ast.NewAST(ex, nil), opts...) + prg, err := interp.NewInterpretable(parsed, opts...) if err != nil { return nil, nil, err } @@ -2045,7 +2029,7 @@ func isFieldQual(q Qualifier, fieldName string) bool { return f.Name == fieldName } -func testMustParse(t testing.TB, data any) *exprpb.ParsedExpr { +func testMustParse(t testing.TB, data any) *ast.AST { t.Helper() p, err := parser.NewParser() if err != nil { diff --git a/interpreter/prune.go b/interpreter/prune.go index b8834b1cb..504e5bf6d 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -15,6 +15,7 @@ package interpreter import ( + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" @@ -67,22 +68,30 @@ type astPruner struct { // compiled and constant folded expressions, but is not willing to constant // fold(and thus cache results of) some external calls, then they can prepare // the overloads accordingly. -func PruneAst(expr *exprpb.Expr, macroCalls map[int64]*exprpb.Expr, state EvalState) *exprpb.ParsedExpr { +func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *ast.AST { pruneState := NewEvalState() for _, id := range state.IDs() { v, _ := state.Value(id) pruneState.SetValue(id, v) } + pbExpr, _ := ast.ExprToProto(expr) + pbMacroCalls := make(map[int64]*exprpb.Expr, len(macroCalls)) + for id, call := range macroCalls { + pbMacroCalls[id], _ = ast.ExprToProto(call) + } pruner := &astPruner{ - expr: expr, - macroCalls: macroCalls, + expr: pbExpr, + macroCalls: pbMacroCalls, state: pruneState, - nextExprID: getMaxID(expr)} - newExpr, _ := pruner.maybePrune(expr) - return &exprpb.ParsedExpr{ - Expr: newExpr, - SourceInfo: &exprpb.SourceInfo{MacroCalls: pruner.macroCalls}, + nextExprID: getMaxID(pbExpr)} + newPBExpr, _ := pruner.maybePrune(pbExpr) + newExpr, _ := ast.ProtoToExpr(newPBExpr) + newInfo := ast.NewSourceInfo(nil) + for id, pbCall := range pruner.macroCalls { + call, _ := ast.ProtoToExpr(pbCall) + newInfo.SetMacroCall(id, call) } + return ast.NewAST(newExpr, newInfo) } func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr { diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index 989369428..c32ca0f39 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -466,23 +466,19 @@ func TestPrune(t *testing.T) { addFunctionBindings(t, dispatcher) dispatcher.Add(funcBindings(t, optionalDecls(t)...)...) interp := NewInterpreter(dispatcher, containers.DefaultContainer, reg, reg, attrs) - ex, err := ast.ProtoToExpr(parsed.GetExpr()) - if err != nil { - t.Fatalf("ast.ProtoToExpr() failed: %v", err) - } - interpretable, err := interp.NewInterpretable(ast.NewAST(ex, nil), + interpretable, err := interp.NewInterpretable(parsed, ExhaustiveEval(), Observe(EvalStateObserver(state))) if err != nil { t.Fatalf("NewUncheckedInterpretable() failed: %v", err) } interpretable.Eval(testActivation(t, tst.in)) - newExpr := PruneAst(parsed.GetExpr(), parsed.GetSourceInfo().GetMacroCalls(), state) + newExpr := PruneAst(parsed.Expr(), parsed.SourceInfo().MacroCalls(), state) if tst.iterRange != "" { - compre := newExpr.GetExpr().GetComprehensionExpr() - if compre == nil { - t.Fatalf("iter range check cannot operate on non comprehension output: %v", newExpr.GetExpr()) + if newExpr.Expr().Kind() != ast.ComprehensionKind { + t.Fatalf("iter range check cannot operate on non comprehension output: %v", newExpr.Expr()) } - gotIterRange, err := parser.Unparse(compre.GetIterRange(), newExpr.GetSourceInfo()) + compre := newExpr.Expr().AsComprehension() + gotIterRange, err := parser.Unparse(compre.IterRange(), newExpr.SourceInfo()) if err != nil { t.Fatalf("parser.Unparse() failed: %v", err) } @@ -490,7 +486,7 @@ func TestPrune(t *testing.T) { t.Errorf("iter range unparse got: %v, wanted %v", gotIterRange, tst.iterRange) } } - actual, err := parser.Unparse(newExpr.GetExpr(), newExpr.GetSourceInfo()) + actual, err := parser.Unparse(newExpr.Expr(), newExpr.SourceInfo()) if err != nil { t.Fatalf("parser.Unparse() failed: %v", err) } diff --git a/parser/BUILD.bazel b/parser/BUILD.bazel index 67ecc9554..af15000cd 100644 --- a/parser/BUILD.bazel +++ b/parser/BUILD.bazel @@ -20,6 +20,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//common:go_default_library", + "//common/ast:go_default_library", "//common/operators:go_default_library", "//common/runes:go_default_library", "//parser/gen:go_default_library", @@ -43,7 +44,9 @@ go_test( ":go_default_library", ], deps = [ + "//common/ast:go_default_library", "//common/debug:go_default_library", + "//common/types:go_default_library", "//parser/gen:go_default_library", "//test:go_default_library", "@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library", diff --git a/parser/helper_test.go b/parser/helper_test.go index c262b0dc3..f57c851c2 100644 --- a/parser/helper_test.go +++ b/parser/helper_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "google.golang.org/protobuf/proto" @@ -44,15 +45,25 @@ func TestExprHelperCopy(t *testing.T) { if err != nil { t.Fatalf("NewParser() failed: %v", err) } - ast, errs := p.Parse(src) + parsed, errs := p.Parse(src) if errs != nil && len(errs.GetErrors()) != 0 { t.Fatalf("p.Parse(%v) failed: %v", src, errs.ToDisplayString()) } - // id generation is consistent between runs, and in this case '27' refers to the // macro expression that's the sole argument to the noop() macro. - macroTarget := ast.GetSourceInfo().GetMacroCalls()[int64(27)] - if proto.Equal(ast.GetExpr(), macroTarget) { - t.Errorf("Copy() failed to provide a unique ids: %v", macroTarget) + macroTarget, found := parsed.SourceInfo().GetMacroCall(27) + if !found { + t.Fatal("GetMacroCall(27) failed to find call") + } + macroExpr, err := ast.ExprToProto(macroTarget) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + protoExpr, err := ast.ExprToProto(parsed.Expr()) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + if proto.Equal(protoExpr, macroExpr) { + t.Errorf("Copy() failed to provide a unique ids: %v", macroExpr) } } diff --git a/parser/parser.go b/parser/parser.go index 109326a93..292b64cb6 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -26,6 +26,7 @@ import ( antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/runes" "github.com/google/cel-go/parser/gen" @@ -88,7 +89,7 @@ func mustNewParser(opts ...Option) *Parser { } // Parse parses the expression represented by source and returns the result. -func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) { +func (p *Parser) Parse(source common.Source) (*ast.AST, *common.Errors) { errs := common.NewErrors(source) impl := parser{ errors: &parseErrors{errs}, @@ -114,10 +115,17 @@ func (p *Parser) Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors } else { e = impl.parse(buf, source.Description()) } - return &exprpb.ParsedExpr{ - Expr: e, - SourceInfo: impl.helper.getSourceInfo(), - }, errs + sourceInfo, err := ast.ProtoToSourceInfo(impl.helper.getSourceInfo()) + if err != nil { + impl.reportError(common.NoLocation, + "failed to convert proto source info: %v", impl.helper.getSourceInfo()) + } + out, err := ast.ProtoToExpr(e) + if err != nil { + impl.reportError(common.NoLocation, + "failed to convert proto expression: %v", e) + } + return ast.NewAST(out, sourceInfo), errs } // reservedIds are not legal to use as variables. We exclude them post-parse, as they *are* valid @@ -150,7 +158,7 @@ var reservedIds = map[string]struct{}{ // This function calls ParseWithMacros with AllMacros. // // Deprecated: Use NewParser().Parse() instead. -func Parse(source common.Source) (*exprpb.ParsedExpr, *common.Errors) { +func Parse(source common.Source) (*ast.AST, *common.Errors) { return mustNewParser(Macros(AllMacros...)).Parse(source) } diff --git a/parser/parser_test.go b/parser/parser_test.go index 14e0d8ed2..d34247d90 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -22,7 +22,9 @@ import ( "testing" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/debug" + "github.com/google/cel-go/common/types" "github.com/google/cel-go/test" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -1777,71 +1779,82 @@ type metadata interface { } type kindAndIDAdorner struct { - sourceInfo *exprpb.SourceInfo + sourceInfo *ast.SourceInfo } func (k *kindAndIDAdorner) GetMetadata(elem any) string { - switch elem.(type) { - case *exprpb.Expr: - e := elem.(*exprpb.Expr) - macroCalls := k.sourceInfo.GetMacroCalls() - if macroCalls != nil { - if val, found := macroCalls[e.GetId()]; found { - return fmt.Sprintf("^#%d:%s#", e.GetId(), val.GetCallExpr().GetFunction()) - } + switch e := elem.(type) { + case ast.Expr: + if macroCall, found := k.sourceInfo.GetMacroCall(e.ID()); found { + return fmt.Sprintf("^#%d:%s#", e.ID(), macroCall.AsCall().FunctionName()) } - var valType any = e.ExprKind - switch valType.(type) { - case *exprpb.Expr_ConstExpr: - valType = e.GetConstExpr().GetConstantKind() + var valType string + switch e.Kind() { + case ast.CallKind: + valType = "*expr.Expr_CallExpr" + case ast.ComprehensionKind: + valType = "*expr.Expr_ComprehensionExpr" + case ast.IdentKind: + valType = "*expr.Expr_IdentExpr" + case ast.LiteralKind: + lit := e.AsLiteral() + switch lit.(type) { + case types.Bool: + valType = "*expr.Constant_BoolValue" + case types.Bytes: + valType = "*expr.Constant_BytesValue" + case types.Double: + valType = "*expr.Constant_DoubleValue" + case types.Int: + valType = "*expr.Constant_Int64Value" + case types.Null: + valType = "*expr.Constant_NullValue" + case types.String: + valType = "*expr.Constant_StringValue" + case types.Uint: + valType = "*expr.Constant_Uint64Value" + default: + valType = reflect.TypeOf(lit).String() + } + case ast.ListKind: + valType = "*expr.Expr_ListExpr" + case ast.MapKind, ast.StructKind: + valType = "*expr.Expr_StructExpr" + case ast.SelectKind: + valType = "*expr.Expr_SelectExpr" } - return fmt.Sprintf("^#%d:%s#", e.GetId(), reflect.TypeOf(valType)) - case *exprpb.Expr_CreateStruct_Entry: - entry := elem.(*exprpb.Expr_CreateStruct_Entry) - return fmt.Sprintf("^#%d:%s#", entry.GetId(), "*expr.Expr_CreateStruct_Entry") + return fmt.Sprintf("^#%d:%s#", e.ID(), valType) + case ast.EntryExpr: + return fmt.Sprintf("^#%d:%s#", e.ID(), "*expr.Expr_CreateStruct_Entry") } return "" } type locationAdorner struct { - sourceInfo *exprpb.SourceInfo + sourceInfo *ast.SourceInfo } var _ metadata = &locationAdorner{} func (l *locationAdorner) GetLocation(exprID int64) (common.Location, bool) { - if pos, found := l.sourceInfo.GetPositions()[exprID]; found { - var line = 1 - for _, lineOffset := range l.sourceInfo.GetLineOffsets() { - if lineOffset > pos { - break - } else { - line++ - } - } - var column = pos - if line > 1 { - column = pos - l.sourceInfo.GetLineOffsets()[line-2] - } - return common.NewLocation(line, int(column)), true - } - return common.NoLocation, false + loc := l.sourceInfo.GetStartLocation(exprID) + return loc, loc != common.NoLocation } func (l *locationAdorner) GetMetadata(elem any) string { var elemID int64 - switch elem.(type) { - case *exprpb.Expr: - elemID = elem.(*exprpb.Expr).GetId() - case *exprpb.Expr_CreateStruct_Entry: - elemID = elem.(*exprpb.Expr_CreateStruct_Entry).GetId() + switch elem := elem.(type) { + case ast.Expr: + elemID = elem.ID() + case ast.EntryExpr: + elemID = elem.ID() } location, _ := l.GetLocation(elemID) return fmt.Sprintf("^#%d[%d,%d]#", elemID, location.Line(), location.Column()) } -func convertMacroCallsToString(source *exprpb.SourceInfo) string { - macroCalls := source.GetMacroCalls() +func convertMacroCallsToString(source *ast.SourceInfo) string { + macroCalls := source.MacroCalls() keys := make([]int64, len(macroCalls)) adornedStrings := make([]string, len(macroCalls)) i := 0 @@ -1849,14 +1862,17 @@ func convertMacroCallsToString(source *exprpb.SourceInfo) string { keys[i] = k i++ } + fac := ast.NewExprFactory() // Sort the keys in descending order to create a stable ordering for tests and improve readability. sort.Slice(keys, func(i, j int) bool { return keys[i] > keys[j] }) i = 0 for _, key := range keys { - call := macroCalls[int64(key)] - callWithID := &exprpb.Expr{ - Id: int64(key), - ExprKind: call.GetExprKind(), + call := macroCalls[int64(key)].AsCall() + var callWithID ast.Expr + if call.IsMemberFunction() { + callWithID = fac.NewMemberCall(int64(key), call.FunctionName(), call.Target(), call.Args()...) + } else { + callWithID = fac.NewCall(int64(key), call.FunctionName(), call.Args()...) } adornedStrings[i] = debug.ToAdornedDebugString( callWithID, @@ -1882,7 +1898,7 @@ func TestParse(t *testing.T) { p = newTestParser(t, tc.Opts...) } src := common.NewTextSource(tc.I) - parsedExpr, errors := p.Parse(src) + parsed, errors := p.Parse(src) if len(errors.GetErrors()) > 0 { actualErr := errors.ToDisplayString() if tc.E == "" { @@ -1895,20 +1911,20 @@ func TestParse(t *testing.T) { t.Fatalf("Expected error not thrown: '%s'", tc.E) } failureDisplayMethod := fmt.Sprintf("Parse(\"%s\")", tc.I) - actualWithKind := debug.ToAdornedDebugString(parsedExpr.GetExpr(), &kindAndIDAdorner{}) + actualWithKind := debug.ToAdornedDebugString(parsed.Expr(), &kindAndIDAdorner{}) if !test.Compare(actualWithKind, tc.P) { t.Fatal(test.DiffMessage(fmt.Sprintf("Structure - %s", failureDisplayMethod), actualWithKind, tc.P)) } if tc.L != "" { - actualWithLocation := debug.ToAdornedDebugString(parsedExpr.GetExpr(), &locationAdorner{parsedExpr.GetSourceInfo()}) + actualWithLocation := debug.ToAdornedDebugString(parsed.Expr(), &locationAdorner{parsed.SourceInfo()}) if !test.Compare(actualWithLocation, tc.L) { t.Fatal(test.DiffMessage(fmt.Sprintf("Location - %s", failureDisplayMethod), actualWithLocation, tc.L)) } } if tc.M != "" { - actualAdornedMacroCalls := convertMacroCallsToString(parsedExpr.GetSourceInfo()) + actualAdornedMacroCalls := convertMacroCallsToString(parsed.SourceInfo()) if !test.Compare(actualAdornedMacroCalls, tc.M) { t.Fatal(test.DiffMessage(fmt.Sprintf("Macro Calls - %s", failureDisplayMethod), actualAdornedMacroCalls, tc.M)) } diff --git a/parser/unparser.go b/parser/unparser.go index c3c40a0dd..4c4d125b7 100644 --- a/parser/unparser.go +++ b/parser/unparser.go @@ -20,6 +20,7 @@ import ( "strconv" "strings" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -39,14 +40,21 @@ import ( // // This function optionally takes in one or more UnparserOption to alter the unparsing behavior, such as // performing word wrapping on expressions. -func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo, opts ...UnparserOption) (string, error) { +func Unparse(expr ast.Expr, info *ast.SourceInfo, opts ...UnparserOption) (string, error) { + pbExpr, err := ast.ExprToProto(expr) + if err != nil { + return "", err + } + pbInfo, err := ast.SourceInfoToProto(info) + if err != nil { + return "", err + } unparserOpts := &unparserOption{ wrapOnColumn: defaultWrapOnColumn, wrapAfterColumnLimit: defaultWrapAfterColumnLimit, operatorsToWrapOn: defaultOperatorsToWrapOn, } - var err error for _, opt := range opts { unparserOpts, err = opt(unparserOpts) if err != nil { @@ -55,10 +63,10 @@ func Unparse(expr *exprpb.Expr, info *exprpb.SourceInfo, opts ...UnparserOption) } un := &unparser{ - info: info, + info: pbInfo, options: unparserOpts, } - err = un.visit(expr) + err = un.visit(pbExpr) if err != nil { return "", err } diff --git a/parser/unparser_test.go b/parser/unparser_test.go index c464c1010..3415f6e46 100644 --- a/parser/unparser_test.go +++ b/parser/unparser_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "google.golang.org/protobuf/proto" @@ -463,7 +464,7 @@ func TestUnparse(t *testing.T) { if len(iss.GetErrors()) > 0 { t.Fatalf("parser.Parse(%s) failed: %v", tc.in, iss.ToDisplayString()) } - out, err := Unparse(p.GetExpr(), p.GetSourceInfo(), tc.unparserOptions...) + out, err := Unparse(p.Expr(), p.SourceInfo(), tc.unparserOptions...) if err != nil { t.Fatalf("Unparse(%s) failed: %v", tc.in, err) @@ -479,8 +480,14 @@ func TestUnparse(t *testing.T) { if len(iss.GetErrors()) > 0 { t.Fatalf("parser.Parse(%s) roundtrip failed: %v", tc.in, iss.ToDisplayString()) } - before := p.GetExpr() - after := p2.GetExpr() + before, err := ast.ExprToProto(p.Expr()) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + after, err := ast.ExprToProto(p2.Expr()) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } if !proto.Equal(before, after) { t.Errorf("Roundtrip Parse() differs from original. Got '%v', wanted '%v'", before, after) } @@ -503,15 +510,6 @@ func TestUnparseErrors(t *testing.T) { unparserOptions []UnparserOption }{ {name: "empty_expr", in: &exprpb.Expr{}, err: errors.New("unsupported expression")}, - { - name: "bad_constant", - in: &exprpb.Expr{ - ExprKind: &exprpb.Expr_ConstExpr{ - ConstExpr: &exprpb.Constant{}, - }, - }, - err: errors.New("unsupported constant"), - }, { name: "bad_args", in: &exprpb.Expr{ @@ -600,7 +598,12 @@ func TestUnparseErrors(t *testing.T) { for _, tst := range tests { tc := tst t.Run(tc.name, func(t *testing.T) { - out, err := Unparse(tc.in, &exprpb.SourceInfo{}, tc.unparserOptions...) + info := ast.NewSourceInfo(nil) + e, err := ast.ProtoToExpr(tc.in) + if err != nil { + t.Fatalf("ast.ProtoToExpr(%v) failed: %v", tc.in, err) + } + out, err := Unparse(e, info, tc.unparserOptions...) if err == nil { t.Fatalf("Unparse(%v) got %v, wanted error %v", tc.in, out, tc.err) } From b99d1226802fad1e6c69414042329124b40b1e70 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 14 Aug 2023 12:55:02 -0400 Subject: [PATCH 07/31] Add test coverage for SourceInfo to proto (#802) --- common/ast/conversion_test.go | 174 ++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) diff --git a/common/ast/conversion_test.go b/common/ast/conversion_test.go index 79226bf31..523c5a812 100644 --- a/common/ast/conversion_test.go +++ b/common/ast/conversion_test.go @@ -233,6 +233,180 @@ func TestConvertExpr(t *testing.T) { } } +func TestSourceInfoToProto(t *testing.T) { + expr := `[{}, {'field': true}].exists(i, has(i.field))` + p, err := parser.NewParser( + parser.Macros(parser.AllMacros...), + parser.EnableOptionalSyntax(true), + parser.PopulateMacroCalls(true)) + if err != nil { + t.Fatalf("parser.NewParser() failed: %v", err) + } + parsed, errs := p.Parse(common.NewTextSource(expr)) + if len(errs.GetErrors()) != 0 { + t.Fatalf("Parse() failed: %s", errs.ToDisplayString()) + } + pbInfo, err := ast.SourceInfoToProto(parsed.SourceInfo()) + if err != nil { + t.Fatalf("SourceInfoToProto() failed: %v", err) + } + wantInfo := ` + location: "" + line_offsets: 46 + positions: { + key: 1 + value: 0 + } + positions: { + key: 2 + value: 1 + } + positions: { + key: 3 + value: 5 + } + positions: { + key: 4 + value: 13 + } + positions: { + key: 5 + value: 6 + } + positions: { + key: 6 + value: 15 + } + positions: { + key: 7 + value: 28 + } + positions: { + key: 8 + value: 29 + } + positions: { + key: 9 + value: 35 + } + positions: { + key: 10 + value: 36 + } + positions: { + key: 11 + value: 37 + } + positions: { + key: 12 + value: 35 + } + positions: { + key: 13 + value: 28 + } + positions: { + key: 14 + value: 28 + } + positions: { + key: 15 + value: 28 + } + positions: { + key: 16 + value: 28 + } + positions: { + key: 17 + value: 28 + } + positions: { + key: 18 + value: 28 + } + positions: { + key: 19 + value: 28 + } + positions: { + key: 20 + value: 28 + } + macro_calls: { + key: 12 + value: { + call_expr: { + function: "has" + args: { + id: 11 + select_expr: { + operand: { + id: 10 + ident_expr: { + name: "i" + } + } + field: "field" + } + } + } + } + } + macro_calls: { + key: 20 + value: { + call_expr: { + target: { + id: 1 + list_expr: { + elements: { + id: 2 + struct_expr: {} + } + elements: { + id: 3 + struct_expr: { + entries: { + id: 4 + map_key: { + id: 5 + const_expr: { + string_value: "field" + } + } + value: { + id: 6 + const_expr: { + bool_value: true + } + } + } + } + } + } + } + function: "exists" + args: { + id: 8 + ident_expr: { + name: "i" + } + } + args: { + id: 12 + } + } + } + }` + wantPBInfo := &exprpb.SourceInfo{} + prototext.Unmarshal([]byte(wantInfo), wantPBInfo) + if !proto.Equal(pbInfo, wantPBInfo) { + t.Errorf("SourceInfoToProto() got %v, wanted %v", + prototext.Format(pbInfo), prototext.Format(wantPBInfo)) + } +} + func TestReferenceInfoToProtoError(t *testing.T) { out, err := ast.ReferenceInfoToProto( ast.NewIdentReference("SECOND", types.Duration{Duration: time.Duration(1) * time.Second})) From bf4f82c46582c592ac020281e32b2fd8bcfb1a32 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Mon, 14 Aug 2023 16:40:48 -0400 Subject: [PATCH 08/31] Update RenumberIDs to support stable id generation (#803) --- common/ast/expr.go | 8 ++++---- common/ast/expr_test.go | 7 ++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/common/ast/expr.go b/common/ast/expr.go index 1fad19fb1..044be5d03 100644 --- a/common/ast/expr.go +++ b/common/ast/expr.go @@ -155,8 +155,8 @@ type EntryExpr interface { isEntryExpr() } -// IDGenerator produces monotonically increasing ids suitable for tagging expression nodes. -type IDGenerator func() int64 +// IDGenerator produces unique ids suitable for tagging expression nodes +type IDGenerator func(originalID int64) int64 // CallExpr defines an interface for inspecting a function call and its arugments. type CallExpr interface { @@ -435,7 +435,7 @@ func (e *expr) RenumberIDs(idGen IDGenerator) { if e.Kind() == UnspecifiedExprKind { return } - e.id = idGen() + e.id = idGen(e.id) e.exprKindCase.renumberIDs(idGen) } @@ -751,7 +751,7 @@ func (e *entryExpr) AsStructField() StructField { } func (e *entryExpr) RenumberIDs(idGen IDGenerator) { - e.id = idGen() + e.id = idGen(e.id) e.entryExprKindCase.renumberIDs(idGen) } diff --git a/common/ast/expr_test.go b/common/ast/expr_test.go index 5b38464db..d430068c6 100644 --- a/common/ast/expr_test.go +++ b/common/ast/expr_test.go @@ -474,8 +474,13 @@ func nilTestExpr(t testing.TB) ast.Expr { } func testIDGen(seed int64) ast.IDGenerator { - return func() int64 { + seen := map[int64]int64{} + return func(originalID int64) int64 { + if id, found := seen[originalID]; found { + return id + } seed++ + seen[originalID] = seed return seed } } From 2ef121b0e843304d908ffd02ecb7b4bd79594db5 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 15 Aug 2023 17:33:59 -0400 Subject: [PATCH 09/31] Fix SetKindCase and Renumbering bugs (#805) --- common/ast/expr.go | 7 +++++-- common/ast/expr_test.go | 29 +++++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/common/ast/expr.go b/common/ast/expr.go index 044be5d03..5811e3957 100644 --- a/common/ast/expr.go +++ b/common/ast/expr.go @@ -387,6 +387,7 @@ func (e *expr) SetKindCase(other Expr) { function: c.FunctionName(), target: c.Target(), args: c.Args(), + isMember: c.IsMemberFunction(), } case ComprehensionKind: c := other.AsComprehension() @@ -432,11 +433,13 @@ func (e *expr) SetKindCase(other Expr) { } func (e *expr) RenumberIDs(idGen IDGenerator) { - if e.Kind() == UnspecifiedExprKind { + if e == nil { return } e.id = idGen(e.id) - e.exprKindCase.renumberIDs(idGen) + if e.exprKindCase != nil { + e.exprKindCase.renumberIDs(idGen) + } } type baseCallExpr struct { diff --git a/common/ast/expr_test.go b/common/ast/expr_test.go index d430068c6..b6a5a41df 100644 --- a/common/ast/expr_test.go +++ b/common/ast/expr_test.go @@ -20,13 +20,16 @@ import ( "testing" "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" ) func TestSetKindCase(t *testing.T) { fac := ast.NewExprFactory() tests := []ast.Expr{ + fac.NewUnspecifiedExpr(1), fac.NewCall(1, "_==_", fac.NewLiteral(2, types.True), fac.NewLiteral(3, types.False)), + fac.NewMemberCall(1, overloads.Size, fac.NewLiteral(2, types.String("hello"))), fac.NewComprehension(12, fac.NewList(1, []ast.Expr{}, []int32{}), "i", @@ -80,6 +83,10 @@ func TestSetKindCase(t *testing.T) { if !reflect.DeepEqual(expr.AsList(), tst.AsList()) { t.Errorf("got %v, wanted %v", expr.AsList(), tst.AsList()) } + case ast.UnspecifiedExprKind: + if !reflect.DeepEqual(expr, tst) { + t.Errorf("got %v, wanted %v", expr, tst) + } case ast.MapKind: case ast.SelectKind: case ast.StructKind: @@ -162,8 +169,8 @@ func TestCallNil(t *testing.T) { t.Errorf("empty Target() got %d, wanted 0", call.Target().ID()) } expr.RenumberIDs(testIDGen(100)) - if expr.ID() != 1 { - t.Errorf("Renumbering an unspecified expression mutated the value: %v", expr) + if expr.ID() != 101 { + t.Errorf("RenumberIDs() got %d, wanted 101", expr.ID()) } } @@ -241,8 +248,8 @@ func TestComprehensionNil(t *testing.T) { t.Errorf("Result() got %v, wanted unspecified", comp.Result().Kind()) } expr.RenumberIDs(testIDGen(100)) - if expr.ID() != 1 { - t.Errorf("Renumbering an unspecified expression mutated the value: %v", expr) + if expr.ID() != 101 { + t.Errorf("RenumberIDs() got %d, wanted 101", expr.ID()) } } @@ -462,6 +469,20 @@ func TestStructFieldNil(t *testing.T) { } } +func TestRenumberIDs(t *testing.T) { + fac := ast.NewExprFactory() + e := fac.NewUnspecifiedExpr(10) + e.RenumberIDs(testIDGen(100)) + if e.ID() != 101 { + t.Errorf("RenumberIDs() got %d, wanted 101", e.ID()) + } + e = fac.NewLiteral(20, types.True) + e.RenumberIDs(testIDGen(200)) + if e.ID() != 201 { + t.Errorf("RenumberIDs() got %d, wanted 201", e.ID()) + } +} + func nilTestExpr(t testing.TB) ast.Expr { t.Helper() fac := ast.NewExprFactory() From 036015e53069fa955324298e2d4a84beca915668 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 15 Aug 2023 19:12:59 -0400 Subject: [PATCH 10/31] Add NavigableExpr.Depth() and a helper method for navigating from Expr values (#807) --- common/ast/navigable.go | 36 +++++++++++++++++++++++--- common/ast/navigable_test.go | 50 ++++++++++++++++++++++++++++++++++-- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/common/ast/navigable.go b/common/ast/navigable.go index 27b0fcbb6..d4fe2395a 100644 --- a/common/ast/navigable.go +++ b/common/ast/navigable.go @@ -21,11 +21,31 @@ type NavigableExpr interface { // Children returns a list of child expression nodes. Children() []NavigableExpr + + // Depth indicates the depth in the expression tree. + // + // The root expression has depth 0. + Depth() int } // NavigateAST converts an AST to a NavigableExpr func NavigateAST(ast *AST) NavigableExpr { - return newNavigableExpr(ast, nil, ast.Expr()) + return NavigateExpr(ast, ast.Expr()) +} + +// NavigateExpr creates a NavigableExpr whose type information is backed by the input AST. +// +// If the expression is already a NavigableExpr, the parent and depth information will be +// propagated on the new NavigableExpr value; otherwise, the expr value will be treated +// as though it is the root of the expression graph with a depth of 0. +func NavigateExpr(ast *AST, expr Expr) NavigableExpr { + depth := 0 + var parent NavigableExpr = nil + if nav, ok := expr.(NavigableExpr); ok { + depth = nav.Depth() + parent, _ = nav.Parent() + } + return newNavigableExpr(ast, parent, expr, depth) } // ExprMatcher takes a NavigableExpr in and indicates whether the value is a match. @@ -112,9 +132,14 @@ func matchIsConstantValue(e NavigableExpr) bool { return false } -func newNavigableExpr(ast *AST, parent NavigableExpr, expr Expr) NavigableExpr { +func newNavigableExpr(ast *AST, parent NavigableExpr, expr Expr, depth int) NavigableExpr { + // Reduce navigable expression nesting by unwrapping the embedded Expr value. + if nav, ok := expr.(*navigableExprImpl); ok { + expr = nav.Expr + } nav := &navigableExprImpl{ Expr: expr, + depth: depth, ast: ast, parent: parent, createChildren: getChildFactory(expr), @@ -124,6 +149,7 @@ func newNavigableExpr(ast *AST, parent NavigableExpr, expr Expr) NavigableExpr { type navigableExprImpl struct { Expr + depth int ast *AST parent NavigableExpr createChildren childFactory @@ -152,6 +178,10 @@ func (nav *navigableExprImpl) Children() []NavigableExpr { return nav.createChildren(nav) } +func (nav *navigableExprImpl) Depth() int { + return nav.depth +} + func (nav *navigableExprImpl) AsCall() CallExpr { return navigableCallImpl{navigableExprImpl: nav} } @@ -185,7 +215,7 @@ func (nav *navigableExprImpl) AsStruct() StructExpr { } func (nav *navigableExprImpl) createChild(e Expr) NavigableExpr { - return newNavigableExpr(nav.ast, nav, e) + return newNavigableExpr(nav.ast, nav, e, nav.depth+1) } func (nav *navigableExprImpl) isExpr() {} diff --git a/common/ast/navigable_test.go b/common/ast/navigable_test.go index 0145aadf8..e3dac7a56 100644 --- a/common/ast/navigable_test.go +++ b/common/ast/navigable_test.go @@ -31,56 +31,66 @@ import ( proto3pb "github.com/google/cel-go/test/proto3pb" ) -func TestNavigateExpr(t *testing.T) { +func TestNavigateAST(t *testing.T) { tests := []struct { expr string descendantCount int callCount int + maxDepth int }{ { expr: `'a' == 'b'`, descendantCount: 3, callCount: 1, + maxDepth: 1, }, { expr: `'a'.size()`, descendantCount: 2, callCount: 1, + maxDepth: 1, }, { expr: `[1, 2, 3]`, descendantCount: 4, callCount: 0, + maxDepth: 1, }, { expr: `[1, 2, 3][0]`, descendantCount: 6, callCount: 1, + maxDepth: 2, }, { expr: `{1u: 'hello'}`, descendantCount: 3, callCount: 0, + maxDepth: 1, }, { expr: `{'hello': 'world'}.hello`, descendantCount: 4, callCount: 0, + maxDepth: 2, }, { expr: `type(1) == int`, descendantCount: 4, callCount: 2, + maxDepth: 2, }, { expr: `google.expr.proto3.test.TestAllTypes{single_int32: 1}`, descendantCount: 2, callCount: 0, + maxDepth: 1, }, { expr: `[true].exists(i, i)`, descendantCount: 11, // 2 for iter range, 1 for accu init, 4 for loop condition, 3 for loop step, 1 for result callCount: 3, // @not_strictly_false(!result), accu_init || i + maxDepth: 3, }, } @@ -93,6 +103,15 @@ func TestNavigateExpr(t *testing.T) { if len(descendants) != tc.descendantCount { t.Errorf("ast.MatchDescendants(%v) got %d descendants, wanted %d", checked.Expr(), len(descendants), tc.descendantCount) } + maxDepth := 0 + for _, d := range descendants { + if d.Depth() > maxDepth { + maxDepth = d.Depth() + } + } + if maxDepth != tc.maxDepth { + t.Errorf("got max NavigableExpr.Depth() of %d, wanted %d", maxDepth, tc.maxDepth) + } calls := ast.MatchSubset(descendants, ast.KindMatcher(ast.CallKind)) if len(calls) != tc.callCount { t.Errorf("ast.MatchSubset(%v) got %d calls, wanted %d", checked.Expr(), len(calls), tc.callCount) @@ -101,7 +120,7 @@ func TestNavigateExpr(t *testing.T) { } } -func TestNavigateExprNilSafety(t *testing.T) { +func TestNavigableASTNilSafety(t *testing.T) { tests := []struct { name string e ast.NavigableExpr @@ -143,6 +162,33 @@ func TestNavigateExprNilSafety(t *testing.T) { } } +func TestNavigableExpr(t *testing.T) { + checkedAST := mustTypeCheck(t, `'a' == 'b'`) + navAST := ast.NavigateAST(checkedAST) + literals := ast.MatchDescendants(navAST, func(expr ast.NavigableExpr) bool { + return expr.Kind() == ast.LiteralKind && + expr.AsLiteral().Equal(types.String("a")) == types.True + }) + if len(literals) != 1 { + t.Fatalf("MatchDescendents('a') got %d results, wanted 1", len(literals)) + } + litA := literals[0] + if litA.Depth() != 1 { + t.Fatalf("litA.Depth() got %d, wanted 1", litA.Depth()) + } + parent, found := litA.Parent() + if !found { + t.Fatal("litA.Parent() returned nil") + } + if parent.Kind() != ast.CallKind && parent.AsCall().FunctionName() != "==" { + t.Fatalf("litA.Parent() got %v, watned '==' function call", parent) + } + litAPrime := ast.NavigateExpr(checkedAST, litA) + if litAPrime.Depth() != litA.Depth() { + t.Errorf("litAPrime.Depth() != litA.Depth(), got %d, wanted %d", litAPrime.Depth(), litA.Depth()) + } +} + func TestNavigableCallExprMember(t *testing.T) { checked := mustTypeCheck(t, `'hello'.size()`) expr := ast.NavigateAST(checked) From 51cf846fcbb3c59212e00350bede7935d9e43988 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Tue, 15 Aug 2023 19:30:40 -0400 Subject: [PATCH 11/31] Migrate cel.Ast to be a thin layer on ast.AST (#806) * Migrate cel.Ast to be a thin layer on ast.AST * Adjust IsChecked() criteria --- cel/cel_test.go | 2 +- cel/env.go | 82 ++++++++++++----------------------------------- cel/io.go | 18 +++-------- cel/program.go | 14 +------- cel/validator.go | 6 +--- common/ast/ast.go | 5 +-- 6 files changed, 29 insertions(+), 98 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 27c633253..3dd8f75cf 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -2063,7 +2063,7 @@ func TestOptionalValuesCompile(t *testing.T) { if iss.Err() != nil { t.Fatalf("%v failed: %v", tc.expr, iss.Err()) } - for id, reference := range ast.refMap { + for id, reference := range ast.impl.ReferenceMap() { other, found := tc.references[id] if !found { t.Errorf("Compile(%v) expected reference %d: %v", tc.expr, id, reference) diff --git a/cel/env.go b/cel/env.go index d786012da..96b61ad70 100644 --- a/cel/env.go +++ b/cel/env.go @@ -38,11 +38,8 @@ type Source = common.Source // Ast representing the checked or unchecked expression, its source, and related metadata such as // source position information. type Ast struct { - source Source - expr celast.Expr - info *celast.SourceInfo - refMap map[int64]*celast.ReferenceInfo - typeMap map[int64]*types.Type + source Source + impl *celast.AST } // Expr returns the proto serializable instance of the parsed/checked expression. @@ -50,7 +47,7 @@ func (ast *Ast) Expr() *exprpb.Expr { if ast == nil { return nil } - pbExpr, _ := celast.ExprToProto(ast.expr) + pbExpr, _ := celast.ExprToProto(ast.impl.Expr()) return pbExpr } @@ -59,7 +56,7 @@ func (ast *Ast) IsChecked() bool { if ast == nil { return false } - return ast.typeMap != nil && len(ast.typeMap) > 0 + return ast.impl.IsChecked() } // SourceInfo returns character offset and newline position information about expression elements. @@ -67,7 +64,7 @@ func (ast *Ast) SourceInfo() *exprpb.SourceInfo { if ast == nil { return nil } - pbInfo, _ := celast.SourceInfoToProto(ast.info) + pbInfo, _ := celast.SourceInfoToProto(ast.impl.SourceInfo()) return pbInfo } @@ -90,11 +87,7 @@ func (ast *Ast) OutputType() *Type { if ast == nil { return types.ErrorType } - t, found := ast.typeMap[ast.expr.ID()] - if !found { - return DynType - } - return t + return ast.impl.GetType(ast.impl.Expr().ID()) } // Source returns a view of the input used to create the Ast. This source may be complete or @@ -215,21 +208,18 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { if err != nil { errs := common.NewErrors(ast.Source()) errs.ReportError(common.NoLocation, err.Error()) - return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo()) + return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo()) } - res, errs := checker.Check(celast.NewAST(ast.expr, ast.info), ast.Source(), chk) + checked, errs := checker.Check(ast.impl, ast.Source(), chk) if len(errs.GetErrors()) > 0 { - return nil, NewIssuesWithSourceInfo(errs, ast.SourceInfo()) + return nil, NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo()) } // Manually create the Ast to ensure that the Ast source information (which may be more // detailed than the information provided by Check), is returned to the caller. ast = &Ast{ - source: ast.Source(), - expr: res.Expr(), - info: res.SourceInfo(), - refMap: res.ReferenceMap(), - typeMap: res.TypeMap()} + source: ast.Source(), + impl: checked} // Generate a validator configuration from the set of configured validators. vConfig := newValidatorConfig() @@ -239,9 +229,9 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { } } // Apply additional validators on the type-checked result. - iss := NewIssuesWithSourceInfo(errs, ast.SourceInfo()) + iss := NewIssuesWithSourceInfo(errs, ast.impl.SourceInfo()) for _, v := range e.validators { - v.Validate(e, vConfig, res, iss) + v.Validate(e, vConfig, checked, iss) } if iss.Err() != nil { return nil, iss @@ -426,14 +416,11 @@ func (e *Env) Parse(txt string) (*Ast, *Issues) { // It is possible to have both non-nil Ast and Issues values returned from this call; however, // the mere presence of an Ast does not imply that it is valid for use. func (e *Env) ParseSource(src Source) (*Ast, *Issues) { - res, errs := e.prsr.Parse(src) + parsed, errs := e.prsr.Parse(src) if len(errs.GetErrors()) > 0 { return nil, &Issues{errs: errs} } - return &Ast{ - source: src, - expr: res.Expr(), - info: res.SourceInfo()}, nil + return &Ast{source: src, impl: parsed}, nil } // Program generates an evaluable instance of the Ast within the environment (Env). @@ -529,11 +516,8 @@ func (e *Env) PartialVars(vars any) (interpreter.PartialActivation, error) { // TODO: Consider adding an option to generate a Program.Residual to avoid round-tripping to an // Ast format and then Program again. func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { - pruned := interpreter.PruneAst(a.expr, a.info.MacroCalls(), details.State()) - newAST := &Ast{ - expr: pruned.Expr(), - info: pruned.SourceInfo(), - } + pruned := interpreter.PruneAst(a.impl.Expr(), a.impl.SourceInfo().MacroCalls(), details.State()) + newAST := &Ast{source: a.Source(), impl: pruned} expr, err := AstToString(newAST) if err != nil { return nil, err @@ -555,8 +539,7 @@ func (e *Env) ResidualAst(a *Ast, details *EvalDetails) (*Ast, error) { // EstimateCost estimates the cost of a type checked CEL expression using the length estimates of input data and // extension functions provided by estimator. func (e *Env) EstimateCost(ast *Ast, estimator checker.CostEstimator, opts ...checker.CostOption) (checker.CostEstimate, error) { - checked := celast.NewCheckedAST(celast.NewAST(ast.expr, ast.info), ast.typeMap, ast.refMap) - return checker.Cost(checked, estimator, opts...) + return checker.Cost(ast.impl, estimator, opts...) } // configure applies a series of EnvOptions to the current environment. @@ -698,7 +681,7 @@ type Error = common.Error // Note: in the future, non-fatal warnings and notices may be inspectable via the Issues struct. type Issues struct { errs *common.Errors - info *exprpb.SourceInfo + info *celast.SourceInfo } // NewIssues returns an Issues struct from a common.Errors object. @@ -709,7 +692,7 @@ func NewIssues(errs *common.Errors) *Issues { // NewIssuesWithSourceInfo returns an Issues struct from a common.Errors object with SourceInfo metatata // which can be used with the `ReportErrorAtID` method for additional error reports within the context // information that's inferred from an expression id. -func NewIssuesWithSourceInfo(errs *common.Errors, info *exprpb.SourceInfo) *Issues { +func NewIssuesWithSourceInfo(errs *common.Errors, info *celast.SourceInfo) *Issues { return &Issues{ errs: errs, info: info, @@ -759,30 +742,7 @@ func (i *Issues) String() string { // The source metadata for the expression at `id`, if present, is attached to the error report. // To ensure that source metadata is attached to error reports, use NewIssuesWithSourceInfo. func (i *Issues) ReportErrorAtID(id int64, message string, args ...any) { - i.errs.ReportErrorAtID(id, locationByID(id, i.info), message, args...) -} - -// locationByID returns a common.Location given an expression id. -// -// TODO: move this functionality into the native SourceInfo and an overhaul of the common.Source -// as this implementation relies on the abstractions present in the protobuf SourceInfo object, -// and is replicated in the checker. -func locationByID(id int64, sourceInfo *exprpb.SourceInfo) common.Location { - positions := sourceInfo.GetPositions() - var line = 1 - if offset, found := positions[id]; found { - col := int(offset) - for _, lineOffset := range sourceInfo.GetLineOffsets() { - if lineOffset < offset { - line++ - col = int(offset - lineOffset) - } else { - break - } - } - return common.NewLocation(line, col) - } - return common.NoLocation + i.errs.ReportErrorAtID(id, i.info.GetStartLocation(id), message, args...) } // getStdEnv lazy initializes the CEL standard environment. diff --git a/cel/io.go b/cel/io.go index c8c9c911f..a5471743a 100644 --- a/cel/io.go +++ b/cel/io.go @@ -47,17 +47,11 @@ func CheckedExprToAst(checkedExpr *exprpb.CheckedExpr) *Ast { // // Prefer CheckedExprToAst if loading expressions from storage. func CheckedExprToAstWithSource(checkedExpr *exprpb.CheckedExpr, src Source) (*Ast, error) { - checkedAST, err := ast.ToAST(checkedExpr) + checked, err := ast.ToAST(checkedExpr) if err != nil { return nil, err } - return &Ast{ - source: src, - expr: checkedAST.Expr(), - info: checkedAST.SourceInfo(), - refMap: checkedAST.ReferenceMap(), - typeMap: checkedAST.TypeMap(), - }, nil + return &Ast{source: src, impl: checked}, nil } // AstToCheckedExpr converts an Ast to an protobuf CheckedExpr value. @@ -92,11 +86,7 @@ func ParsedExprToAstWithSource(parsedExpr *exprpb.ParsedExpr, src Source) *Ast { src = common.NewInfoSource(parsedExpr.GetSourceInfo()) } e, _ := ast.ProtoToExpr(parsedExpr.GetExpr()) - return &Ast{ - expr: e, - info: info, - source: src, - } + return &Ast{source: src, impl: ast.NewAST(e, info)} } // AstToParsedExpr converts an Ast to an protobuf ParsedExpr value. @@ -112,7 +102,7 @@ func AstToParsedExpr(a *Ast) (*exprpb.ParsedExpr, error) { // Note, the conversion may not be an exact replica of the original expression, but will produce // a string that is semantically equivalent and whose textual representation is stable. func AstToString(a *Ast) (string, error) { - return parser.Unparse(a.expr, a.info) + return parser.Unparse(a.impl.Expr(), a.impl.SourceInfo()) } // RefValueToValue converts between ref.Val and api.expr.Value. diff --git a/cel/program.go b/cel/program.go index 2f829a2a3..1762c324a 100644 --- a/cel/program.go +++ b/cel/program.go @@ -518,19 +518,7 @@ func (p *evalActivationPool) Put(value any) { } func astToExprAST(a *Ast) (*ast.AST, error) { - expr, err := ast.ProtoToExpr(a.Expr()) - if err != nil { - return nil, err - } - info, err := ast.ProtoToSourceInfo(a.SourceInfo()) - if err != nil { - return nil, err - } - baseAST := ast.NewAST(expr, info) - if !a.IsChecked() { - return baseAST, nil - } - return ast.NewCheckedAST(baseAST, a.typeMap, a.refMap), nil + return a.impl, nil } var ( diff --git a/cel/validator.go b/cel/validator.go index 59df3f4c3..a8a86a746 100644 --- a/cel/validator.go +++ b/cel/validator.go @@ -220,11 +220,7 @@ func (v formatValidator) Validate(e *Env, _ ValidatorConfig, a *ast.AST, iss *Is } func evalCall(env *Env, call, arg ast.Expr) error { - ast := &Ast{ - expr: call, - info: ast.NewSourceInfo(nil), - source: nil, - } + ast := &Ast{impl: ast.NewAST(call, ast.NewSourceInfo(nil))} prg, err := env.Program(ast) if err != nil { return err diff --git a/common/ast/ast.go b/common/ast/ast.go index 8612e99bb..ec770d8dc 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -27,7 +27,6 @@ type AST struct { sourceInfo *SourceInfo typeMap map[int64]*types.Type refMap map[int64]*ReferenceInfo - checked bool } // Expr returns the root ast.Expr value in the AST. @@ -100,7 +99,7 @@ func (a *AST) SetReference(id int64, r *ReferenceInfo) { // IsChecked returns whether the AST is type-checked. func (a *AST) IsChecked() bool { - return a != nil && a.checked + return a != nil && len(a.TypeMap()) > 0 } // NewAST creates a base AST instance with an ast.Expr and ast.SourceInfo value. @@ -113,7 +112,6 @@ func NewAST(e Expr, sourceInfo *SourceInfo) *AST { sourceInfo: sourceInfo, typeMap: make(map[int64]*types.Type), refMap: make(map[int64]*ReferenceInfo), - checked: false, } } @@ -124,7 +122,6 @@ func NewCheckedAST(parsed *AST, typeMap map[int64]*types.Type, refMap map[int64] sourceInfo: parsed.SourceInfo(), typeMap: typeMap, refMap: refMap, - checked: len(typeMap) != 0 || len(refMap) != 0, } } From fc3b794b884fc4da23db1fd44f01a9f7c6cd659d Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 16 Aug 2023 13:38:57 -0400 Subject: [PATCH 12/31] Expose Copy() methods for AST, SourceInfo, and Expr values (#808) --- common/ast/ast.go | 46 ++++++++++++++++++++ common/ast/ast_test.go | 25 +++++++++++ common/ast/expr_test.go | 9 ++++ common/ast/factory.go | 83 ++++++++++++++++++++++++++++++++++++ common/ast/navigable_test.go | 5 ++- 5 files changed, 167 insertions(+), 1 deletion(-) diff --git a/common/ast/ast.go b/common/ast/ast.go index ec770d8dc..0a0cbdc77 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -125,6 +125,28 @@ func NewCheckedAST(parsed *AST, typeMap map[int64]*types.Type, refMap map[int64] } } +// Copy creates a deep copy of the Expr and SourceInfo values in the input AST. +// +// Copies of the Expr value are generated using an internal default ExprFactory. +func Copy(a *AST) *AST { + if a == nil { + return nil + } + e := defaultFactory.CopyExpr(a.expr) + if !a.IsChecked() { + return NewAST(e, CopySourceInfo(a.SourceInfo())) + } + typesCopy := make(map[int64]*types.Type, len(a.typeMap)) + for id, t := range a.typeMap { + typesCopy[id] = t + } + refsCopy := make(map[int64]*ReferenceInfo, len(a.refMap)) + for id, r := range a.refMap { + refsCopy[id] = r + } + return NewCheckedAST(NewAST(e, CopySourceInfo(a.SourceInfo())), typesCopy, refsCopy) +} + // NewSourceInfo creates a simple SourceInfo object from an input common.Source value. func NewSourceInfo(src common.Source) *SourceInfo { var lineOffsets []int32 @@ -141,6 +163,30 @@ func NewSourceInfo(src common.Source) *SourceInfo { } } +// CopySourceInfo creates a deep copy of the MacroCalls within the input SourceInfo. +// +// Copies of macro Expr values are generated using an internal default ExprFactory. +func CopySourceInfo(info *SourceInfo) *SourceInfo { + if info == nil { + return nil + } + rangesCopy := make(map[int64]OffsetRange, len(info.offsetRanges)) + for id, off := range info.offsetRanges { + rangesCopy[id] = off + } + callsCopy := make(map[int64]Expr, len(info.macroCalls)) + for id, call := range info.macroCalls { + callsCopy[id] = defaultFactory.CopyExpr(call) + } + return &SourceInfo{ + syntax: info.syntax, + desc: info.desc, + lines: info.lines, + offsetRanges: rangesCopy, + macroCalls: callsCopy, + } +} + // SourceInfo records basic information about the expression as a textual input and // as a parsed expression value. type SourceInfo struct { diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go index eef521827..d152b2a2d 100644 --- a/common/ast/ast_test.go +++ b/common/ast/ast_test.go @@ -25,6 +25,31 @@ import ( "github.com/google/cel-go/common/types" ) +func TestASTCopy(t *testing.T) { + tests := []string{ + `'a' == 'b'`, + `'a'.size()`, + `size('a')`, + `has({'a': 1}.a)`, + `{'a': 1}`, + `{'a': 1}['a']`, + `[1, 2, 3].exists(i, i % 2 == 1)`, + `google.expr.proto3.test.TestAllTypes{}`, + `google.expr.proto3.test.TestAllTypes{repeated_int32: [1, 2]}`, + } + + for _, tst := range tests { + checked := mustTypeCheck(t, tst) + copied := ast.Copy(checked) + if !reflect.DeepEqual(copied.Expr(), checked.Expr()) { + t.Errorf("Copy() got expr %v, wanted %v", copied.Expr(), checked.Expr()) + } + if !reflect.DeepEqual(copied.SourceInfo(), checked.SourceInfo()) { + t.Errorf("Copy() got source info %v, wanted %v", copied.SourceInfo(), checked.SourceInfo()) + } + } +} + func TestASTNilSafety(t *testing.T) { ex, err := ast.ProtoToExpr(nil) if err != nil { diff --git a/common/ast/expr_test.go b/common/ast/expr_test.go index b6a5a41df..d32bf3fee 100644 --- a/common/ast/expr_test.go +++ b/common/ast/expr_test.go @@ -88,8 +88,17 @@ func TestSetKindCase(t *testing.T) { t.Errorf("got %v, wanted %v", expr, tst) } case ast.MapKind: + if !reflect.DeepEqual(expr.AsMap(), tst.AsMap()) { + t.Errorf("got %v, wanted %v", expr, tst) + } case ast.SelectKind: + if !reflect.DeepEqual(expr.AsSelect(), tst.AsSelect()) { + t.Errorf("got %v, wanted %v", expr, tst) + } case ast.StructKind: + if !reflect.DeepEqual(expr.AsStruct(), tst.AsStruct()) { + t.Errorf("got %v, wanted %v", expr, tst) + } default: t.Errorf("unable to determine kind case: %v", tst) } diff --git a/common/ast/factory.go b/common/ast/factory.go index 9e7c57869..8f4a31f34 100644 --- a/common/ast/factory.go +++ b/common/ast/factory.go @@ -4,6 +4,12 @@ import "github.com/google/cel-go/common/types/ref" // ExprFactory interfaces defines a set of methods necessary for building native expression values. type ExprFactory interface { + // CopyExpr creates a deep copy of the input Expr value. + CopyExpr(Expr) Expr + + // CopyEntryExpr creates a deep copy of the input EntryExpr value. + CopyEntryExpr(EntryExpr) EntryExpr + // NewCall creates an Expr value representing a global function call. NewCall(id int64, function string, args ...Expr) Expr @@ -176,6 +182,79 @@ func (fac *baseExprFactory) NewUnspecifiedExpr(id int64) Expr { return fac.newExpr(id, nil) } +func (fac *baseExprFactory) CopyExpr(e Expr) Expr { + switch e.Kind() { + case CallKind: + c := e.AsCall() + argsCopy := make([]Expr, len(c.Args())) + for i, arg := range c.Args() { + argsCopy[i] = fac.CopyExpr(arg) + } + if !c.IsMemberFunction() { + return fac.NewCall(e.ID(), c.FunctionName(), argsCopy...) + } + return fac.NewMemberCall(e.ID(), c.FunctionName(), fac.CopyExpr(c.Target()), argsCopy...) + case ComprehensionKind: + compre := e.AsComprehension() + return fac.NewComprehension(e.ID(), + fac.CopyExpr(compre.IterRange()), + compre.IterVar(), + compre.AccuVar(), + fac.CopyExpr(compre.AccuInit()), + fac.CopyExpr(compre.LoopCondition()), + fac.CopyExpr(compre.LoopStep()), + fac.CopyExpr(compre.Result())) + case IdentKind: + return fac.NewIdent(e.ID(), e.AsIdent()) + case ListKind: + l := e.AsList() + elemsCopy := make([]Expr, l.Size()) + for i, elem := range l.Elements() { + elemsCopy[i] = fac.CopyExpr(elem) + } + return fac.NewList(e.ID(), elemsCopy, l.OptionalIndices()) + case LiteralKind: + return fac.NewLiteral(e.ID(), e.AsLiteral()) + case MapKind: + m := e.AsMap() + entriesCopy := make([]EntryExpr, m.Size()) + for i, entry := range m.Entries() { + entriesCopy[i] = fac.CopyEntryExpr(entry) + } + return fac.NewMap(e.ID(), entriesCopy) + case SelectKind: + s := e.AsSelect() + if s.IsTestOnly() { + return fac.NewPresenceTest(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName()) + } + return fac.NewSelect(e.ID(), fac.CopyExpr(s.Operand()), s.FieldName()) + case StructKind: + s := e.AsStruct() + fieldsCopy := make([]EntryExpr, len(s.Fields())) + for i, field := range s.Fields() { + fieldsCopy[i] = fac.CopyEntryExpr(field) + } + return fac.NewStruct(e.ID(), s.TypeName(), fieldsCopy) + default: + return fac.NewUnspecifiedExpr(e.ID()) + } +} + +func (fac *baseExprFactory) CopyEntryExpr(e EntryExpr) EntryExpr { + switch e.Kind() { + case MapEntryKind: + entry := e.AsMapEntry() + return fac.NewMapEntry(e.ID(), + fac.CopyExpr(entry.Key()), fac.CopyExpr(entry.Value()), entry.IsOptional()) + case StructFieldKind: + field := e.AsStructField() + return fac.NewStructField(e.ID(), + field.Name(), fac.CopyExpr(field.Value()), field.IsOptional()) + default: + return fac.newEntryExpr(e.ID(), nil) + } +} + func (*baseExprFactory) isExprFactory() {} func (fac *baseExprFactory) newExpr(id int64, e exprKindCase) Expr { @@ -191,3 +270,7 @@ func (fac *baseExprFactory) newEntryExpr(id int64, e entryExprKindCase) EntryExp entryExprKindCase: e, } } + +var ( + defaultFactory = &baseExprFactory{} +) diff --git a/common/ast/navigable_test.go b/common/ast/navigable_test.go index e3dac7a56..e479a8ede 100644 --- a/common/ast/navigable_test.go +++ b/common/ast/navigable_test.go @@ -446,7 +446,10 @@ func TestNavigableSelectExpr_TestOnly(t *testing.T) { func mustTypeCheck(t testing.TB, expr string) *ast.AST { t.Helper() - p, err := parser.NewParser(parser.Macros(parser.AllMacros...), parser.EnableOptionalSyntax(true)) + p, err := parser.NewParser( + parser.Macros(parser.AllMacros...), + parser.EnableOptionalSyntax(true), + parser.PopulateMacroCalls(true)) if err != nil { t.Fatalf("parser.NewParser() failed: %v", err) } From 6643a4a1f12ffb93b41f84ac05e04d91f424a74c Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 16 Aug 2023 13:39:12 -0400 Subject: [PATCH 13/31] Catch invalid literals created from expression factories (#810) --- checker/checker.go | 5 ++++- checker/checker_test.go | 19 +++++++++++++++++++ checker/errors.go | 6 ++---- 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/checker/checker.go b/checker/checker.go index c8f99ff5b..57fb3ce5e 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -18,6 +18,7 @@ package checker import ( "fmt" + "reflect" "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" @@ -78,6 +79,8 @@ func (c *checker) check(e ast.Expr) { case types.BoolType, types.BytesType, types.DoubleType, types.IntType, types.NullType, types.StringType, types.UintType: c.setType(e, literal.Type().(*types.Type)) + default: + c.errors.unexpectedASTType(e.ID(), c.location(e), "literal", literal.Type().TypeName()) } case ast.IdentKind: c.checkIdent(e) @@ -94,7 +97,7 @@ func (c *checker) check(e ast.Expr) { case ast.ComprehensionKind: c.checkComprehension(e) default: - c.errors.unexpectedASTType(e.ID(), c.location(e), e) + c.errors.unexpectedASTType(e.ID(), c.location(e), "unspecified", reflect.TypeOf(e).Name()) } } diff --git a/checker/checker_test.go b/checker/checker_test.go index 16d0f6eb5..da69bd891 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -18,8 +18,10 @@ import ( "fmt" "strings" "testing" + "time" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/containers" "github.com/google/cel-go/common/decls" "github.com/google/cel-go/common/stdlib" @@ -2542,6 +2544,23 @@ func TestCheckErrorData(t *testing.T) { } } +func TestCheckInvalidLiteral(t *testing.T) { + fac := ast.NewExprFactory() + durLiteral := fac.NewLiteral(1, types.Duration{Duration: time.Second}) + // This is not valid syntax, just for illustration purposes. + src := common.NewTextSource(`1s`) + parsed := ast.NewAST(durLiteral, ast.NewSourceInfo(src)) + reg := newTestRegistry(t) + env, err := NewEnv(containers.DefaultContainer, reg) + if err != nil { + t.Fatalf("NewEnv(cont, reg) failed: %v", err) + } + _, iss := Check(parsed, src, env) + if !strings.Contains(iss.ToDisplayString(), "unexpected literal type") { + t.Errorf("got %s, wanted 'unexpected literal type'", iss.ToDisplayString()) + } +} + func testFunction(t testing.TB, name string, opts ...decls.FunctionOpt) *decls.FunctionDecl { t.Helper() fn, err := decls.NewFunction(name, opts...) diff --git a/checker/errors.go b/checker/errors.go index c7032590d..8b3bf0b8b 100644 --- a/checker/errors.go +++ b/checker/errors.go @@ -15,8 +15,6 @@ package checker import ( - "reflect" - "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" @@ -85,6 +83,6 @@ func (e *typeErrors) unexpectedFailedResolution(id int64, l common.Location, typ e.errs.ReportErrorAtID(id, l, "unexpected failed resolution of '%s'", typeName) } -func (e *typeErrors) unexpectedASTType(id int64, l common.Location, ex ast.Expr) { - e.errs.ReportErrorAtID(id, l, "unrecognized ast type: %v", reflect.TypeOf(ex)) +func (e *typeErrors) unexpectedASTType(id int64, l common.Location, kind, typeName string) { + e.errs.ReportErrorAtID(id, l, "unexpected %s type: %v", kind, typeName) } From 0bd4d3922263e5e6919ca496e509967ddb35fdf0 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 16 Aug 2023 13:39:45 -0400 Subject: [PATCH 14/31] Minor test, code, and error hygiene changes (#809) --- cel/BUILD.bazel | 1 + cel/io.go | 6 +----- cel/io_test.go | 2 +- cel/program.go | 11 +---------- common/errors.go | 2 +- 5 files changed, 5 insertions(+), 17 deletions(-) diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index 0905f6353..aa978e064 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -57,6 +57,7 @@ go_test( "decls_test.go", "env_test.go", "io_test.go", + "validator_test.go", ], data = [ "//cel/testdata:gen_test_fds", diff --git a/cel/io.go b/cel/io.go index a5471743a..3133fb9d7 100644 --- a/cel/io.go +++ b/cel/io.go @@ -61,11 +61,7 @@ func AstToCheckedExpr(a *Ast) (*exprpb.CheckedExpr, error) { if !a.IsChecked() { return nil, fmt.Errorf("cannot convert unchecked ast") } - checked, err := astToExprAST(a) - if err != nil { - return nil, err - } - return ast.ToProto(checked) + return ast.ToProto(a.impl) } // ParsedExprToAst converts a parsed expression proto message to an Ast. diff --git a/cel/io_test.go b/cel/io_test.go index a19f9cc14..d1710f4dc 100644 --- a/cel/io_test.go +++ b/cel/io_test.go @@ -107,7 +107,7 @@ func TestAstToProto(t *testing.T) { } checked, err := AstToCheckedExpr(ast) if err != nil { - t.Fatalf("AstToCheckeExpr(ast) failed: %v", err) + t.Fatalf("AstToCheckedExpr(ast) failed: %v", err) } ast4 := CheckedExprToAst(checked) if !proto.Equal(ast4.Expr(), ast.Expr()) { diff --git a/cel/program.go b/cel/program.go index 1762c324a..cec4839df 100644 --- a/cel/program.go +++ b/cel/program.go @@ -19,7 +19,6 @@ import ( "fmt" "sync" - "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/interpreter" @@ -244,11 +243,7 @@ func newProgram(e *Env, a *Ast, opts []ProgramOption) (Program, error) { func (p *prog) initInterpretable(a *Ast, decs []interpreter.InterpretableDecorator) (*prog, error) { // When the AST has been exprAST it contains metadata that can be used to speed up program execution. - exprAST, err := astToExprAST(a) - if err != nil { - return nil, err - } - interpretable, err := p.interpreter.NewInterpretable(exprAST, decs...) + interpretable, err := p.interpreter.NewInterpretable(a.impl, decs...) if err != nil { return nil, err } @@ -517,10 +512,6 @@ func (p *evalActivationPool) Put(value any) { p.Pool.Put(a) } -func astToExprAST(a *Ast) (*ast.AST, error) { - return a.impl, nil -} - var ( // activationPool is an internally managed pool of Activation values that wrap map[string]any inputs activationPool = newEvalActivationPool() diff --git a/common/errors.go b/common/errors.go index 63919714e..25adc73d8 100644 --- a/common/errors.go +++ b/common/errors.go @@ -64,7 +64,7 @@ func (e *Errors) GetErrors() []*Error { // Append creates a new Errors object with the current and input errors. func (e *Errors) Append(errs []*Error) *Errors { return &Errors{ - errors: append(e.errors, errs...), + errors: append(e.errors[:], errs...), source: e.source, numErrors: e.numErrors + len(errs), maxErrorsToReport: e.maxErrorsToReport, From 5be9464ec501625a563055e1b471d626d57a3428 Mon Sep 17 00:00:00 2001 From: bboogler <141976241+bboogler@users.noreply.github.com> Date: Thu, 17 Aug 2023 20:29:09 -0700 Subject: [PATCH 15/31] Creating a Function that Reverses a String (#796) Creating a reverse function for CEL that takes a string and returns a new string whose characters are in the reverse order of the original. --- ext/README.md | 14 ++++++++++++++ ext/strings.go | 31 +++++++++++++++++++++++++++++++ ext/strings_test.go | 7 +++++++ 3 files changed, 52 insertions(+) diff --git a/ext/README.md b/ext/README.md index 6f621ac4a..2fac0cb22 100644 --- a/ext/README.md +++ b/ext/README.md @@ -414,3 +414,17 @@ Examples: 'TacoCat'.upperAscii() // returns 'TACOCAT' 'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII' + +### Reverse + +Returns a new string whose characters are the same as the target string, only formatted in +reverse order. +This function relies on converting strings to rune arrays in order to reverse. +It can be located in Version 3 of strings. + + .reverse() -> + +Examples: + + 'gums'.reverse() // returns 'smug' + 'John Smith'.reverse() // returns 'htimS nhoJ' \ No newline at end of file diff --git a/ext/strings.go b/ext/strings.go index 935c9b6a1..13e919545 100644 --- a/ext/strings.go +++ b/ext/strings.go @@ -267,6 +267,19 @@ const ( // // 'TacoCat'.upperAscii() // returns 'TACOCAT' // 'TacoCÆt Xii'.upperAscii() // returns 'TACOCÆT XII' +// +// # Reverse +// +// Returns a new string whose characters are the same as the target string, only formatted in +// reverse order. +// This function relies on converting strings to rune arrays in order to reverse +// +// .reverse() -> +// +// Examples: +// +// 'gums'.reverse() // returns 'smug' +// 'John Smith'.reverse() // returns 'htimS nhoJ' func Strings(options ...StringsOption) cel.EnvOption { s := &stringLib{ version: math.MaxUint32, @@ -500,6 +513,16 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption { }))), ) } + if lib.version >= 3 { + opts = append( opts, + cel.Function("reverse", + cel.MemberOverload("reverse", []*cel.Type{cel.StringType}, cel.StringType, + cel.UnaryBinding(func(str ref.Val) ref.Val { + s := str.(types.String) + return stringOrError(reverse(string(s))) + }))), + ) + } if lib.validateFormat { opts = append(opts, cel.ASTValidators(stringFormatValidator{})) } @@ -653,6 +676,14 @@ func upperASCII(str string) (string, error) { return string(runes), nil } +func reverse(str string) (string, error) { + chars := []rune(str) + for i, j := 0, len(chars)-1; i < j; i, j = i+1, j-1 { + chars[i], chars[j] = chars[j], chars[i] + } + return string(chars), nil +} + func joinSeparator(strs []string, separator string) (string, error) { return strings.Join(strs, separator), nil } diff --git a/ext/strings_test.go b/ext/strings_test.go index 002226eb2..6c1583ad3 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -105,6 +105,13 @@ var stringTests = []struct { // Upper ASCII tests. {expr: `'tacoCat'.upperAscii() == 'TACOCAT'`}, {expr: `'tacoCαt'.upperAscii() == 'TACOCαT'`}, + // Reverse tests. + {expr: `'gums'.reverse() == 'smug'`}, + {expr: `'palindromes'.reverse() == 'semordnilap'`}, + {expr: `'John Smith'.reverse() == 'htimS nhoJ'`}, + {expr: `'u180etext'.reverse() == 'txete081u'`}, + {expr: `'2600+U'.reverse() == 'U+0062'`}, + {expr: `'\u180e\u200b\u200c\u200d\u2060\ufeff'.reverse() == '\ufeff\u2060\u200d\u200c\u200b\u180e'`}, // Join tests. {expr: `['x', 'y'].join() == 'xy'`}, {expr: `['x', 'y'].join('-') == 'x-y'`}, From 0b8fcf3578818774b1b1a945797ba357173708ca Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 18 Aug 2023 12:05:01 -0400 Subject: [PATCH 16/31] Migrate the parser to the Go-native Expr (#797) * Migrate the type-checker to a native AST representation * Parser expr update * Unify cel.MacroExprFactory with ast.ExprFactory --- cel/cel_test.go | 137 ++++++++++ cel/library.go | 61 ++--- cel/macro.go | 456 +++++++++++++++++++++++++++++- common/ast/conversion.go | 12 + ext/BUILD.bazel | 2 - ext/bindings.go | 24 +- ext/guards.go | 11 +- ext/math.go | 67 +++-- ext/protos.go | 45 ++- parser/BUILD.bazel | 2 + parser/helper.go | 578 +++++++++++++++------------------------ parser/helper_test.go | 6 +- parser/macro.go | 194 ++++++------- parser/parser.go | 138 +++++----- parser/parser_test.go | 4 +- 15 files changed, 1076 insertions(+), 661 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index 3dd8f75cf..b7dd7996f 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -37,6 +37,7 @@ import ( "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/interpreter" + "github.com/google/cel-go/parser" "github.com/google/cel-go/test" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -699,6 +700,142 @@ func TestCustomMacro(t *testing.T) { } } +func TestMacroInterop(t *testing.T) { + existsOneMacro := NewReceiverMacro("exists_one", 2, + func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + return ExistsOneMacroExpander(meh, iterRange, args) + }) + transformMacro := NewReceiverMacro("transform", 2, + func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + return MapMacroExpander(meh, iterRange, args) + }) + filterMacro := NewReceiverMacro("filter", 2, + func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + return FilterMacroExpander(meh, iterRange, args) + }) + pairMacro := NewGlobalMacro("pair", 2, + func(meh MacroExprHelper, iterRange *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + return meh.NewMap(meh.NewMapEntry(args[0], args[1], false)), nil + }) + getMacro := NewReceiverMacro("get", 2, + func(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { + return meh.GlobalCall( + operators.Conditional, + meh.PresenceTest(meh.Copy(target), args[0].GetIdentExpr().GetName()), + meh.Select(meh.Copy(target), args[0].GetIdentExpr().GetName()), + meh.Copy(args[1]), + ), nil + }) + env := testEnv(t, Macros(existsOneMacro, transformMacro, filterMacro, pairMacro, getMacro)) + tests := []struct { + expr string + out ref.Val + }{ + { + expr: `['tr', 's', 'fri'].filter(i, i.size() > 1).transform(i, i + 'end').exists_one(i, i == 'friend')`, + out: types.True, + }, + { + expr: `pair('a', 'b')`, + out: types.DefaultTypeAdapter.NativeToValue(map[string]string{"a": "b"}), + }, + { + expr: `{}.get(a, 'default')`, + out: types.String("default"), + }, + { + expr: `{'a': 'b'}.get(a, 'default')`, + out: types.String("b"), + }, + } + + for _, tst := range tests { + ast, iss := env.Compile(tst.expr) + if iss.Err() != nil { + t.Fatal(iss.Err()) + } + prg, err := env.Program(ast, EvalOptions(OptExhaustiveEval)) + if err != nil { + t.Fatalf("program creation error: %s\n", err) + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + t.Fatal(err) + } + if out.Equal(tst.out) != types.True { + t.Errorf("got %v, wanted %v", out, tst.out) + } + } +} + +func TestMacroModern(t *testing.T) { + existsOneMacro := ReceiverMacro("exists_one", 2, + func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) { + return parser.MakeExistsOne(mef, iterRange, args) + }) + transformMacro := ReceiverMacro("transform", 2, + func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) { + return parser.MakeMap(mef, iterRange, args) + }) + filterMacro := ReceiverMacro("filter", 2, + func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) { + return parser.MakeFilter(mef, iterRange, args) + }) + pairMacro := GlobalMacro("pair", 2, + func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) { + return mef.NewMap(mef.NewMapEntry(args[0], args[1], false)), nil + }) + getMacro := ReceiverMacro("get", 2, + func(mef MacroExprFactory, target celast.Expr, args []celast.Expr) (celast.Expr, *Error) { + return mef.NewCall( + operators.Conditional, + mef.NewPresenceTest(mef.Copy(target), args[0].AsIdent()), + mef.NewSelect(mef.Copy(target), args[0].AsIdent()), + mef.Copy(args[1]), + ), nil + }) + env := testEnv(t, Macros(existsOneMacro, transformMacro, filterMacro, pairMacro, getMacro)) + tests := []struct { + expr string + out ref.Val + }{ + { + expr: `['tr', 's', 'fri'].filter(i, i.size() > 1).transform(i, i + 'end').exists_one(i, i == 'friend')`, + out: types.True, + }, + { + expr: `pair('a', 'b')`, + out: types.DefaultTypeAdapter.NativeToValue(map[string]string{"a": "b"}), + }, + { + expr: `{}.get(a, 'default')`, + out: types.String("default"), + }, + { + expr: `{'a': 'b'}.get(a, 'default')`, + out: types.String("b"), + }, + } + + for _, tst := range tests { + ast, iss := env.Compile(tst.expr) + if iss.Err() != nil { + t.Fatal(iss.Err()) + } + prg, err := env.Program(ast, EvalOptions(OptExhaustiveEval)) + if err != nil { + t.Fatalf("program creation error: %s\n", err) + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + t.Fatal(err) + } + if out.Equal(tst.out) != types.True { + t.Errorf("got %v, wanted %v", out, tst.out) + } + } +} + func TestCustomExistsMacro(t *testing.T) { env := testEnv(t, Variable("attr", MapType(StringType, BoolType)), diff --git a/cel/library.go b/cel/library.go index 4d232085c..deddc14e5 100644 --- a/cel/library.go +++ b/cel/library.go @@ -20,6 +20,7 @@ import ( "strings" "time" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/stdlib" @@ -28,8 +29,6 @@ import ( "github.com/google/cel-go/common/types/traits" "github.com/google/cel-go/interpreter" "github.com/google/cel-go/parser" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) const ( @@ -313,7 +312,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption { Types(types.OptionalType), // Configure the optMap and optFlatMap macros. - Macros(NewReceiverMacro(optMapMacro, 2, optMap)), + Macros(ReceiverMacro(optMapMacro, 2, optMap)), // Global and member functions for working with optional values. Function(optionalOfFunc, @@ -374,7 +373,7 @@ func (lib *optionalLib) CompileOptions() []EnvOption { Overload("optional_map_index_value", []*Type{OptionalType(mapTypeKV), paramTypeK}, optionalTypeV)), } if lib.version >= 1 { - opts = append(opts, Macros(NewReceiverMacro(optFlatMapMacro, 2, optFlatMap))) + opts = append(opts, Macros(ReceiverMacro(optFlatMapMacro, 2, optFlatMap))) } return opts } @@ -386,57 +385,57 @@ func (lib *optionalLib) ProgramOptions() []ProgramOption { } } -func optMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { +func optMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) { varIdent := args[0] varName := "" - switch varIdent.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr: - varName = varIdent.GetIdentExpr().GetName() + switch varIdent.Kind() { + case ast.IdentKind: + varName = varIdent.AsIdent() default: - return nil, meh.NewError(varIdent.GetId(), "optMap() variable name must be a simple identifier") + return nil, meh.NewError(varIdent.ID(), "optMap() variable name must be a simple identifier") } mapExpr := args[1] - return meh.GlobalCall( + return meh.NewCall( operators.Conditional, - meh.ReceiverCall(hasValueFunc, target), - meh.GlobalCall(optionalOfFunc, - meh.Fold( - unusedIterVar, + meh.NewMemberCall(hasValueFunc, target), + meh.NewCall(optionalOfFunc, + meh.NewComprehension( meh.NewList(), + unusedIterVar, varName, - meh.ReceiverCall(valueFunc, target), - meh.LiteralBool(false), - meh.Ident(varName), + meh.NewMemberCall(valueFunc, target), + meh.NewLiteral(types.False), + meh.NewIdent(varName), mapExpr, ), ), - meh.GlobalCall(optionalNoneFunc), + meh.NewCall(optionalNoneFunc), ), nil } -func optFlatMap(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { +func optFlatMap(meh MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *Error) { varIdent := args[0] varName := "" - switch varIdent.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr: - varName = varIdent.GetIdentExpr().GetName() + switch varIdent.Kind() { + case ast.IdentKind: + varName = varIdent.AsIdent() default: - return nil, meh.NewError(varIdent.GetId(), "optFlatMap() variable name must be a simple identifier") + return nil, meh.NewError(varIdent.ID(), "optFlatMap() variable name must be a simple identifier") } mapExpr := args[1] - return meh.GlobalCall( + return meh.NewCall( operators.Conditional, - meh.ReceiverCall(hasValueFunc, target), - meh.Fold( - unusedIterVar, + meh.NewMemberCall(hasValueFunc, target), + meh.NewComprehension( meh.NewList(), + unusedIterVar, varName, - meh.ReceiverCall(valueFunc, target), - meh.LiteralBool(false), - meh.Ident(varName), + meh.NewMemberCall(valueFunc, target), + meh.NewLiteral(types.False), + meh.NewIdent(varName), mapExpr, ), - meh.GlobalCall(optionalNoneFunc), + meh.NewCall(optionalNoneFunc), ), nil } diff --git a/cel/macro.go b/cel/macro.go index 1eb414c8b..4db1fd57a 100644 --- a/cel/macro.go +++ b/cel/macro.go @@ -15,6 +15,11 @@ package cel import ( + "fmt" + + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" "github.com/google/cel-go/parser" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" @@ -26,7 +31,14 @@ import ( // a Macro should be created per arg-count or as a var arg macro. type Macro = parser.Macro -// MacroExpander converts a call and its associated arguments into a new CEL abstract syntax tree. +// MacroFactory defines an expansion function which converts a call and its arguments to a cel.Expr value. +type MacroFactory = parser.MacroExpander + +// MacroExprFactory assists with the creation of Expr values in a manner which is consistent +// the internal semantics and id generation behaviors of the parser and checker libraries. +type MacroExprFactory = parser.ExprHelper + +// MacroExpander converts a call and its associated arguments into a protobuf Expr representation. // // If the MacroExpander determines within the implementation that an expansion is not needed it may return // a nil Expr value to indicate a non-match. However, if an expansion is to be performed, but the arguments @@ -36,48 +48,197 @@ type Macro = parser.Macro // and produces as output an Expr ast node. // // Note: when the Macro.IsReceiverStyle() method returns true, the target argument will be nil. -type MacroExpander = parser.MacroExpander +type MacroExpander func(eh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) // MacroExprHelper exposes helper methods for creating new expressions within a CEL abstract syntax tree. -type MacroExprHelper = parser.ExprHelper +// ExprHelper assists with the manipulation of proto-based Expr values in a manner which is +// consistent with the source position and expression id generation code leveraged by both +// the parser and type-checker. +type MacroExprHelper interface { + // Copy the input expression with a brand new set of identifiers. + Copy(*exprpb.Expr) *exprpb.Expr + + // LiteralBool creates an Expr value for a bool literal. + LiteralBool(value bool) *exprpb.Expr + + // LiteralBytes creates an Expr value for a byte literal. + LiteralBytes(value []byte) *exprpb.Expr + + // LiteralDouble creates an Expr value for double literal. + LiteralDouble(value float64) *exprpb.Expr + + // LiteralInt creates an Expr value for an int literal. + LiteralInt(value int64) *exprpb.Expr + + // LiteralString creates am Expr value for a string literal. + LiteralString(value string) *exprpb.Expr + + // LiteralUint creates an Expr value for a uint literal. + LiteralUint(value uint64) *exprpb.Expr + + // NewList creates a CreateList instruction where the list is comprised of the optional set + // of elements provided as arguments. + NewList(elems ...*exprpb.Expr) *exprpb.Expr + + // NewMap creates a CreateStruct instruction for a map where the map is comprised of the + // optional set of key, value entries. + NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + + // NewMapEntry creates a Map Entry for the key, value pair. + NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + + // NewObject creates a CreateStruct instruction for an object with a given type name and + // optional set of field initializers. + NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + + // NewObjectFieldInit creates a new Object field initializer from the field name and value. + NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + + // Fold creates a fold comprehension instruction. + // + // - iterVar is the iteration variable name. + // - iterRange represents the expression that resolves to a list or map where the elements or + // keys (respectively) will be iterated over. + // - accuVar is the accumulation variable name, typically parser.AccumulatorName. + // - accuInit is the initial expression whose value will be set for the accuVar prior to + // folding. + // - condition is the expression to test to determine whether to continue folding. + // - step is the expression to evaluation at the conclusion of a single fold iteration. + // - result is the computation to evaluate at the conclusion of the fold. + // + // The accuVar should not shadow variable names that you would like to reference within the + // environment in the step and condition expressions. Presently, the name __result__ is commonly + // used by built-in macros but this may change in the future. + Fold(iterVar string, + iterRange *exprpb.Expr, + accuVar string, + accuInit *exprpb.Expr, + condition *exprpb.Expr, + step *exprpb.Expr, + result *exprpb.Expr) *exprpb.Expr + + // Ident creates an identifier Expr value. + Ident(name string) *exprpb.Expr + + // AccuIdent returns an accumulator identifier for use with comprehension results. + AccuIdent() *exprpb.Expr + + // GlobalCall creates a function call Expr value for a global (free) function. + GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr + + // ReceiverCall creates a function call Expr value for a receiver-style function. + ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr + + // PresenceTest creates a Select TestOnly Expr value for modelling has() semantics. + PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr + + // Select create a field traversal Expr value. + Select(operand *exprpb.Expr, field string) *exprpb.Expr + + // OffsetLocation returns the Location of the expression identifier. + OffsetLocation(exprID int64) common.Location + + // NewError associates an error message with a given expression id. + NewError(exprID int64, message string) *Error +} + +// GlobalMacro creates a Macro for a global function with the specified arg count. +func GlobalMacro(function string, argCount int, factory MacroFactory) Macro { + return parser.NewGlobalMacro(function, argCount, factory) +} + +// ReceiverMacro creates a Macro for a receiver function matching the specified arg count. +func ReceiverMacro(function string, argCount int, factory MacroFactory) Macro { + return parser.NewReceiverMacro(function, argCount, factory) +} + +// GlobalVarArgMacro creates a Macro for a global function with a variable arg count. +func GlobalVarArgMacro(function string, factory MacroFactory) Macro { + return parser.NewGlobalVarArgMacro(function, factory) +} + +// ReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count. +func ReceiverVarArgMacro(function string, factory MacroFactory) Macro { + return parser.NewReceiverVarArgMacro(function, factory) +} // NewGlobalMacro creates a Macro for a global function with the specified arg count. +// +// Deprecated: use GlobalMacro func NewGlobalMacro(function string, argCount int, expander MacroExpander) Macro { - return parser.NewGlobalMacro(function, argCount, expander) + expand := adaptingExpander{expander} + return parser.NewGlobalMacro(function, argCount, expand.Expander) } // NewReceiverMacro creates a Macro for a receiver function matching the specified arg count. +// +// Deprecated: use ReceiverMacro func NewReceiverMacro(function string, argCount int, expander MacroExpander) Macro { - return parser.NewReceiverMacro(function, argCount, expander) + expand := adaptingExpander{expander} + return parser.NewReceiverMacro(function, argCount, expand.Expander) } // NewGlobalVarArgMacro creates a Macro for a global function with a variable arg count. +// +// Deprecated: use GlobalVarArgMacro func NewGlobalVarArgMacro(function string, expander MacroExpander) Macro { - return parser.NewGlobalVarArgMacro(function, expander) + expand := adaptingExpander{expander} + return parser.NewGlobalVarArgMacro(function, expand.Expander) } // NewReceiverVarArgMacro creates a Macro for a receiver function matching a variable arg count. +// +// Deprecated: use ReceiverVarArgMacro func NewReceiverVarArgMacro(function string, expander MacroExpander) Macro { - return parser.NewReceiverVarArgMacro(function, expander) + expand := adaptingExpander{expander} + return parser.NewReceiverVarArgMacro(function, expand.Expander) } // HasMacroExpander expands the input call arguments into a presence test, e.g. has(.field) func HasMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeHas(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + arg, err := adaptToExpr(args[0]) + if err != nil { + return nil, err + } + if arg.Kind() == ast.SelectKind { + s := arg.AsSelect() + return adaptToProto(ph.NewPresenceTest(s.Operand(), s.FieldName())) + } + return nil, ph.NewError(arg.ID(), "invalid argument to has() macro") } // ExistsMacroExpander expands the input call arguments into a comprehension that returns true if any of the // elements in the range match the predicate expressions: // .exists(, ) func ExistsMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeExists(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeExists(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) } // ExistsOneMacroExpander expands the input call arguments into a comprehension that returns true if exactly // one of the elements in the range match the predicate expressions: // .exists_one(, ) func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeExistsOne(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeExistsOne(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) } // MapMacroExpander expands the input call arguments into a comprehension that transforms each element in the @@ -91,14 +252,30 @@ func ExistsOneMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*ex // In the second form only iterVar values which return true when provided to the predicate expression // are transformed. func MapMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeMap(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeMap(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) } // FilterMacroExpander expands the input call arguments into a comprehension which produces a list which contains // only elements which match the provided predicate expression: // .filter(, ) func FilterMacroExpander(meh MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *Error) { - return parser.MakeFilter(meh, target, args) + ph, err := toParserHelper(meh) + if err != nil { + return nil, err + } + out, err := parser.MakeFilter(ph, mustAdaptToExpr(target), mustAdaptToExprs(args)) + if err != nil { + return nil, err + } + return adaptToProto(out) } var ( @@ -142,3 +319,258 @@ var ( // NoMacros provides an alias to an empty list of macros NoMacros = []Macro{} ) + +type adaptingExpander struct { + legacyExpander MacroExpander +} + +func (adapt *adaptingExpander) Expander(eh parser.ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + var legacyTarget *exprpb.Expr = nil + var err *Error = nil + if target != nil { + legacyTarget, err = adaptToProto(target) + if err != nil { + return nil, err + } + } + legacyArgs := make([]*exprpb.Expr, len(args)) + for i, arg := range args { + legacyArgs[i], err = adaptToProto(arg) + if err != nil { + return nil, err + } + } + ah := &adaptingHelper{modernHelper: eh} + legacyExpr, err := adapt.legacyExpander(ah, legacyTarget, legacyArgs) + if err != nil { + return nil, err + } + ex, err := adaptToExpr(legacyExpr) + if err != nil { + return nil, err + } + return ex, nil +} + +func wrapErr(id int64, message string, err error) *common.Error { + return &common.Error{ + Location: common.NoLocation, + Message: fmt.Sprintf("%s: %v", message, err), + ExprID: id, + } +} + +type adaptingHelper struct { + modernHelper parser.ExprHelper +} + +// Copy the input expression with a brand new set of identifiers. +func (ah *adaptingHelper) Copy(e *exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.Copy(mustAdaptToExpr(e))) +} + +// LiteralBool creates an Expr value for a bool literal. +func (ah *adaptingHelper) LiteralBool(value bool) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bool(value))) +} + +// LiteralBytes creates an Expr value for a byte literal. +func (ah *adaptingHelper) LiteralBytes(value []byte) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Bytes(value))) +} + +// LiteralDouble creates an Expr value for double literal. +func (ah *adaptingHelper) LiteralDouble(value float64) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Double(value))) +} + +// LiteralInt creates an Expr value for an int literal. +func (ah *adaptingHelper) LiteralInt(value int64) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Int(value))) +} + +// LiteralString creates am Expr value for a string literal. +func (ah *adaptingHelper) LiteralString(value string) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.String(value))) +} + +// LiteralUint creates an Expr value for a uint literal. +func (ah *adaptingHelper) LiteralUint(value uint64) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewLiteral(types.Uint(value))) +} + +// NewList creates a CreateList instruction where the list is comprised of the optional set +// of elements provided as arguments. +func (ah *adaptingHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewList(mustAdaptToExprs(elems)...)) +} + +// NewMap creates a CreateStruct instruction for a map where the map is comprised of the +// optional set of key, value entries. +func (ah *adaptingHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { + adaptedEntries := make([]ast.EntryExpr, len(entries)) + for i, e := range entries { + adaptedEntries[i] = mustAdaptToEntryExpr(e) + } + return mustAdaptToProto(ah.modernHelper.NewMap(adaptedEntries...)) +} + +// NewMapEntry creates a Map Entry for the key, value pair. +func (ah *adaptingHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { + return mustAdaptToProtoEntry( + ah.modernHelper.NewMapEntry(mustAdaptToExpr(key), mustAdaptToExpr(val), optional)) +} + +// NewObject creates a CreateStruct instruction for an object with a given type name and +// optional set of field initializers. +func (ah *adaptingHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { + adaptedEntries := make([]ast.EntryExpr, len(fieldInits)) + for i, e := range fieldInits { + adaptedEntries[i] = mustAdaptToEntryExpr(e) + } + return mustAdaptToProto(ah.modernHelper.NewStruct(typeName, adaptedEntries...)) +} + +// NewObjectFieldInit creates a new Object field initializer from the field name and value. +func (ah *adaptingHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { + return mustAdaptToProtoEntry( + ah.modernHelper.NewStructField(field, mustAdaptToExpr(init), optional)) +} + +// Fold creates a fold comprehension instruction. +// +// - iterVar is the iteration variable name. +// - iterRange represents the expression that resolves to a list or map where the elements or +// keys (respectively) will be iterated over. +// - accuVar is the accumulation variable name, typically parser.AccumulatorName. +// - accuInit is the initial expression whose value will be set for the accuVar prior to +// folding. +// - condition is the expression to test to determine whether to continue folding. +// - step is the expression to evaluation at the conclusion of a single fold iteration. +// - result is the computation to evaluate at the conclusion of the fold. +// +// The accuVar should not shadow variable names that you would like to reference within the +// environment in the step and condition expressions. Presently, the name __result__ is commonly +// used by built-in macros but this may change in the future. +func (ah *adaptingHelper) Fold(iterVar string, + iterRange *exprpb.Expr, + accuVar string, + accuInit *exprpb.Expr, + condition *exprpb.Expr, + step *exprpb.Expr, + result *exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto( + ah.modernHelper.NewComprehension( + mustAdaptToExpr(iterRange), + iterVar, + accuVar, + mustAdaptToExpr(accuInit), + mustAdaptToExpr(condition), + mustAdaptToExpr(step), + mustAdaptToExpr(result), + ), + ) +} + +// Ident creates an identifier Expr value. +func (ah *adaptingHelper) Ident(name string) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewIdent(name)) +} + +// AccuIdent returns an accumulator identifier for use with comprehension results. +func (ah *adaptingHelper) AccuIdent() *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewAccuIdent()) +} + +// GlobalCall creates a function call Expr value for a global (free) function. +func (ah *adaptingHelper) GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto(ah.modernHelper.NewCall(function, mustAdaptToExprs(args)...)) +} + +// ReceiverCall creates a function call Expr value for a receiver-style function. +func (ah *adaptingHelper) ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr { + return mustAdaptToProto( + ah.modernHelper.NewMemberCall(function, mustAdaptToExpr(target), mustAdaptToExprs(args)...)) +} + +// PresenceTest creates a Select TestOnly Expr value for modelling has() semantics. +func (ah *adaptingHelper) PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr { + op := mustAdaptToExpr(operand) + return mustAdaptToProto(ah.modernHelper.NewPresenceTest(op, field)) +} + +// Select create a field traversal Expr value. +func (ah *adaptingHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr { + op := mustAdaptToExpr(operand) + return mustAdaptToProto(ah.modernHelper.NewSelect(op, field)) +} + +// OffsetLocation returns the Location of the expression identifier. +func (ah *adaptingHelper) OffsetLocation(exprID int64) common.Location { + return ah.modernHelper.OffsetLocation(exprID) +} + +// NewError associates an error message with a given expression id. +func (ah *adaptingHelper) NewError(exprID int64, message string) *Error { + return ah.modernHelper.NewError(exprID, message) +} + +func mustAdaptToExprs(exprs []*exprpb.Expr) []ast.Expr { + adapted := make([]ast.Expr, len(exprs)) + for i, e := range exprs { + adapted[i] = mustAdaptToExpr(e) + } + return adapted +} + +func mustAdaptToExpr(e *exprpb.Expr) ast.Expr { + out, _ := adaptToExpr(e) + return out +} + +func adaptToExpr(e *exprpb.Expr) (ast.Expr, *Error) { + if e == nil { + return nil, nil + } + out, err := ast.ProtoToExpr(e) + if err != nil { + return nil, wrapErr(e.GetId(), "proto conversion failure", err) + } + return out, nil +} + +func mustAdaptToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) ast.EntryExpr { + out, _ := ast.ProtoToEntryExpr(e) + return out +} + +func mustAdaptToProto(e ast.Expr) *exprpb.Expr { + out, _ := adaptToProto(e) + return out +} + +func adaptToProto(e ast.Expr) (*exprpb.Expr, *Error) { + if e == nil { + return nil, nil + } + out, err := ast.ExprToProto(e) + if err != nil { + return nil, wrapErr(e.ID(), "expr conversion failure", err) + } + return out, nil +} + +func mustAdaptToProtoEntry(e ast.EntryExpr) *exprpb.Expr_CreateStruct_Entry { + out, _ := ast.EntryExprToProto(e) + return out +} + +func toParserHelper(meh MacroExprHelper) (parser.ExprHelper, *Error) { + ah, ok := meh.(*adaptingHelper) + if !ok { + return nil, common.NewError(0, + fmt.Sprintf("unsupported macro helper: %v (%T)", meh, meh), + common.NoLocation) + } + return ah.modernHelper, nil +} diff --git a/common/ast/conversion.go b/common/ast/conversion.go index f2cc01ff5..16c57871e 100644 --- a/common/ast/conversion.go +++ b/common/ast/conversion.go @@ -95,6 +95,18 @@ func ProtoToExpr(e *exprpb.Expr) (Expr, error) { return exprInternal(factory, e) } +// ProtoToEntryExpr converts a protobuf struct/map entry to an ast.EntryExpr +func ProtoToEntryExpr(e *exprpb.Expr_CreateStruct_Entry) (EntryExpr, error) { + factory := NewExprFactory() + switch e.GetKeyKind().(type) { + case *exprpb.Expr_CreateStruct_Entry_FieldKey: + return exprStructField(factory, e.GetId(), e) + case *exprpb.Expr_CreateStruct_Entry_MapKey: + return exprMapEntry(factory, e.GetId(), e) + } + return nil, fmt.Errorf("unsupported expr entry kind: %v", e) +} + func exprInternal(factory ExprFactory, e *exprpb.Expr) (Expr, error) { id := e.GetId() switch e.GetExprKind().(type) { diff --git a/ext/BUILD.bazel b/ext/BUILD.bazel index c6e57392c..74b6c780f 100644 --- a/ext/BUILD.bazel +++ b/ext/BUILD.bazel @@ -29,7 +29,6 @@ go_library( "//common/types/ref:go_default_library", "//common/types/traits:go_default_library", "//interpreter:go_default_library", - "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//reflect/protoreflect:go_default_library", "@org_golang_google_protobuf//types/known/structpb", @@ -62,7 +61,6 @@ go_test( "//test:go_default_library", "//test/proto2pb:go_default_library", "//test/proto3pb:go_default_library", - "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", "@org_golang_google_protobuf//proto:go_default_library", "@org_golang_google_protobuf//types/known/wrapperspb:go_default_library", "@org_golang_google_protobuf//encoding/protojson:go_default_library", diff --git a/ext/bindings.go b/ext/bindings.go index 4ac9a7f07..2c6cc627f 100644 --- a/ext/bindings.go +++ b/ext/bindings.go @@ -16,8 +16,8 @@ package ext import ( "github.com/google/cel-go/cel" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" ) // Bindings returns a cel.EnvOption to configure support for local variable @@ -61,7 +61,7 @@ func (celBindings) CompileOptions() []cel.EnvOption { return []cel.EnvOption{ cel.Macros( // cel.bind(var, , ) - cel.NewReceiverMacro(bindMacro, 3, celBind), + cel.ReceiverMacro(bindMacro, 3, celBind), ), } } @@ -70,27 +70,27 @@ func (celBindings) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } -func celBind(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { +func celBind(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { if !macroTargetMatchesNamespace(celNamespace, target) { return nil, nil } varIdent := args[0] varName := "" - switch varIdent.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr: - varName = varIdent.GetIdentExpr().GetName() + switch varIdent.Kind() { + case ast.IdentKind: + varName = varIdent.AsIdent() default: - return nil, meh.NewError(varIdent.GetId(), "cel.bind() variable names must be simple identifiers") + return nil, mef.NewError(varIdent.ID(), "cel.bind() variable names must be simple identifiers") } varInit := args[1] resultExpr := args[2] - return meh.Fold( + return mef.NewComprehension( + mef.NewList(), unusedIterVar, - meh.NewList(), varName, varInit, - meh.LiteralBool(false), - meh.Ident(varName), + mef.NewLiteral(types.False), + mef.NewIdent(varName), resultExpr, ), nil } diff --git a/ext/guards.go b/ext/guards.go index 785c8675b..2c00bfe3a 100644 --- a/ext/guards.go +++ b/ext/guards.go @@ -15,10 +15,9 @@ package ext import ( + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // function invocation guards for common call signatures within extension functions. @@ -51,10 +50,10 @@ func listStringOrError(strs []string, err error) ref.Val { return types.DefaultTypeAdapter.NativeToValue(strs) } -func macroTargetMatchesNamespace(ns string, target *exprpb.Expr) bool { - switch target.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr: - if target.GetIdentExpr().GetName() != ns { +func macroTargetMatchesNamespace(ns string, target ast.Expr) bool { + switch target.Kind() { + case ast.IdentKind: + if target.AsIdent() != ns { return false } return true diff --git a/ext/math.go b/ext/math.go index 0b9a36103..65d7e2eb0 100644 --- a/ext/math.go +++ b/ext/math.go @@ -19,11 +19,10 @@ import ( "strings" "github.com/google/cel-go/cel" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // Math returns a cel.EnvOption to configure namespaced math helper macros and @@ -111,9 +110,9 @@ func (mathLib) CompileOptions() []cel.EnvOption { return []cel.EnvOption{ cel.Macros( // math.least(num, ...) - cel.NewReceiverVarArgMacro(leastMacro, mathLeast), + cel.ReceiverVarArgMacro(leastMacro, mathLeast), // math.greatest(num, ...) - cel.NewReceiverVarArgMacro(greatestMacro, mathGreatest), + cel.ReceiverVarArgMacro(greatestMacro, mathGreatest), ), cel.Function(minFunc, cel.Overload("math_@min_double", []*cel.Type{cel.DoubleType}, cel.DoubleType, @@ -187,57 +186,57 @@ func (mathLib) ProgramOptions() []cel.ProgramOption { return []cel.ProgramOption{} } -func mathLeast(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { +func mathLeast(meh cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { if !macroTargetMatchesNamespace(mathNamespace, target) { return nil, nil } switch len(args) { case 0: - return nil, meh.NewError(target.GetId(), "math.least() requires at least one argument") + return nil, meh.NewError(target.ID(), "math.least() requires at least one argument") case 1: if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) { - return meh.GlobalCall(minFunc, args[0]), nil + return meh.NewCall(minFunc, args[0]), nil } - return nil, meh.NewError(args[0].GetId(), "math.least() invalid single argument value") + return nil, meh.NewError(args[0].ID(), "math.least() invalid single argument value") case 2: err := checkInvalidArgs(meh, "math.least()", args) if err != nil { return nil, err } - return meh.GlobalCall(minFunc, args...), nil + return meh.NewCall(minFunc, args...), nil default: err := checkInvalidArgs(meh, "math.least()", args) if err != nil { return nil, err } - return meh.GlobalCall(minFunc, meh.NewList(args...)), nil + return meh.NewCall(minFunc, meh.NewList(args...)), nil } } -func mathGreatest(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { +func mathGreatest(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { if !macroTargetMatchesNamespace(mathNamespace, target) { return nil, nil } switch len(args) { case 0: - return nil, meh.NewError(target.GetId(), "math.greatest() requires at least one argument") + return nil, mef.NewError(target.ID(), "math.greatest() requires at least one argument") case 1: if isListLiteralWithValidArgs(args[0]) || isValidArgType(args[0]) { - return meh.GlobalCall(maxFunc, args[0]), nil + return mef.NewCall(maxFunc, args[0]), nil } - return nil, meh.NewError(args[0].GetId(), "math.greatest() invalid single argument value") + return nil, mef.NewError(args[0].ID(), "math.greatest() invalid single argument value") case 2: - err := checkInvalidArgs(meh, "math.greatest()", args) + err := checkInvalidArgs(mef, "math.greatest()", args) if err != nil { return nil, err } - return meh.GlobalCall(maxFunc, args...), nil + return mef.NewCall(maxFunc, args...), nil default: - err := checkInvalidArgs(meh, "math.greatest()", args) + err := checkInvalidArgs(mef, "math.greatest()", args) if err != nil { return nil, err } - return meh.GlobalCall(maxFunc, meh.NewList(args...)), nil + return mef.NewCall(maxFunc, mef.NewList(args...)), nil } } @@ -311,48 +310,48 @@ func maxList(numList ref.Val) ref.Val { } } -func checkInvalidArgs(meh cel.MacroExprHelper, funcName string, args []*exprpb.Expr) *cel.Error { +func checkInvalidArgs(meh cel.MacroExprFactory, funcName string, args []ast.Expr) *cel.Error { for _, arg := range args { err := checkInvalidArgLiteral(funcName, arg) if err != nil { - return meh.NewError(arg.GetId(), err.Error()) + return meh.NewError(arg.ID(), err.Error()) } } return nil } -func checkInvalidArgLiteral(funcName string, arg *exprpb.Expr) error { +func checkInvalidArgLiteral(funcName string, arg ast.Expr) error { if !isValidArgType(arg) { return fmt.Errorf("%s simple literal arguments must be numeric", funcName) } return nil } -func isValidArgType(arg *exprpb.Expr) bool { - switch arg.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - c := arg.GetConstExpr() - switch c.GetConstantKind().(type) { - case *exprpb.Constant_DoubleValue, *exprpb.Constant_Int64Value, *exprpb.Constant_Uint64Value: +func isValidArgType(arg ast.Expr) bool { + switch arg.Kind() { + case ast.LiteralKind: + c := ref.Val(arg.AsLiteral()) + switch c.(type) { + case types.Double, types.Int, types.Uint: return true default: return false } - case *exprpb.Expr_ListExpr, *exprpb.Expr_StructExpr: + case ast.ListKind, ast.MapKind, ast.StructKind: return false default: return true } } -func isListLiteralWithValidArgs(arg *exprpb.Expr) bool { - switch arg.GetExprKind().(type) { - case *exprpb.Expr_ListExpr: - list := arg.GetListExpr() - if len(list.GetElements()) == 0 { +func isListLiteralWithValidArgs(arg ast.Expr) bool { + switch arg.Kind() { + case ast.ListKind: + list := arg.AsList() + if list.Size() == 0 { return false } - for _, e := range list.GetElements() { + for _, e := range list.Elements() { if !isValidArgType(e) { return false } diff --git a/ext/protos.go b/ext/protos.go index a7ca27a6a..68796f60a 100644 --- a/ext/protos.go +++ b/ext/protos.go @@ -16,8 +16,7 @@ package ext import ( "github.com/google/cel-go/cel" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/ast" ) // Protos returns a cel.EnvOption to configure extended macros and functions for @@ -72,9 +71,9 @@ func (protoLib) CompileOptions() []cel.EnvOption { return []cel.EnvOption{ cel.Macros( // proto.getExt(msg, select_expression) - cel.NewReceiverMacro(getExtension, 2, getProtoExt), + cel.ReceiverMacro(getExtension, 2, getProtoExt), // proto.hasExt(msg, select_expression) - cel.NewReceiverMacro(hasExtension, 2, hasProtoExt), + cel.ReceiverMacro(hasExtension, 2, hasProtoExt), ), } } @@ -85,56 +84,56 @@ func (protoLib) ProgramOptions() []cel.ProgramOption { } // hasProtoExt generates a test-only select expression for a fully-qualified extension name on a protobuf message. -func hasProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { +func hasProtoExt(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { if !macroTargetMatchesNamespace(protoNamespace, target) { return nil, nil } - extensionField, err := getExtFieldName(meh, args[1]) + extensionField, err := getExtFieldName(mef, args[1]) if err != nil { return nil, err } - return meh.PresenceTest(args[0], extensionField), nil + return mef.NewPresenceTest(args[0], extensionField), nil } // getProtoExt generates a select expression for a fully-qualified extension name on a protobuf message. -func getProtoExt(meh cel.MacroExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *cel.Error) { +func getProtoExt(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (ast.Expr, *cel.Error) { if !macroTargetMatchesNamespace(protoNamespace, target) { return nil, nil } - extFieldName, err := getExtFieldName(meh, args[1]) + extFieldName, err := getExtFieldName(mef, args[1]) if err != nil { return nil, err } - return meh.Select(args[0], extFieldName), nil + return mef.NewSelect(args[0], extFieldName), nil } -func getExtFieldName(meh cel.MacroExprHelper, expr *exprpb.Expr) (string, *cel.Error) { +func getExtFieldName(mef cel.MacroExprFactory, expr ast.Expr) (string, *cel.Error) { isValid := false extensionField := "" - switch expr.GetExprKind().(type) { - case *exprpb.Expr_SelectExpr: + switch expr.Kind() { + case ast.SelectKind: extensionField, isValid = validateIdentifier(expr) } if !isValid { - return "", meh.NewError(expr.GetId(), "invalid extension field") + return "", mef.NewError(expr.ID(), "invalid extension field") } return extensionField, nil } -func validateIdentifier(expr *exprpb.Expr) (string, bool) { - switch expr.GetExprKind().(type) { - case *exprpb.Expr_IdentExpr: - return expr.GetIdentExpr().GetName(), true - case *exprpb.Expr_SelectExpr: - sel := expr.GetSelectExpr() - if sel.GetTestOnly() { +func validateIdentifier(expr ast.Expr) (string, bool) { + switch expr.Kind() { + case ast.IdentKind: + return expr.AsIdent(), true + case ast.SelectKind: + sel := expr.AsSelect() + if sel.IsTestOnly() { return "", false } - opStr, isIdent := validateIdentifier(sel.GetOperand()) + opStr, isIdent := validateIdentifier(sel.Operand()) if !isIdent { return "", false } - return opStr + "." + sel.GetField(), true + return opStr + "." + sel.FieldName(), true default: return "", false } diff --git a/parser/BUILD.bazel b/parser/BUILD.bazel index af15000cd..a4a09ad75 100644 --- a/parser/BUILD.bazel +++ b/parser/BUILD.bazel @@ -23,6 +23,8 @@ go_library( "//common/ast:go_default_library", "//common/operators:go_default_library", "//common/runes:go_default_library", + "//common/types:go_default_library", + "//common/types/ref:go_default_library", "//parser/gen:go_default_library", "@com_github_antlr_antlr4_runtime_go_antlr_v4//:go_default_library", "@org_golang_google_genproto_googleapis_api//expr/v1alpha1:go_default_library", diff --git a/parser/helper.go b/parser/helper.go index a5f29e3d7..07656e139 100644 --- a/parser/helper.go +++ b/parser/helper.go @@ -20,281 +20,206 @@ import ( antlr "github.com/antlr/antlr4/runtime/Go/antlr/v4" "github.com/google/cel-go/common" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" ) type parserHelper struct { - source common.Source - nextID int64 - positions map[int64]int32 - macroCalls map[int64]*exprpb.Expr + exprFactory ast.ExprFactory + source common.Source + sourceInfo *ast.SourceInfo + nextID int64 } -func newParserHelper(source common.Source) *parserHelper { +func newParserHelper(source common.Source, fac ast.ExprFactory) *parserHelper { return &parserHelper{ - source: source, - nextID: 1, - positions: make(map[int64]int32), - macroCalls: make(map[int64]*exprpb.Expr), + exprFactory: fac, + source: source, + sourceInfo: ast.NewSourceInfo(source), + nextID: 1, } } -func (p *parserHelper) getSourceInfo() *exprpb.SourceInfo { - return &exprpb.SourceInfo{ - Location: p.source.Description(), - Positions: p.positions, - LineOffsets: p.source.LineOffsets(), - MacroCalls: p.macroCalls} +func (p *parserHelper) getSourceInfo() *ast.SourceInfo { + return p.sourceInfo } -func (p *parserHelper) newLiteral(ctx any, value *exprpb.Constant) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_ConstExpr{ConstExpr: value} - return exprNode +func (p *parserHelper) newLiteral(ctx any, value ref.Val) ast.Expr { + return p.exprFactory.NewLiteral(p.newID(ctx), value) } -func (p *parserHelper) newLiteralBool(ctx any, value bool) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: value}}) +func (p *parserHelper) newLiteralBool(ctx any, value bool) ast.Expr { + return p.newLiteral(ctx, types.Bool(value)) } -func (p *parserHelper) newLiteralString(ctx any, value string) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: value}}) +func (p *parserHelper) newLiteralString(ctx any, value string) ast.Expr { + return p.newLiteral(ctx, types.String(value)) } -func (p *parserHelper) newLiteralBytes(ctx any, value []byte) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: value}}) +func (p *parserHelper) newLiteralBytes(ctx any, value []byte) ast.Expr { + return p.newLiteral(ctx, types.Bytes(value)) } -func (p *parserHelper) newLiteralInt(ctx any, value int64) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: value}}) +func (p *parserHelper) newLiteralInt(ctx any, value int64) ast.Expr { + return p.newLiteral(ctx, types.Int(value)) } -func (p *parserHelper) newLiteralUint(ctx any, value uint64) *exprpb.Expr { - return p.newLiteral(ctx, &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: value}}) +func (p *parserHelper) newLiteralUint(ctx any, value uint64) ast.Expr { + return p.newLiteral(ctx, types.Uint(value)) } -func (p *parserHelper) newLiteralDouble(ctx any, value float64) *exprpb.Expr { - return p.newLiteral(ctx, - &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: value}}) +func (p *parserHelper) newLiteralDouble(ctx any, value float64) ast.Expr { + return p.newLiteral(ctx, types.Double(value)) } -func (p *parserHelper) newIdent(ctx any, name string) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_IdentExpr{IdentExpr: &exprpb.Expr_Ident{Name: name}} - return exprNode +func (p *parserHelper) newIdent(ctx any, name string) ast.Expr { + return p.exprFactory.NewIdent(p.newID(ctx), name) } -func (p *parserHelper) newSelect(ctx any, operand *exprpb.Expr, field string) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_SelectExpr{ - SelectExpr: &exprpb.Expr_Select{Operand: operand, Field: field}} - return exprNode +func (p *parserHelper) newSelect(ctx any, operand ast.Expr, field string) ast.Expr { + return p.exprFactory.NewSelect(p.newID(ctx), operand, field) } -func (p *parserHelper) newPresenceTest(ctx any, operand *exprpb.Expr, field string) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_SelectExpr{ - SelectExpr: &exprpb.Expr_Select{Operand: operand, Field: field, TestOnly: true}} - return exprNode +func (p *parserHelper) newPresenceTest(ctx any, operand ast.Expr, field string) ast.Expr { + return p.exprFactory.NewPresenceTest(p.newID(ctx), operand, field) } -func (p *parserHelper) newGlobalCall(ctx any, function string, args ...*exprpb.Expr) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{Function: function, Args: args}} - return exprNode +func (p *parserHelper) newGlobalCall(ctx any, function string, args ...ast.Expr) ast.Expr { + return p.exprFactory.NewCall(p.newID(ctx), function, args...) } -func (p *parserHelper) newReceiverCall(ctx any, function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{Function: function, Target: target, Args: args}} - return exprNode +func (p *parserHelper) newReceiverCall(ctx any, function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return p.exprFactory.NewMemberCall(p.newID(ctx), function, target, args...) } -func (p *parserHelper) newList(ctx any, elements []*exprpb.Expr, optionals ...int32) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - Elements: elements, - OptionalIndices: optionals, - }} - return exprNode +func (p *parserHelper) newList(ctx any, elements []ast.Expr, optionals ...int32) ast.Expr { + return p.exprFactory.NewList(p.newID(ctx), elements, optionals) } -func (p *parserHelper) newMap(ctx any, entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{Entries: entries}} - return exprNode +func (p *parserHelper) newMap(ctx any, entries ...ast.EntryExpr) ast.Expr { + return p.exprFactory.NewMap(p.newID(ctx), entries) } -func (p *parserHelper) newMapEntry(entryID int64, key *exprpb.Expr, value *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { - return &exprpb.Expr_CreateStruct_Entry{ - Id: entryID, - KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{MapKey: key}, - Value: value, - OptionalEntry: optional, - } +func (p *parserHelper) newMapEntry(entryID int64, key ast.Expr, value ast.Expr, optional bool) ast.EntryExpr { + return p.exprFactory.NewMapEntry(entryID, key, value, optional) } -func (p *parserHelper) newObject(ctx any, typeName string, entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{ - MessageName: typeName, - Entries: entries, - }, - } - return exprNode +func (p *parserHelper) newObject(ctx any, typeName string, fields ...ast.EntryExpr) ast.Expr { + return p.exprFactory.NewStruct(p.newID(ctx), typeName, fields) } -func (p *parserHelper) newObjectField(fieldID int64, field string, value *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { - return &exprpb.Expr_CreateStruct_Entry{ - Id: fieldID, - KeyKind: &exprpb.Expr_CreateStruct_Entry_FieldKey{FieldKey: field}, - Value: value, - OptionalEntry: optional, - } +func (p *parserHelper) newObjectField(fieldID int64, field string, value ast.Expr, optional bool) ast.EntryExpr { + return p.exprFactory.NewStructField(fieldID, field, value, optional) } -func (p *parserHelper) newComprehension(ctx any, iterVar string, - iterRange *exprpb.Expr, +func (p *parserHelper) newComprehension(ctx any, + iterRange ast.Expr, + iterVar string, accuVar string, - accuInit *exprpb.Expr, - condition *exprpb.Expr, - step *exprpb.Expr, - result *exprpb.Expr) *exprpb.Expr { - exprNode := p.newExpr(ctx) - exprNode.ExprKind = &exprpb.Expr_ComprehensionExpr{ - ComprehensionExpr: &exprpb.Expr_Comprehension{ - AccuVar: accuVar, - AccuInit: accuInit, - IterVar: iterVar, - IterRange: iterRange, - LoopCondition: condition, - LoopStep: step, - Result: result}} - return exprNode -} - -func (p *parserHelper) newExpr(ctx any) *exprpb.Expr { - id, isID := ctx.(int64) - if isID { - return &exprpb.Expr{Id: id} + accuInit ast.Expr, + condition ast.Expr, + step ast.Expr, + result ast.Expr) ast.Expr { + return p.exprFactory.NewComprehension( + p.newID(ctx), iterRange, iterVar, accuVar, accuInit, condition, step, result) +} + +func (p *parserHelper) newID(ctx any) int64 { + if id, isID := ctx.(int64); isID { + return id } - return &exprpb.Expr{Id: p.id(ctx)} + return p.id(ctx) +} + +func (p *parserHelper) newExpr(ctx any) ast.Expr { + return p.exprFactory.NewUnspecifiedExpr(p.newID(ctx)) } func (p *parserHelper) id(ctx any) int64 { - var location common.Location + var offset ast.OffsetRange switch c := ctx.(type) { case antlr.ParserRuleContext: - token := c.GetStart() - location = p.source.NewLocation(token.GetLine(), token.GetColumn()) + start, stop := c.GetStart(), c.GetStop() + if stop == nil { + stop = start + } + offset.Start = p.sourceInfo.ComputeOffset(int32(start.GetLine()), int32(start.GetColumn())) + offset.Stop = p.sourceInfo.ComputeOffset(int32(stop.GetLine()), int32(stop.GetColumn())) case antlr.Token: - token := c - location = p.source.NewLocation(token.GetLine(), token.GetColumn()) + offset.Start = p.sourceInfo.ComputeOffset(int32(c.GetLine()), int32(c.GetColumn())) + offset.Stop = offset.Start case common.Location: - location = c + offset.Start = p.sourceInfo.ComputeOffset(int32(c.Line()), int32(c.Column())) + offset.Stop = offset.Start + case ast.OffsetRange: + offset = c default: // This should only happen if the ctx is nil return -1 } id := p.nextID - p.positions[id], _ = p.source.LocationOffset(location) + p.sourceInfo.SetOffsetRange(id, offset) p.nextID++ return id } func (p *parserHelper) getLocation(id int64) common.Location { - characterOffset := p.positions[id] - location, _ := p.source.OffsetLocation(characterOffset) - return location + return p.sourceInfo.GetStartLocation(id) } // buildMacroCallArg iterates the expression and returns a new expression // where all macros have been replaced by their IDs in MacroCalls -func (p *parserHelper) buildMacroCallArg(expr *exprpb.Expr) *exprpb.Expr { - if _, found := p.macroCalls[expr.GetId()]; found { - return &exprpb.Expr{Id: expr.GetId()} +func (p *parserHelper) buildMacroCallArg(expr ast.Expr) ast.Expr { + if _, found := p.sourceInfo.GetMacroCall(expr.ID()); found { + return p.exprFactory.NewUnspecifiedExpr(expr.ID()) } - switch expr.GetExprKind().(type) { - case *exprpb.Expr_CallExpr: + switch expr.Kind() { + case ast.CallKind: // Iterate the AST from `expr` recursively looking for macros. Because we are at most // starting from the top level macro, this recursion is bounded by the size of the AST. This // means that the depth check on the AST during parsing will catch recursion overflows // before we get to here. - macroTarget := expr.GetCallExpr().GetTarget() - if macroTarget != nil { - macroTarget = p.buildMacroCallArg(macroTarget) - } - macroArgs := make([]*exprpb.Expr, len(expr.GetCallExpr().GetArgs())) - for index, arg := range expr.GetCallExpr().GetArgs() { + call := expr.AsCall() + macroArgs := make([]ast.Expr, len(call.Args())) + for index, arg := range call.Args() { macroArgs[index] = p.buildMacroCallArg(arg) } - return &exprpb.Expr{ - Id: expr.GetId(), - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Target: macroTarget, - Function: expr.GetCallExpr().GetFunction(), - Args: macroArgs, - }, - }, + if !call.IsMemberFunction() { + return p.exprFactory.NewCall(expr.ID(), call.FunctionName(), macroArgs...) } - case *exprpb.Expr_ListExpr: - listExpr := expr.GetListExpr() - macroListArgs := make([]*exprpb.Expr, len(listExpr.GetElements())) - for i, elem := range listExpr.GetElements() { + macroTarget := p.buildMacroCallArg(call.Target()) + return p.exprFactory.NewMemberCall(expr.ID(), call.FunctionName(), macroTarget, macroArgs...) + case ast.ListKind: + list := expr.AsList() + macroListArgs := make([]ast.Expr, list.Size()) + for i, elem := range list.Elements() { macroListArgs[i] = p.buildMacroCallArg(elem) } - return &exprpb.Expr{ - Id: expr.GetId(), - ExprKind: &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - Elements: macroListArgs, - OptionalIndices: listExpr.GetOptionalIndices(), - }, - }, - } + return p.exprFactory.NewList(expr.ID(), macroListArgs, list.OptionalIndices()) } - return expr } // addMacroCall adds the macro the the MacroCalls map in source info. If a macro has args/subargs/target // that are macros, their ID will be stored instead for later self-lookups. -func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) { - macroTarget := target - if target != nil { - if _, found := p.macroCalls[target.GetId()]; found { - macroTarget = &exprpb.Expr{Id: target.GetId()} - } else { - macroTarget = p.buildMacroCallArg(target) - } - } - - macroArgs := make([]*exprpb.Expr, len(args)) +func (p *parserHelper) addMacroCall(exprID int64, function string, target ast.Expr, args ...ast.Expr) { + macroArgs := make([]ast.Expr, len(args)) for index, arg := range args { macroArgs[index] = p.buildMacroCallArg(arg) } - - p.macroCalls[exprID] = &exprpb.Expr{ - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Target: macroTarget, - Function: function, - Args: macroArgs, - }, - }, + if target == nil { + p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewCall(0, function, macroArgs...)) + return } + macroTarget := target + if _, found := p.sourceInfo.GetMacroCall(target.ID()); found { + macroTarget = p.exprFactory.NewUnspecifiedExpr(target.ID()) + } else { + macroTarget = p.buildMacroCallArg(target) + } + p.sourceInfo.SetMacroCall(exprID, p.exprFactory.NewMemberCall(0, function, macroTarget, macroArgs...)) } // logicManager compacts logical trees into a more efficient structure which is semantically @@ -309,71 +234,71 @@ func (p *parserHelper) addMacroCall(exprID int64, function string, target *exprp // controversial choice as it alters the traditional order of execution assumptions present in most // expressions. type logicManager struct { - helper *parserHelper + exprFactory ast.ExprFactory function string - terms []*exprpb.Expr + terms []ast.Expr ops []int64 variadicASTs bool } // newVariadicLogicManager creates a logic manager instance bound to a specific function and its first term. -func newVariadicLogicManager(h *parserHelper, function string, term *exprpb.Expr) *logicManager { +func newVariadicLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager { return &logicManager{ - helper: h, + exprFactory: fac, function: function, - terms: []*exprpb.Expr{term}, + terms: []ast.Expr{term}, ops: []int64{}, variadicASTs: true, } } // newBalancingLogicManager creates a logic manager instance bound to a specific function and its first term. -func newBalancingLogicManager(h *parserHelper, function string, term *exprpb.Expr) *logicManager { +func newBalancingLogicManager(fac ast.ExprFactory, function string, term ast.Expr) *logicManager { return &logicManager{ - helper: h, + exprFactory: fac, function: function, - terms: []*exprpb.Expr{term}, + terms: []ast.Expr{term}, ops: []int64{}, variadicASTs: false, } } // addTerm adds an operation identifier and term to the set of terms to be balanced. -func (l *logicManager) addTerm(op int64, term *exprpb.Expr) { +func (l *logicManager) addTerm(op int64, term ast.Expr) { l.terms = append(l.terms, term) l.ops = append(l.ops, op) } // toExpr renders the logic graph into an Expr value, either balancing a tree of logical // operations or creating a variadic representation of the logical operator. -func (l *logicManager) toExpr() *exprpb.Expr { +func (l *logicManager) toExpr() ast.Expr { if len(l.terms) == 1 { return l.terms[0] } if l.variadicASTs { - return l.helper.newGlobalCall(l.ops[0], l.function, l.terms...) + return l.exprFactory.NewCall(l.ops[0], l.function, l.terms...) } return l.balancedTree(0, len(l.ops)-1) } // balancedTree recursively balances the terms provided to a commutative operator. -func (l *logicManager) balancedTree(lo, hi int) *exprpb.Expr { +func (l *logicManager) balancedTree(lo, hi int) ast.Expr { mid := (lo + hi + 1) / 2 - var left *exprpb.Expr + var left ast.Expr if mid == lo { left = l.terms[mid] } else { left = l.balancedTree(lo, mid-1) } - var right *exprpb.Expr + var right ast.Expr if mid == hi { right = l.terms[mid+1] } else { right = l.balancedTree(mid+1, hi) } - return l.helper.newGlobalCall(l.ops[mid], l.function, left, right) + return l.exprFactory.NewCall(l.ops[mid], l.function, left, right) } type exprHelper struct { @@ -387,202 +312,151 @@ func (e *exprHelper) nextMacroID() int64 { // Copy implements the ExprHelper interface method by producing a copy of the input Expr value // with a fresh set of numeric identifiers the Expr and all its descendants. -func (e *exprHelper) Copy(expr *exprpb.Expr) *exprpb.Expr { - copy := e.parserHelper.newExpr(e.parserHelper.getLocation(expr.GetId())) - switch expr.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - copy.ExprKind = &exprpb.Expr_ConstExpr{ConstExpr: expr.GetConstExpr()} - case *exprpb.Expr_IdentExpr: - copy.ExprKind = &exprpb.Expr_IdentExpr{IdentExpr: expr.GetIdentExpr()} - case *exprpb.Expr_SelectExpr: - op := expr.GetSelectExpr().GetOperand() - copy.ExprKind = &exprpb.Expr_SelectExpr{SelectExpr: &exprpb.Expr_Select{ - Operand: e.Copy(op), - Field: expr.GetSelectExpr().GetField(), - TestOnly: expr.GetSelectExpr().GetTestOnly(), - }} - case *exprpb.Expr_CallExpr: - call := expr.GetCallExpr() - target := call.GetTarget() - if target != nil { - target = e.Copy(target) +func (e *exprHelper) Copy(expr ast.Expr) ast.Expr { + offsetRange, _ := e.parserHelper.sourceInfo.GetOffsetRange(expr.ID()) + copyID := e.parserHelper.newID(offsetRange) + switch expr.Kind() { + case ast.LiteralKind: + return e.exprFactory.NewLiteral(copyID, expr.AsLiteral()) + case ast.IdentKind: + return e.exprFactory.NewIdent(copyID, expr.AsIdent()) + case ast.SelectKind: + sel := expr.AsSelect() + op := e.Copy(sel.Operand()) + if sel.IsTestOnly() { + return e.exprFactory.NewPresenceTest(copyID, op, sel.FieldName()) } - args := call.GetArgs() - argsCopy := make([]*exprpb.Expr, len(args)) + return e.exprFactory.NewSelect(copyID, op, sel.FieldName()) + case ast.CallKind: + call := expr.AsCall() + args := call.Args() + argsCopy := make([]ast.Expr, len(args)) for i, arg := range args { argsCopy[i] = e.Copy(arg) } - copy.ExprKind = &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Function: call.GetFunction(), - Target: target, - Args: argsCopy, - }, + if !call.IsMemberFunction() { + return e.exprFactory.NewCall(copyID, call.FunctionName(), argsCopy...) } - case *exprpb.Expr_ListExpr: - elems := expr.GetListExpr().GetElements() - elemsCopy := make([]*exprpb.Expr, len(elems)) + return e.exprFactory.NewMemberCall(copyID, call.FunctionName(), e.Copy(call.Target()), argsCopy...) + case ast.ListKind: + list := expr.AsList() + elems := list.Elements() + elemsCopy := make([]ast.Expr, len(elems)) for i, elem := range elems { elemsCopy[i] = e.Copy(elem) } - copy.ExprKind = &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{Elements: elemsCopy}, - } - case *exprpb.Expr_StructExpr: - entries := expr.GetStructExpr().GetEntries() - entriesCopy := make([]*exprpb.Expr_CreateStruct_Entry, len(entries)) - for i, entry := range entries { - entryCopy := &exprpb.Expr_CreateStruct_Entry{} - entryCopy.Id = e.nextMacroID() - switch entry.GetKeyKind().(type) { - case *exprpb.Expr_CreateStruct_Entry_FieldKey: - entryCopy.KeyKind = &exprpb.Expr_CreateStruct_Entry_FieldKey{ - FieldKey: entry.GetFieldKey(), - } - case *exprpb.Expr_CreateStruct_Entry_MapKey: - entryCopy.KeyKind = &exprpb.Expr_CreateStruct_Entry_MapKey{ - MapKey: e.Copy(entry.GetMapKey()), - } - } - entryCopy.Value = e.Copy(entry.GetValue()) - entriesCopy[i] = entryCopy + return e.exprFactory.NewList(copyID, elemsCopy, list.OptionalIndices()) + case ast.MapKind: + m := expr.AsMap() + entries := m.Entries() + entriesCopy := make([]ast.EntryExpr, len(entries)) + for i, en := range entries { + entry := en.AsMapEntry() + entryID := e.nextMacroID() + entriesCopy[i] = e.exprFactory.NewMapEntry(entryID, + e.Copy(entry.Key()), e.Copy(entry.Value()), entry.IsOptional()) } - copy.ExprKind = &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{ - MessageName: expr.GetStructExpr().GetMessageName(), - Entries: entriesCopy, - }, - } - case *exprpb.Expr_ComprehensionExpr: - iterRange := e.Copy(expr.GetComprehensionExpr().GetIterRange()) - accuInit := e.Copy(expr.GetComprehensionExpr().GetAccuInit()) - cond := e.Copy(expr.GetComprehensionExpr().GetLoopCondition()) - step := e.Copy(expr.GetComprehensionExpr().GetLoopStep()) - result := e.Copy(expr.GetComprehensionExpr().GetResult()) - copy.ExprKind = &exprpb.Expr_ComprehensionExpr{ - ComprehensionExpr: &exprpb.Expr_Comprehension{ - IterRange: iterRange, - IterVar: expr.GetComprehensionExpr().GetIterVar(), - AccuInit: accuInit, - AccuVar: expr.GetComprehensionExpr().GetAccuVar(), - LoopCondition: cond, - LoopStep: step, - Result: result, - }, + return e.exprFactory.NewMap(copyID, entriesCopy) + case ast.StructKind: + s := expr.AsStruct() + fields := s.Fields() + fieldsCopy := make([]ast.EntryExpr, len(fields)) + for i, f := range fields { + field := f.AsStructField() + fieldID := e.nextMacroID() + fieldsCopy[i] = e.exprFactory.NewStructField(fieldID, + field.Name(), e.Copy(field.Value()), field.IsOptional()) } + return e.exprFactory.NewStruct(copyID, s.TypeName(), fieldsCopy) + case ast.ComprehensionKind: + compre := expr.AsComprehension() + iterRange := e.Copy(compre.IterRange()) + accuInit := e.Copy(compre.AccuInit()) + cond := e.Copy(compre.LoopCondition()) + step := e.Copy(compre.LoopStep()) + result := e.Copy(compre.Result()) + return e.exprFactory.NewComprehension(copyID, + iterRange, compre.IterVar(), compre.AccuVar(), accuInit, cond, step, result) } - return copy -} - -// LiteralBool implements the ExprHelper interface method. -func (e *exprHelper) LiteralBool(value bool) *exprpb.Expr { - return e.parserHelper.newLiteralBool(e.nextMacroID(), value) -} - -// LiteralBytes implements the ExprHelper interface method. -func (e *exprHelper) LiteralBytes(value []byte) *exprpb.Expr { - return e.parserHelper.newLiteralBytes(e.nextMacroID(), value) -} - -// LiteralDouble implements the ExprHelper interface method. -func (e *exprHelper) LiteralDouble(value float64) *exprpb.Expr { - return e.parserHelper.newLiteralDouble(e.nextMacroID(), value) -} - -// LiteralInt implements the ExprHelper interface method. -func (e *exprHelper) LiteralInt(value int64) *exprpb.Expr { - return e.parserHelper.newLiteralInt(e.nextMacroID(), value) + return e.exprFactory.NewUnspecifiedExpr(copyID) } -// LiteralString implements the ExprHelper interface method. -func (e *exprHelper) LiteralString(value string) *exprpb.Expr { - return e.parserHelper.newLiteralString(e.nextMacroID(), value) -} - -// LiteralUint implements the ExprHelper interface method. -func (e *exprHelper) LiteralUint(value uint64) *exprpb.Expr { - return e.parserHelper.newLiteralUint(e.nextMacroID(), value) +// NewLiteral implements the ExprHelper interface method. +func (e *exprHelper) NewLiteral(value ref.Val) ast.Expr { + return e.exprFactory.NewLiteral(e.nextMacroID(), value) } // NewList implements the ExprHelper interface method. -func (e *exprHelper) NewList(elems ...*exprpb.Expr) *exprpb.Expr { - return e.parserHelper.newList(e.nextMacroID(), elems) +func (e *exprHelper) NewList(elems ...ast.Expr) ast.Expr { + return e.exprFactory.NewList(e.nextMacroID(), elems, []int32{}) } // NewMap implements the ExprHelper interface method. -func (e *exprHelper) NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { - return e.parserHelper.newMap(e.nextMacroID(), entries...) +func (e *exprHelper) NewMap(entries ...ast.EntryExpr) ast.Expr { + return e.exprFactory.NewMap(e.nextMacroID(), entries) } // NewMapEntry implements the ExprHelper interface method. -func (e *exprHelper) NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { - return e.parserHelper.newMapEntry(e.nextMacroID(), key, val, optional) +func (e *exprHelper) NewMapEntry(key ast.Expr, val ast.Expr, optional bool) ast.EntryExpr { + return e.exprFactory.NewMapEntry(e.nextMacroID(), key, val, optional) } -// NewObject implements the ExprHelper interface method. -func (e *exprHelper) NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr { - return e.parserHelper.newObject(e.nextMacroID(), typeName, fieldInits...) +// NewStruct implements the ExprHelper interface method. +func (e *exprHelper) NewStruct(typeName string, fieldInits ...ast.EntryExpr) ast.Expr { + return e.exprFactory.NewStruct(e.nextMacroID(), typeName, fieldInits) } -// NewObjectFieldInit implements the ExprHelper interface method. -func (e *exprHelper) NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry { - return e.parserHelper.newObjectField(e.nextMacroID(), field, init, optional) +// NewStructField implements the ExprHelper interface method. +func (e *exprHelper) NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr { + return e.exprFactory.NewStructField(e.nextMacroID(), field, init, optional) } -// Fold implements the ExprHelper interface method. -func (e *exprHelper) Fold(iterVar string, - iterRange *exprpb.Expr, +// NewComprehension implements the ExprHelper interface method. +func (e *exprHelper) NewComprehension( + iterRange ast.Expr, + iterVar string, accuVar string, - accuInit *exprpb.Expr, - condition *exprpb.Expr, - step *exprpb.Expr, - result *exprpb.Expr) *exprpb.Expr { - return e.parserHelper.newComprehension( - e.nextMacroID(), iterVar, iterRange, accuVar, accuInit, condition, step, result) + accuInit ast.Expr, + condition ast.Expr, + step ast.Expr, + result ast.Expr) ast.Expr { + return e.exprFactory.NewComprehension( + e.nextMacroID(), iterRange, iterVar, accuVar, accuInit, condition, step, result) } -// Ident implements the ExprHelper interface method. -func (e *exprHelper) Ident(name string) *exprpb.Expr { - return e.parserHelper.newIdent(e.nextMacroID(), name) +// NewIdent implements the ExprHelper interface method. +func (e *exprHelper) NewIdent(name string) ast.Expr { + return e.exprFactory.NewIdent(e.nextMacroID(), name) } -// AccuIdent implements the ExprHelper interface method. -func (e *exprHelper) AccuIdent() *exprpb.Expr { - return e.parserHelper.newIdent(e.nextMacroID(), AccumulatorName) +// NewAccuIdent implements the ExprHelper interface method. +func (e *exprHelper) NewAccuIdent() ast.Expr { + return e.exprFactory.NewAccuIdent(e.nextMacroID()) } -// GlobalCall implements the ExprHelper interface method. -func (e *exprHelper) GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr { - return e.parserHelper.newGlobalCall(e.nextMacroID(), function, args...) +// NewGlobalCall implements the ExprHelper interface method. +func (e *exprHelper) NewCall(function string, args ...ast.Expr) ast.Expr { + return e.exprFactory.NewCall(e.nextMacroID(), function, args...) } -// ReceiverCall implements the ExprHelper interface method. -func (e *exprHelper) ReceiverCall(function string, - target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr { - return e.parserHelper.newReceiverCall(e.nextMacroID(), function, target, args...) +// NewMemberCall implements the ExprHelper interface method. +func (e *exprHelper) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return e.exprFactory.NewMemberCall(e.nextMacroID(), function, target, args...) } -// PresenceTest implements the ExprHelper interface method. -func (e *exprHelper) PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr { - return e.parserHelper.newPresenceTest(e.nextMacroID(), operand, field) +// NewPresenceTest implements the ExprHelper interface method. +func (e *exprHelper) NewPresenceTest(operand ast.Expr, field string) ast.Expr { + return e.exprFactory.NewPresenceTest(e.nextMacroID(), operand, field) } -// Select implements the ExprHelper interface method. -func (e *exprHelper) Select(operand *exprpb.Expr, field string) *exprpb.Expr { - return e.parserHelper.newSelect(e.nextMacroID(), operand, field) +// NewSelect implements the ExprHelper interface method. +func (e *exprHelper) NewSelect(operand ast.Expr, field string) ast.Expr { + return e.exprFactory.NewSelect(e.nextMacroID(), operand, field) } // OffsetLocation implements the ExprHelper interface method. func (e *exprHelper) OffsetLocation(exprID int64) common.Location { - offset, found := e.parserHelper.positions[exprID] - if !found { - return common.NoLocation - } - location, found := e.parserHelper.source.OffsetLocation(offset) - if !found { - return common.NoLocation - } - return location + return e.parserHelper.sourceInfo.GetStartLocation(exprID) } // NewError associates an error message with a given expression id, populating the source offset location of the error if possible. diff --git a/parser/helper_test.go b/parser/helper_test.go index f57c851c2..d6602bc7b 100644 --- a/parser/helper_test.go +++ b/parser/helper_test.go @@ -21,8 +21,6 @@ import ( "github.com/google/cel-go/common/ast" "google.golang.org/protobuf/proto" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) func TestExprHelperCopy(t *testing.T) { @@ -34,9 +32,7 @@ func TestExprHelperCopy(t *testing.T) { Macros( MapMacro, NewGlobalMacro("noop", 1, - func(eh ExprHelper, - target *exprpb.Expr, - args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { + func(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { return eh.Copy(args[0]), nil }, ), diff --git a/parser/macro.go b/parser/macro.go index 6066e8ef4..5b1775bed 100644 --- a/parser/macro.go +++ b/parser/macro.go @@ -18,9 +18,10 @@ import ( "fmt" "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" ) // NewGlobalMacro creates a Macro for a global function with the specified arg count. @@ -142,58 +143,38 @@ func makeVarArgMacroKey(name string, receiverStyle bool) string { // and produces as output an Expr ast node. // // Note: when the Macro.IsReceiverStyle() method returns true, the target argument will be nil. -type MacroExpander func(eh ExprHelper, - target *exprpb.Expr, - args []*exprpb.Expr) (*exprpb.Expr, *common.Error) +type MacroExpander func(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) -// ExprHelper assists with the manipulation of proto-based Expr values in a manner which is -// consistent with the source position and expression id generation code leveraged by both -// the parser and type-checker. +// ExprHelper assists with the creation of Expr values in a manner which is consistent +// the internal semantics and id generation behaviors of the parser and checker libraries. type ExprHelper interface { // Copy the input expression with a brand new set of identifiers. - Copy(*exprpb.Expr) *exprpb.Expr - - // LiteralBool creates an Expr value for a bool literal. - LiteralBool(value bool) *exprpb.Expr - - // LiteralBytes creates an Expr value for a byte literal. - LiteralBytes(value []byte) *exprpb.Expr - - // LiteralDouble creates an Expr value for double literal. - LiteralDouble(value float64) *exprpb.Expr + Copy(ast.Expr) ast.Expr - // LiteralInt creates an Expr value for an int literal. - LiteralInt(value int64) *exprpb.Expr + // Literal creates an Expr value for a scalar literal value. + NewLiteral(value ref.Val) ast.Expr - // LiteralString creates am Expr value for a string literal. - LiteralString(value string) *exprpb.Expr - - // LiteralUint creates an Expr value for a uint literal. - LiteralUint(value uint64) *exprpb.Expr - - // NewList creates a CreateList instruction where the list is comprised of the optional set - // of elements provided as arguments. - NewList(elems ...*exprpb.Expr) *exprpb.Expr + // NewList creates a list literal instruction with an optional set of elements. + NewList(elems ...ast.Expr) ast.Expr // NewMap creates a CreateStruct instruction for a map where the map is comprised of the // optional set of key, value entries. - NewMap(entries ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + NewMap(entries ...ast.EntryExpr) ast.Expr // NewMapEntry creates a Map Entry for the key, value pair. - NewMapEntry(key *exprpb.Expr, val *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + NewMapEntry(key ast.Expr, val ast.Expr, optional bool) ast.EntryExpr - // NewObject creates a CreateStruct instruction for an object with a given type name and - // optional set of field initializers. - NewObject(typeName string, fieldInits ...*exprpb.Expr_CreateStruct_Entry) *exprpb.Expr + // NewStruct creates a struct literal expression with an optional set of field initializers. + NewStruct(typeName string, fieldInits ...ast.EntryExpr) ast.Expr - // NewObjectFieldInit creates a new Object field initializer from the field name and value. - NewObjectFieldInit(field string, init *exprpb.Expr, optional bool) *exprpb.Expr_CreateStruct_Entry + // NewStructField creates a new struct field initializer from the field name and value. + NewStructField(field string, init ast.Expr, optional bool) ast.EntryExpr - // Fold creates a fold comprehension instruction. + // NewComprehension creates a new comprehension instruction. // - // - iterVar is the iteration variable name. // - iterRange represents the expression that resolves to a list or map where the elements or // keys (respectively) will be iterated over. + // - iterVar is the iteration variable name. // - accuVar is the accumulation variable name, typically parser.AccumulatorName. // - accuInit is the initial expression whose value will be set for the accuVar prior to // folding. @@ -204,31 +185,31 @@ type ExprHelper interface { // The accuVar should not shadow variable names that you would like to reference within the // environment in the step and condition expressions. Presently, the name __result__ is commonly // used by built-in macros but this may change in the future. - Fold(iterVar string, - iterRange *exprpb.Expr, + NewComprehension(iterRange ast.Expr, + iterVar string, accuVar string, - accuInit *exprpb.Expr, - condition *exprpb.Expr, - step *exprpb.Expr, - result *exprpb.Expr) *exprpb.Expr + accuInit ast.Expr, + condition ast.Expr, + step ast.Expr, + result ast.Expr) ast.Expr - // Ident creates an identifier Expr value. - Ident(name string) *exprpb.Expr + // NewIdent creates an identifier Expr value. + NewIdent(name string) ast.Expr - // AccuIdent returns an accumulator identifier for use with comprehension results. - AccuIdent() *exprpb.Expr + // NewAccuIdent returns an accumulator identifier for use with comprehension results. + NewAccuIdent() ast.Expr - // GlobalCall creates a function call Expr value for a global (free) function. - GlobalCall(function string, args ...*exprpb.Expr) *exprpb.Expr + // NewCall creates a function call Expr value for a global (free) function. + NewCall(function string, args ...ast.Expr) ast.Expr - // ReceiverCall creates a function call Expr value for a receiver-style function. - ReceiverCall(function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr + // NewMemberCall creates a function call Expr value for a receiver-style function. + NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr - // PresenceTest creates a Select TestOnly Expr value for modelling has() semantics. - PresenceTest(operand *exprpb.Expr, field string) *exprpb.Expr + // NewPresenceTest creates a Select TestOnly Expr value for modelling has() semantics. + NewPresenceTest(operand ast.Expr, field string) ast.Expr - // Select create a field traversal Expr value. - Select(operand *exprpb.Expr, field string) *exprpb.Expr + // NewSelect create a field traversal Expr value. + NewSelect(operand ast.Expr, field string) ast.Expr // OffsetLocation returns the Location of the expression identifier. OffsetLocation(exprID int64) common.Location @@ -296,21 +277,21 @@ const ( // MakeAll expands the input call arguments into a comprehension that returns true if all of the // elements in the range match the predicate expressions: // .all(, ) -func MakeAll(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeAll(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { return makeQuantifier(quantifierAll, eh, target, args) } // MakeExists expands the input call arguments into a comprehension that returns true if any of the // elements in the range match the predicate expressions: // .exists(, ) -func MakeExists(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeExists(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { return makeQuantifier(quantifierExists, eh, target, args) } // MakeExistsOne expands the input call arguments into a comprehension that returns true if exactly // one of the elements in the range match the predicate expressions: // .exists_one(, ) -func MakeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeExistsOne(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { return makeQuantifier(quantifierExistsOne, eh, target, args) } @@ -324,14 +305,14 @@ func MakeExistsOne(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*ex // // In the second form only iterVar values which return true when provided to the predicate expression // are transformed. -func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeMap(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - return nil, eh.NewError(args[0].GetId(), "argument is not an identifier") + return nil, eh.NewError(args[0].ID(), "argument is not an identifier") } - var fn *exprpb.Expr - var filter *exprpb.Expr + var fn ast.Expr + var filter ast.Expr if len(args) == 3 { filter = args[1] @@ -341,84 +322,85 @@ func MakeMap(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.E fn = args[1] } - accuExpr := eh.Ident(AccumulatorName) + accuExpr := eh.NewAccuIdent() init := eh.NewList() - condition := eh.LiteralBool(true) - step := eh.GlobalCall(operators.Add, accuExpr, eh.NewList(fn)) + condition := eh.NewLiteral(types.True) + step := eh.NewCall(operators.Add, accuExpr, eh.NewList(fn)) if filter != nil { - step = eh.GlobalCall(operators.Conditional, filter, step, accuExpr) + step = eh.NewCall(operators.Conditional, filter, step, accuExpr) } - return eh.Fold(v, target, AccumulatorName, init, condition, step, accuExpr), nil + return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, accuExpr), nil } // MakeFilter expands the input call arguments into a comprehension which produces a list which contains // only elements which match the provided predicate expression: // .filter(, ) -func MakeFilter(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func MakeFilter(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - return nil, eh.NewError(args[0].GetId(), "argument is not an identifier") + return nil, eh.NewError(args[0].ID(), "argument is not an identifier") } filter := args[1] - accuExpr := eh.Ident(AccumulatorName) + accuExpr := eh.NewAccuIdent() init := eh.NewList() - condition := eh.LiteralBool(true) - step := eh.GlobalCall(operators.Add, accuExpr, eh.NewList(args[0])) - step = eh.GlobalCall(operators.Conditional, filter, step, accuExpr) - return eh.Fold(v, target, AccumulatorName, init, condition, step, accuExpr), nil + condition := eh.NewLiteral(types.True) + step := eh.NewCall(operators.Add, accuExpr, eh.NewList(args[0])) + step = eh.NewCall(operators.Conditional, filter, step, accuExpr) + return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, accuExpr), nil } // MakeHas expands the input call arguments into a presence test, e.g. has(.field) -func MakeHas(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { - if s, ok := args[0].ExprKind.(*exprpb.Expr_SelectExpr); ok { - return eh.PresenceTest(s.SelectExpr.GetOperand(), s.SelectExpr.GetField()), nil +func MakeHas(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { + if args[0].Kind() == ast.SelectKind { + s := args[0].AsSelect() + return eh.NewPresenceTest(s.Operand(), s.FieldName()), nil } - return nil, eh.NewError(args[0].GetId(), "invalid argument to has() macro") + return nil, eh.NewError(args[0].ID(), "invalid argument to has() macro") } -func makeQuantifier(kind quantifierKind, eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { +func makeQuantifier(kind quantifierKind, eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { v, found := extractIdent(args[0]) if !found { - return nil, eh.NewError(args[0].GetId(), "argument must be a simple name") + return nil, eh.NewError(args[0].ID(), "argument must be a simple name") } - var init *exprpb.Expr - var condition *exprpb.Expr - var step *exprpb.Expr - var result *exprpb.Expr + var init ast.Expr + var condition ast.Expr + var step ast.Expr + var result ast.Expr switch kind { case quantifierAll: - init = eh.LiteralBool(true) - condition = eh.GlobalCall(operators.NotStrictlyFalse, eh.AccuIdent()) - step = eh.GlobalCall(operators.LogicalAnd, eh.AccuIdent(), args[1]) - result = eh.AccuIdent() + init = eh.NewLiteral(types.True) + condition = eh.NewCall(operators.NotStrictlyFalse, eh.NewAccuIdent()) + step = eh.NewCall(operators.LogicalAnd, eh.NewAccuIdent(), args[1]) + result = eh.NewAccuIdent() case quantifierExists: - init = eh.LiteralBool(false) - condition = eh.GlobalCall( + init = eh.NewLiteral(types.False) + condition = eh.NewCall( operators.NotStrictlyFalse, - eh.GlobalCall(operators.LogicalNot, eh.AccuIdent())) - step = eh.GlobalCall(operators.LogicalOr, eh.AccuIdent(), args[1]) - result = eh.AccuIdent() + eh.NewCall(operators.LogicalNot, eh.NewAccuIdent())) + step = eh.NewCall(operators.LogicalOr, eh.NewAccuIdent(), args[1]) + result = eh.NewAccuIdent() case quantifierExistsOne: - zeroExpr := eh.LiteralInt(0) - oneExpr := eh.LiteralInt(1) + zeroExpr := eh.NewLiteral(types.Int(0)) + oneExpr := eh.NewLiteral(types.Int(1)) init = zeroExpr - condition = eh.LiteralBool(true) - step = eh.GlobalCall(operators.Conditional, args[1], - eh.GlobalCall(operators.Add, eh.AccuIdent(), oneExpr), eh.AccuIdent()) - result = eh.GlobalCall(operators.Equals, eh.AccuIdent(), oneExpr) + condition = eh.NewLiteral(types.True) + step = eh.NewCall(operators.Conditional, args[1], + eh.NewCall(operators.Add, eh.NewAccuIdent(), oneExpr), eh.NewAccuIdent()) + result = eh.NewCall(operators.Equals, eh.NewAccuIdent(), oneExpr) default: - return nil, eh.NewError(args[0].GetId(), fmt.Sprintf("unrecognized quantifier '%v'", kind)) + return nil, eh.NewError(args[0].ID(), fmt.Sprintf("unrecognized quantifier '%v'", kind)) } - return eh.Fold(v, target, AccumulatorName, init, condition, step, result), nil + return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, result), nil } -func extractIdent(e *exprpb.Expr) (string, bool) { - switch e.ExprKind.(type) { - case *exprpb.Expr_IdentExpr: - return e.GetIdentExpr().GetName(), true +func extractIdent(e ast.Expr) (string, bool) { + switch e.Kind() { + case ast.IdentKind: + return e.AsIdent(), true } return "", false } diff --git a/parser/parser.go b/parser/parser.go index 292b64cb6..c408bd355 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -29,10 +29,8 @@ import ( "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/runes" + "github.com/google/cel-go/common/types" "github.com/google/cel-go/parser/gen" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - structpb "google.golang.org/protobuf/types/known/structpb" ) // Parser encapsulates the context necessary to perform parsing for different expressions. @@ -91,9 +89,11 @@ func mustNewParser(opts ...Option) *Parser { // Parse parses the expression represented by source and returns the result. func (p *Parser) Parse(source common.Source) (*ast.AST, *common.Errors) { errs := common.NewErrors(source) + fac := ast.NewExprFactory() impl := parser{ errors: &parseErrors{errs}, - helper: newParserHelper(source), + exprFactory: fac, + helper: newParserHelper(source, fac), macros: p.macros, maxRecursionDepth: p.maxRecursionDepth, errorReportingLimit: p.errorReportingLimit, @@ -107,25 +107,15 @@ func (p *Parser) Parse(source common.Source) (*ast.AST, *common.Errors) { if !ok { buf = runes.NewBuffer(source.Content()) } - var e *exprpb.Expr + var out ast.Expr if buf.Len() > p.expressionSizeCodePointLimit { - e = impl.reportError(common.NoLocation, + out = impl.reportError(common.NoLocation, "expression code point size exceeds limit: size: %d, limit %d", buf.Len(), p.expressionSizeCodePointLimit) } else { - e = impl.parse(buf, source.Description()) - } - sourceInfo, err := ast.ProtoToSourceInfo(impl.helper.getSourceInfo()) - if err != nil { - impl.reportError(common.NoLocation, - "failed to convert proto source info: %v", impl.helper.getSourceInfo()) - } - out, err := ast.ProtoToExpr(e) - if err != nil { - impl.reportError(common.NoLocation, - "failed to convert proto expression: %v", e) + out = impl.parse(buf, source.Description()) } - return ast.NewAST(out, sourceInfo), errs + return ast.NewAST(out, impl.helper.getSourceInfo()), errs } // reservedIds are not legal to use as variables. We exclude them post-parse, as they *are* valid @@ -295,6 +285,7 @@ var _ antlr.ErrorStrategy = &recoveryLimitErrorStrategy{} type parser struct { gen.BaseCELVisitor errors *parseErrors + exprFactory ast.ExprFactory helper *parserHelper macros map[string]Macro recursionDepth int @@ -328,7 +319,7 @@ var ( } ) -func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr { +func (p *parser) parse(expr runes.Buffer, desc string) ast.Expr { // TODO: get rid of these pools once https://github.com/antlr/antlr4/pull/3571 is in a release lexer := lexerPool.Get().(*gen.CELLexer) prsr := parserPool.Get().(*gen.CELParser) @@ -381,7 +372,7 @@ func (p *parser) parse(expr runes.Buffer, desc string) *exprpb.Expr { } }() - return p.Visit(prsr.Start()).(*exprpb.Expr) + return p.Visit(prsr.Start()).(ast.Expr) } // Visitor implementations. @@ -478,26 +469,26 @@ func (p *parser) VisitStart(ctx *gen.StartContext) any { // Visit a parse tree produced by CELParser#expr. func (p *parser) VisitExpr(ctx *gen.ExprContext) any { - result := p.Visit(ctx.GetE()).(*exprpb.Expr) + result := p.Visit(ctx.GetE()).(ast.Expr) if ctx.GetOp() == nil { return result } opID := p.helper.id(ctx.GetOp()) - ifTrue := p.Visit(ctx.GetE1()).(*exprpb.Expr) - ifFalse := p.Visit(ctx.GetE2()).(*exprpb.Expr) + ifTrue := p.Visit(ctx.GetE1()).(ast.Expr) + ifFalse := p.Visit(ctx.GetE2()).(ast.Expr) return p.globalCallOrMacro(opID, operators.Conditional, result, ifTrue, ifFalse) } // Visit a parse tree produced by CELParser#conditionalOr. func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any { - result := p.Visit(ctx.GetE()).(*exprpb.Expr) + result := p.Visit(ctx.GetE()).(ast.Expr) l := p.newLogicManager(operators.LogicalOr, result) rest := ctx.GetE1() for i, op := range ctx.GetOps() { if i >= len(rest) { return p.reportError(ctx, "unexpected character, wanted '||'") } - next := p.Visit(rest[i]).(*exprpb.Expr) + next := p.Visit(rest[i]).(ast.Expr) opID := p.helper.id(op) l.addTerm(opID, next) } @@ -506,14 +497,14 @@ func (p *parser) VisitConditionalOr(ctx *gen.ConditionalOrContext) any { // Visit a parse tree produced by CELParser#conditionalAnd. func (p *parser) VisitConditionalAnd(ctx *gen.ConditionalAndContext) any { - result := p.Visit(ctx.GetE()).(*exprpb.Expr) + result := p.Visit(ctx.GetE()).(ast.Expr) l := p.newLogicManager(operators.LogicalAnd, result) rest := ctx.GetE1() for i, op := range ctx.GetOps() { if i >= len(rest) { return p.reportError(ctx, "unexpected character, wanted '&&'") } - next := p.Visit(rest[i]).(*exprpb.Expr) + next := p.Visit(rest[i]).(ast.Expr) opID := p.helper.id(op) l.addTerm(opID, next) } @@ -527,9 +518,9 @@ func (p *parser) VisitRelation(ctx *gen.RelationContext) any { opText = ctx.GetOp().GetText() } if op, found := operators.Find(opText); found { - lhs := p.Visit(ctx.Relation(0)).(*exprpb.Expr) + lhs := p.Visit(ctx.Relation(0)).(ast.Expr) opID := p.helper.id(ctx.GetOp()) - rhs := p.Visit(ctx.Relation(1)).(*exprpb.Expr) + rhs := p.Visit(ctx.Relation(1)).(ast.Expr) return p.globalCallOrMacro(opID, op, lhs, rhs) } return p.reportError(ctx, "operator not found") @@ -542,9 +533,9 @@ func (p *parser) VisitCalc(ctx *gen.CalcContext) any { opText = ctx.GetOp().GetText() } if op, found := operators.Find(opText); found { - lhs := p.Visit(ctx.Calc(0)).(*exprpb.Expr) + lhs := p.Visit(ctx.Calc(0)).(ast.Expr) opID := p.helper.id(ctx.GetOp()) - rhs := p.Visit(ctx.Calc(1)).(*exprpb.Expr) + rhs := p.Visit(ctx.Calc(1)).(ast.Expr) return p.globalCallOrMacro(opID, op, lhs, rhs) } return p.reportError(ctx, "operator not found") @@ -560,7 +551,7 @@ func (p *parser) VisitLogicalNot(ctx *gen.LogicalNotContext) any { return p.Visit(ctx.Member()) } opID := p.helper.id(ctx.GetOps()[0]) - target := p.Visit(ctx.Member()).(*exprpb.Expr) + target := p.Visit(ctx.Member()).(ast.Expr) return p.globalCallOrMacro(opID, operators.LogicalNot, target) } @@ -569,13 +560,13 @@ func (p *parser) VisitNegate(ctx *gen.NegateContext) any { return p.Visit(ctx.Member()) } opID := p.helper.id(ctx.GetOps()[0]) - target := p.Visit(ctx.Member()).(*exprpb.Expr) + target := p.Visit(ctx.Member()).(ast.Expr) return p.globalCallOrMacro(opID, operators.Negate, target) } // VisitSelect visits a parse tree produced by CELParser#Select. func (p *parser) VisitSelect(ctx *gen.SelectContext) any { - operand := p.Visit(ctx.Member()).(*exprpb.Expr) + operand := p.Visit(ctx.Member()).(ast.Expr) // Handle the error case where no valid identifier is specified. if ctx.GetId() == nil || ctx.GetOp() == nil { return p.helper.newExpr(ctx) @@ -596,7 +587,7 @@ func (p *parser) VisitSelect(ctx *gen.SelectContext) any { // VisitMemberCall visits a parse tree produced by CELParser#MemberCall. func (p *parser) VisitMemberCall(ctx *gen.MemberCallContext) any { - operand := p.Visit(ctx.Member()).(*exprpb.Expr) + operand := p.Visit(ctx.Member()).(ast.Expr) // Handle the error case where no valid identifier is specified. if ctx.GetId() == nil { return p.helper.newExpr(ctx) @@ -608,13 +599,13 @@ func (p *parser) VisitMemberCall(ctx *gen.MemberCallContext) any { // Visit a parse tree produced by CELParser#Index. func (p *parser) VisitIndex(ctx *gen.IndexContext) any { - target := p.Visit(ctx.Member()).(*exprpb.Expr) + target := p.Visit(ctx.Member()).(ast.Expr) // Handle the error case where no valid identifier is specified. if ctx.GetOp() == nil { return p.helper.newExpr(ctx) } opID := p.helper.id(ctx.GetOp()) - index := p.Visit(ctx.GetIndex()).(*exprpb.Expr) + index := p.Visit(ctx.GetIndex()).(ast.Expr) operator := operators.Index if ctx.GetOpt() != nil { if !p.enableOptionalSyntax { @@ -638,7 +629,7 @@ func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) any { messageName = "." + messageName } objID := p.helper.id(ctx.GetOp()) - entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry) + entries := p.VisitIFieldInitializerList(ctx.GetEntries()).([]ast.EntryExpr) return p.helper.newObject(objID, messageName, entries...) } @@ -646,16 +637,16 @@ func (p *parser) VisitCreateMessage(ctx *gen.CreateMessageContext) any { func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext) any { if ctx == nil || ctx.GetFields() == nil { // This is the result of a syntax error handled elswhere, return empty. - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } - result := make([]*exprpb.Expr_CreateStruct_Entry, len(ctx.GetFields())) + result := make([]ast.EntryExpr, len(ctx.GetFields())) cols := ctx.GetCols() vals := ctx.GetValues() for i, f := range ctx.GetFields() { if i >= len(cols) || i >= len(vals) { // This is the result of a syntax error detected elsewhere. - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } initID := p.helper.id(cols[i]) optField := f.(*gen.OptFieldContext) @@ -667,10 +658,10 @@ func (p *parser) VisitIFieldInitializerList(ctx gen.IFieldInitializerListContext // The field may be empty due to a prior error. id := optField.IDENTIFIER() if id == nil { - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } fieldName := id.GetText() - value := p.Visit(vals[i]).(*exprpb.Expr) + value := p.Visit(vals[i]).(ast.Expr) field := p.helper.newObjectField(initID, fieldName, value, optional) result[i] = field } @@ -710,9 +701,9 @@ func (p *parser) VisitCreateList(ctx *gen.CreateListContext) any { // Visit a parse tree produced by CELParser#CreateStruct. func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) any { structID := p.helper.id(ctx.GetOp()) - entries := []*exprpb.Expr_CreateStruct_Entry{} + entries := []ast.EntryExpr{} if ctx.GetEntries() != nil { - entries = p.Visit(ctx.GetEntries()).([]*exprpb.Expr_CreateStruct_Entry) + entries = p.Visit(ctx.GetEntries()).([]ast.EntryExpr) } return p.helper.newMap(structID, entries...) } @@ -721,17 +712,17 @@ func (p *parser) VisitCreateStruct(ctx *gen.CreateStructContext) any { func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) any { if ctx == nil || ctx.GetKeys() == nil { // This is the result of a syntax error handled elswhere, return empty. - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } - result := make([]*exprpb.Expr_CreateStruct_Entry, len(ctx.GetCols())) + result := make([]ast.EntryExpr, len(ctx.GetCols())) keys := ctx.GetKeys() vals := ctx.GetValues() for i, col := range ctx.GetCols() { colID := p.helper.id(col) if i >= len(keys) || i >= len(vals) { // This is the result of a syntax error detected elsewhere. - return []*exprpb.Expr_CreateStruct_Entry{} + return []ast.EntryExpr{} } optKey := keys[i] optional := optKey.GetOpt() != nil @@ -739,8 +730,8 @@ func (p *parser) VisitMapInitializerList(ctx *gen.MapInitializerListContext) any p.reportError(optKey, "unsupported syntax '?'") continue } - key := p.Visit(optKey.GetE()).(*exprpb.Expr) - value := p.Visit(vals[i]).(*exprpb.Expr) + key := p.Visit(optKey.GetE()).(ast.Expr) + value := p.Visit(vals[i]).(ast.Expr) entry := p.helper.newMapEntry(colID, key, value, optional) result[i] = entry } @@ -820,30 +811,27 @@ func (p *parser) VisitBoolFalse(ctx *gen.BoolFalseContext) any { // Visit a parse tree produced by CELParser#Null. func (p *parser) VisitNull(ctx *gen.NullContext) any { - return p.helper.newLiteral(ctx, - &exprpb.Constant{ - ConstantKind: &exprpb.Constant_NullValue{ - NullValue: structpb.NullValue_NULL_VALUE}}) + return p.helper.exprFactory.NewLiteral(p.helper.newID(ctx), types.NullValue) } -func (p *parser) visitExprList(ctx gen.IExprListContext) []*exprpb.Expr { +func (p *parser) visitExprList(ctx gen.IExprListContext) []ast.Expr { if ctx == nil { - return []*exprpb.Expr{} + return []ast.Expr{} } return p.visitSlice(ctx.GetE()) } -func (p *parser) visitListInit(ctx gen.IListInitContext) ([]*exprpb.Expr, []int32) { +func (p *parser) visitListInit(ctx gen.IListInitContext) ([]ast.Expr, []int32) { if ctx == nil { - return []*exprpb.Expr{}, []int32{} + return []ast.Expr{}, []int32{} } elements := ctx.GetElems() - result := make([]*exprpb.Expr, len(elements)) + result := make([]ast.Expr, len(elements)) optionals := []int32{} for i, e := range elements { - ex := p.Visit(e.GetE()).(*exprpb.Expr) + ex := p.Visit(e.GetE()).(ast.Expr) if ex == nil { - return []*exprpb.Expr{}, []int32{} + return []ast.Expr{}, []int32{} } result[i] = ex if e.GetOpt() != nil { @@ -857,13 +845,13 @@ func (p *parser) visitListInit(ctx gen.IListInitContext) ([]*exprpb.Expr, []int3 return result, optionals } -func (p *parser) visitSlice(expressions []gen.IExprContext) []*exprpb.Expr { +func (p *parser) visitSlice(expressions []gen.IExprContext) []ast.Expr { if expressions == nil { - return []*exprpb.Expr{} + return []ast.Expr{} } - result := make([]*exprpb.Expr, len(expressions)) + result := make([]ast.Expr, len(expressions)) for i, e := range expressions { - ex := p.Visit(e).(*exprpb.Expr) + ex := p.Visit(e).(ast.Expr) result[i] = ex } return result @@ -878,24 +866,24 @@ func (p *parser) unquote(ctx any, value string, isBytes bool) string { return text } -func (p *parser) newLogicManager(function string, term *exprpb.Expr) *logicManager { +func (p *parser) newLogicManager(function string, term ast.Expr) *logicManager { if p.enableVariadicOperatorASTs { - return newVariadicLogicManager(p.helper, function, term) + return newVariadicLogicManager(p.exprFactory, function, term) } - return newBalancingLogicManager(p.helper, function, term) + return newBalancingLogicManager(p.exprFactory, function, term) } -func (p *parser) reportError(ctx any, format string, args ...any) *exprpb.Expr { +func (p *parser) reportError(ctx any, format string, args ...any) ast.Expr { var location common.Location err := p.helper.newExpr(ctx) switch c := ctx.(type) { case common.Location: location = c case antlr.Token, antlr.ParserRuleContext: - location = p.helper.getLocation(err.GetId()) + location = p.helper.getLocation(err.ID()) } // Provide arguments to the report error. - p.errors.reportErrorAtID(err.GetId(), location, format, args...) + p.errors.reportErrorAtID(err.ID(), location, format, args...) return err } @@ -932,21 +920,21 @@ func (p *parser) ReportContextSensitivity(recognizer antlr.Parser, dfa *antlr.DF // Intentional } -func (p *parser) globalCallOrMacro(exprID int64, function string, args ...*exprpb.Expr) *exprpb.Expr { +func (p *parser) globalCallOrMacro(exprID int64, function string, args ...ast.Expr) ast.Expr { if expr, found := p.expandMacro(exprID, function, nil, args...); found { return expr } return p.helper.newGlobalCall(exprID, function, args...) } -func (p *parser) receiverCallOrMacro(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) *exprpb.Expr { +func (p *parser) receiverCallOrMacro(exprID int64, function string, target ast.Expr, args ...ast.Expr) ast.Expr { if expr, found := p.expandMacro(exprID, function, target, args...); found { return expr } return p.helper.newReceiverCall(exprID, function, target, args...) } -func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr, args ...*exprpb.Expr) (*exprpb.Expr, bool) { +func (p *parser) expandMacro(exprID int64, function string, target ast.Expr, args ...ast.Expr) (ast.Expr, bool) { macro, found := p.macros[makeMacroKey(function, len(args), target != nil)] if !found { macro, found = p.macros[makeVarArgMacroKey(function, target != nil)] @@ -972,7 +960,7 @@ func (p *parser) expandMacro(exprID int64, function string, target *exprpb.Expr, return nil, false } if p.populateMacroCalls { - p.helper.addMacroCall(expr.GetId(), function, target, args...) + p.helper.addMacroCall(expr.ID(), function, target, args...) } return expr, true } diff --git a/parser/parser_test.go b/parser/parser_test.go index d34247d90..8b4964cb0 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -26,8 +26,6 @@ import ( "github.com/google/cel-go/common/debug" "github.com/google/cel-go/common/types" "github.com/google/cel-go/test" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) var testCases = []testInfo{ @@ -1708,7 +1706,7 @@ var testCases = []testInfo{ I: `noop_macro(123)`, Opts: []Option{ Macros(NewGlobalVarArgMacro("noop_macro", - func(eh ExprHelper, target *exprpb.Expr, args []*exprpb.Expr) (*exprpb.Expr, *common.Error) { + func(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common.Error) { return nil, nil })), }, From 04536922515e3de9de498f689b40f70d170b518b Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 18 Aug 2023 12:05:31 -0400 Subject: [PATCH 17/31] Migrate the checker.Coster to the ast.Expr (#798) Migrate Cost calculations to Go-native Expr. Note, this is a breaking changes as the type of the checker.AstNode has been modified to return a Go-native Expr value rather than its protobuf equivalent. --- checker/cost.go | 184 ++++++++++++++++++++++-------------------------- 1 file changed, 83 insertions(+), 101 deletions(-) diff --git a/checker/cost.go b/checker/cost.go index 06c57c1e9..b6109d918 100644 --- a/checker/cost.go +++ b/checker/cost.go @@ -22,8 +22,6 @@ import ( "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" "github.com/google/cel-go/parser" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" ) // WARNING: Any changes to cost calculations in this file require a corresponding change in interpreter/runtimecost.go @@ -58,7 +56,7 @@ type AstNode interface { // Type returns the deduced type of the AstNode. Type() *types.Type // Expr returns the expression of the AstNode. - Expr() *exprpb.Expr + Expr() ast.Expr // ComputedSize returns a size estimate of the AstNode derived from information available in the CEL expression. // For constants and inline list and map declarations, the exact size is returned. For concatenated list, strings // and bytes, the size is derived from the size estimates of the operands. nil is returned if there is no @@ -69,7 +67,7 @@ type AstNode interface { type astNode struct { path []string t *types.Type - expr *exprpb.Expr + expr ast.Expr derivedSize *SizeEstimate } @@ -81,7 +79,7 @@ func (e astNode) Type() *types.Type { return e.t } -func (e astNode) Expr() *exprpb.Expr { +func (e astNode) Expr() ast.Expr { return e.expr } @@ -90,29 +88,27 @@ func (e astNode) ComputedSize() *SizeEstimate { return e.derivedSize } var v uint64 - switch ek := e.expr.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: - switch ck := ek.ConstExpr.GetConstantKind().(type) { - case *exprpb.Constant_StringValue: + switch e.expr.Kind() { + case ast.LiteralKind: + switch ck := e.expr.AsLiteral().(type) { + case types.String: // converting to runes here is an O(n) operation, but // this is consistent with how size is computed at runtime, // and how the language definition defines string size - v = uint64(len([]rune(ck.StringValue))) - case *exprpb.Constant_BytesValue: - v = uint64(len(ck.BytesValue)) - case *exprpb.Constant_BoolValue, *exprpb.Constant_DoubleValue, *exprpb.Constant_DurationValue, - *exprpb.Constant_Int64Value, *exprpb.Constant_TimestampValue, *exprpb.Constant_Uint64Value, - *exprpb.Constant_NullValue: + v = uint64(len([]rune(ck))) + case types.Bytes: + v = uint64(len(ck)) + case types.Bool, types.Double, types.Duration, + types.Int, types.Timestamp, types.Uint, + types.Null: v = uint64(1) default: return nil } - case *exprpb.Expr_ListExpr: - v = uint64(len(ek.ListExpr.GetElements())) - case *exprpb.Expr_StructExpr: - if ek.StructExpr.GetMessageName() == "" { - v = uint64(len(ek.StructExpr.GetEntries())) - } + case ast.ListKind: + v = uint64(e.expr.AsList().Size()) + case ast.MapKind: + v = uint64(e.expr.AsMap().Size()) default: return nil } @@ -270,8 +266,8 @@ type coster struct { // Use a stack of iterVar -> iterRange Expr Ids to handle shadowed variable names. type iterRangeScopes map[string][]int64 -func (vs iterRangeScopes) push(varName string, expr *exprpb.Expr) { - vs[varName] = append(vs[varName], expr.GetId()) +func (vs iterRangeScopes) push(varName string, expr ast.Expr) { + vs[varName] = append(vs[varName], expr.ID()) } func (vs iterRangeScopes) pop(varName string) { @@ -319,32 +315,30 @@ func Cost(checked *ast.AST, estimator CostEstimator, opts ...CostOption) (CostEs return CostEstimate{}, err } } - epb, err := ast.ExprToProto(checked.Expr()) - if err != nil { - return CostEstimate{}, err - } - return c.cost(epb), nil + return c.cost(checked.Expr()), nil } -func (c *coster) cost(e *exprpb.Expr) CostEstimate { +func (c *coster) cost(e ast.Expr) CostEstimate { if e == nil { return CostEstimate{} } var cost CostEstimate - switch e.GetExprKind().(type) { - case *exprpb.Expr_ConstExpr: + switch e.Kind() { + case ast.LiteralKind: cost = constCost - case *exprpb.Expr_IdentExpr: + case ast.IdentKind: cost = c.costIdent(e) - case *exprpb.Expr_SelectExpr: + case ast.SelectKind: cost = c.costSelect(e) - case *exprpb.Expr_CallExpr: + case ast.CallKind: cost = c.costCall(e) - case *exprpb.Expr_ListExpr: + case ast.ListKind: cost = c.costCreateList(e) - case *exprpb.Expr_StructExpr: + case ast.MapKind: + cost = c.costCreateMap(e) + case ast.StructKind: cost = c.costCreateStruct(e) - case *exprpb.Expr_ComprehensionExpr: + case ast.ComprehensionKind: cost = c.costComprehension(e) default: return CostEstimate{} @@ -352,11 +346,10 @@ func (c *coster) cost(e *exprpb.Expr) CostEstimate { return cost } -func (c *coster) costIdent(e *exprpb.Expr) CostEstimate { - identExpr := e.GetIdentExpr() - +func (c *coster) costIdent(e ast.Expr) CostEstimate { + identName := e.AsIdent() // build and track the field path - if iterRange, ok := c.iterRanges.peek(identExpr.GetName()); ok { + if iterRange, ok := c.iterRanges.peek(identName); ok { switch c.checkedAST.GetType(iterRange).Kind() { case types.ListKind: c.addPath(e, append(c.exprPath[iterRange], "@items")) @@ -364,41 +357,40 @@ func (c *coster) costIdent(e *exprpb.Expr) CostEstimate { c.addPath(e, append(c.exprPath[iterRange], "@keys")) } } else { - c.addPath(e, []string{identExpr.GetName()}) + c.addPath(e, []string{identName}) } return selectAndIdentCost } -func (c *coster) costSelect(e *exprpb.Expr) CostEstimate { - sel := e.GetSelectExpr() +func (c *coster) costSelect(e ast.Expr) CostEstimate { + sel := e.AsSelect() var sum CostEstimate - if sel.GetTestOnly() { + if sel.IsTestOnly() { // recurse, but do not add any cost // this is equivalent to how evalTestOnly increments the runtime cost counter // but does not add any additional cost for the qualifier, except here we do // the reverse (ident adds cost) sum = sum.Add(c.presenceTestCost) - sum = sum.Add(c.cost(sel.GetOperand())) + sum = sum.Add(c.cost(sel.Operand())) return sum } - sum = sum.Add(c.cost(sel.GetOperand())) - targetType := c.getType(sel.GetOperand()) + sum = sum.Add(c.cost(sel.Operand())) + targetType := c.getType(sel.Operand()) switch targetType.Kind() { case types.MapKind, types.StructKind, types.TypeParamKind: sum = sum.Add(selectAndIdentCost) } // build and track the field path - c.addPath(e, append(c.getPath(sel.GetOperand()), sel.GetField())) + c.addPath(e, append(c.getPath(sel.Operand()), sel.FieldName())) return sum } -func (c *coster) costCall(e *exprpb.Expr) CostEstimate { - call := e.GetCallExpr() - target := call.GetTarget() - args := call.GetArgs() +func (c *coster) costCall(e ast.Expr) CostEstimate { + call := e.AsCall() + args := call.Args() var sum CostEstimate @@ -409,22 +401,20 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate { argTypes[i] = c.newAstNode(arg) } - overloadIDs := c.checkedAST.GetOverloadIDs(e.GetId()) + overloadIDs := c.checkedAST.GetOverloadIDs(e.ID()) if len(overloadIDs) == 0 { return CostEstimate{} } var targetType AstNode - if target != nil { - if call.Target != nil { - sum = sum.Add(c.cost(call.GetTarget())) - targetType = c.newAstNode(call.GetTarget()) - } + if call.IsMemberFunction() { + sum = sum.Add(c.cost(call.Target())) + targetType = c.newAstNode(call.Target()) } // Pick a cost estimate range that covers all the overload cost estimation ranges fnCost := CostEstimate{Min: uint64(math.MaxUint64), Max: 0} var resultSize *SizeEstimate for _, overload := range overloadIDs { - overloadCost := c.functionCost(call.GetFunction(), overload, &targetType, argTypes, argCosts) + overloadCost := c.functionCost(call.FunctionName(), overload, &targetType, argTypes, argCosts) fnCost = fnCost.Union(overloadCost.CostEstimate) if overloadCost.ResultSize != nil { if resultSize == nil { @@ -447,62 +437,54 @@ func (c *coster) costCall(e *exprpb.Expr) CostEstimate { } } if resultSize != nil { - c.computedSizes[e.GetId()] = *resultSize + c.computedSizes[e.ID()] = *resultSize } return sum.Add(fnCost) } -func (c *coster) costCreateList(e *exprpb.Expr) CostEstimate { - create := e.GetListExpr() +func (c *coster) costCreateList(e ast.Expr) CostEstimate { + create := e.AsList() var sum CostEstimate - for _, e := range create.GetElements() { + for _, e := range create.Elements() { sum = sum.Add(c.cost(e)) } return sum.Add(createListBaseCost) } -func (c *coster) costCreateStruct(e *exprpb.Expr) CostEstimate { - str := e.GetStructExpr() - if str.MessageName != "" { - return c.costCreateMessage(e) - } - return c.costCreateMap(e) -} - -func (c *coster) costCreateMap(e *exprpb.Expr) CostEstimate { - mapVal := e.GetStructExpr() +func (c *coster) costCreateMap(e ast.Expr) CostEstimate { + mapVal := e.AsMap() var sum CostEstimate - for _, ent := range mapVal.GetEntries() { - key := ent.GetMapKey() - sum = sum.Add(c.cost(key)) - - sum = sum.Add(c.cost(ent.GetValue())) + for _, ent := range mapVal.Entries() { + entry := ent.AsMapEntry() + sum = sum.Add(c.cost(entry.Key())) + sum = sum.Add(c.cost(entry.Value())) } return sum.Add(createMapBaseCost) } -func (c *coster) costCreateMessage(e *exprpb.Expr) CostEstimate { - msgVal := e.GetStructExpr() +func (c *coster) costCreateStruct(e ast.Expr) CostEstimate { + msgVal := e.AsStruct() var sum CostEstimate - for _, ent := range msgVal.GetEntries() { - sum = sum.Add(c.cost(ent.GetValue())) + for _, ent := range msgVal.Fields() { + field := ent.AsStructField() + sum = sum.Add(c.cost(field.Value())) } return sum.Add(createMessageBaseCost) } -func (c *coster) costComprehension(e *exprpb.Expr) CostEstimate { - comp := e.GetComprehensionExpr() +func (c *coster) costComprehension(e ast.Expr) CostEstimate { + comp := e.AsComprehension() var sum CostEstimate - sum = sum.Add(c.cost(comp.GetIterRange())) - sum = sum.Add(c.cost(comp.GetAccuInit())) + sum = sum.Add(c.cost(comp.IterRange())) + sum = sum.Add(c.cost(comp.AccuInit())) // Track the iterRange of each IterVar for field path construction - c.iterRanges.push(comp.GetIterVar(), comp.GetIterRange()) - loopCost := c.cost(comp.GetLoopCondition()) - stepCost := c.cost(comp.GetLoopStep()) - c.iterRanges.pop(comp.GetIterVar()) - sum = sum.Add(c.cost(comp.Result)) - rangeCnt := c.sizeEstimate(c.newAstNode(comp.GetIterRange())) + c.iterRanges.push(comp.IterVar(), comp.IterRange()) + loopCost := c.cost(comp.LoopCondition()) + stepCost := c.cost(comp.LoopStep()) + c.iterRanges.pop(comp.IterVar()) + sum = sum.Add(c.cost(comp.Result())) + rangeCnt := c.sizeEstimate(c.newAstNode(comp.IterRange())) rangeCost := rangeCnt.MultiplyByCost(stepCost.Add(loopCost)) sum = sum.Add(rangeCost) @@ -647,26 +629,26 @@ func (c *coster) functionCost(function, overloadID string, target *AstNode, args return CallEstimate{CostEstimate: CostEstimate{Min: 1, Max: 1}.Add(argCostSum())} } -func (c *coster) getType(e *exprpb.Expr) *types.Type { - return c.checkedAST.GetType(e.GetId()) +func (c *coster) getType(e ast.Expr) *types.Type { + return c.checkedAST.GetType(e.ID()) } -func (c *coster) getPath(e *exprpb.Expr) []string { - return c.exprPath[e.GetId()] +func (c *coster) getPath(e ast.Expr) []string { + return c.exprPath[e.ID()] } -func (c *coster) addPath(e *exprpb.Expr, path []string) { - c.exprPath[e.GetId()] = path +func (c *coster) addPath(e ast.Expr, path []string) { + c.exprPath[e.ID()] = path } -func (c *coster) newAstNode(e *exprpb.Expr) *astNode { +func (c *coster) newAstNode(e ast.Expr) *astNode { path := c.getPath(e) if len(path) > 0 && path[0] == parser.AccumulatorName { // only provide paths to root vars; omit accumulator vars path = nil } var derivedSize *SizeEstimate - if size, ok := c.computedSizes[e.GetId()]; ok { + if size, ok := c.computedSizes[e.ID()]; ok { derivedSize = &size } return &astNode{ From b6d0e04a321ce2543bb43204c7ee578888b86c4d Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 18 Aug 2023 12:06:17 -0400 Subject: [PATCH 18/31] Migrate the interpreter.PruneAst to the Go-native Expr (#799) Migrate pruner to Go-native Expr --- interpreter/prune.go | 497 ++++++++++++++++---------------------- interpreter/prune_test.go | 4 +- 2 files changed, 208 insertions(+), 293 deletions(-) diff --git a/interpreter/prune.go b/interpreter/prune.go index 504e5bf6d..410d80dc4 100644 --- a/interpreter/prune.go +++ b/interpreter/prune.go @@ -21,14 +21,12 @@ import ( "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" "github.com/google/cel-go/common/types/traits" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - structpb "google.golang.org/protobuf/types/known/structpb" ) type astPruner struct { - expr *exprpb.Expr - macroCalls map[int64]*exprpb.Expr + ast.ExprFactory + expr ast.Expr + macroCalls map[int64]ast.Expr state EvalState nextExprID int64 } @@ -74,86 +72,38 @@ func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *as v, _ := state.Value(id) pruneState.SetValue(id, v) } - pbExpr, _ := ast.ExprToProto(expr) - pbMacroCalls := make(map[int64]*exprpb.Expr, len(macroCalls)) - for id, call := range macroCalls { - pbMacroCalls[id], _ = ast.ExprToProto(call) - } pruner := &astPruner{ - expr: pbExpr, - macroCalls: pbMacroCalls, - state: pruneState, - nextExprID: getMaxID(pbExpr)} - newPBExpr, _ := pruner.maybePrune(pbExpr) - newExpr, _ := ast.ProtoToExpr(newPBExpr) + ExprFactory: ast.NewExprFactory(), + expr: expr, + macroCalls: macroCalls, + state: pruneState, + nextExprID: getMaxID(expr)} + newExpr, _ := pruner.maybePrune(expr) newInfo := ast.NewSourceInfo(nil) - for id, pbCall := range pruner.macroCalls { - call, _ := ast.ProtoToExpr(pbCall) + for id, call := range pruner.macroCalls { newInfo.SetMacroCall(id, call) } return ast.NewAST(newExpr, newInfo) } -func (p *astPruner) createLiteral(id int64, val *exprpb.Constant) *exprpb.Expr { - return &exprpb.Expr{ - Id: id, - ExprKind: &exprpb.Expr_ConstExpr{ - ConstExpr: val, - }, - } -} - -func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, bool) { +func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) { switch v := val.(type) { - case types.Bool: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BoolValue{BoolValue: bool(v)}}), true - case types.Bytes: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_BytesValue{BytesValue: []byte(v)}}), true - case types.Double: + case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint: p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_DoubleValue{DoubleValue: float64(v)}}), true + return p.NewLiteral(id, val), true case types.Duration: p.state.SetValue(id, val) - durationString := string(v.ConvertToType(types.StringType).(types.String)) - return &exprpb.Expr{ - Id: id, - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: &exprpb.Expr_Call{ - Function: overloads.TypeConvertDuration, - Args: []*exprpb.Expr{ - p.createLiteral(p.nextID(), - &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: durationString}}), - }, - }, - }, - }, true - case types.Int: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_Int64Value{Int64Value: int64(v)}}), true - case types.Uint: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_Uint64Value{Uint64Value: uint64(v)}}), true - case types.String: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_StringValue{StringValue: string(v)}}), true - case types.Null: - p.state.SetValue(id, val) - return p.createLiteral(id, - &exprpb.Constant{ConstantKind: &exprpb.Constant_NullValue{NullValue: v.Value().(structpb.NullValue)}}), true + durationString := v.ConvertToType(types.StringType).(types.String) + return p.NewCall(id, overloads.TypeConvertDuration, p.NewLiteral(p.nextID(), durationString)), true + case types.Timestamp: + timestampString := v.ConvertToType(types.StringType).(types.String) + return p.NewCall(id, overloads.TypeConvertTimestamp, p.NewLiteral(p.nextID(), timestampString)), true } // Attempt to build a list literal. if list, isList := val.(traits.Lister); isList { sz := list.Size().(types.Int) - elemExprs := make([]*exprpb.Expr, sz) + elemExprs := make([]ast.Expr, sz) for i := types.Int(0); i < sz; i++ { elem := list.Get(i) if types.IsUnknownOrError(elem) { @@ -166,20 +116,13 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo elemExprs[i] = elemExpr } p.state.SetValue(id, val) - return &exprpb.Expr{ - Id: id, - ExprKind: &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - Elements: elemExprs, - }, - }, - }, true + return p.NewList(id, elemExprs, []int32{}), true } // Create a map literal if possible. if mp, isMap := val.(traits.Mapper); isMap { it := mp.Iterator() - entries := make([]*exprpb.Expr_CreateStruct_Entry, mp.Size().(types.Int)) + entries := make([]ast.EntryExpr, mp.Size().(types.Int)) i := 0 for it.HasNext() != types.False { key := it.Next() @@ -195,25 +138,12 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo if !ok { return nil, false } - entry := &exprpb.Expr_CreateStruct_Entry{ - Id: p.nextID(), - KeyKind: &exprpb.Expr_CreateStruct_Entry_MapKey{ - MapKey: keyExpr, - }, - Value: valExpr, - } + entry := p.NewMapEntry(p.nextID(), keyExpr, valExpr, false) entries[i] = entry i++ } p.state.SetValue(id, val) - return &exprpb.Expr{ - Id: id, - ExprKind: &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{ - Entries: entries, - }, - }, - }, true + return p.NewMap(id, entries), true } // TODO(issues/377) To construct message literals, the type provider will need to support @@ -221,215 +151,206 @@ func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (*exprpb.Expr, boo return nil, false } -func (p *astPruner) maybePruneOptional(elem *exprpb.Expr) (*exprpb.Expr, bool) { - elemVal, found := p.value(elem.GetId()) +func (p *astPruner) maybePruneOptional(elem ast.Expr) (ast.Expr, bool) { + elemVal, found := p.value(elem.ID()) if found && elemVal.Type() == types.OptionalType { opt := elemVal.(*types.Optional) if !opt.HasValue() { return nil, true } - if newElem, pruned := p.maybeCreateLiteral(elem.GetId(), opt.GetValue()); pruned { + if newElem, pruned := p.maybeCreateLiteral(elem.ID(), opt.GetValue()); pruned { return newElem, true } } return elem, false } -func (p *astPruner) maybePruneIn(node *exprpb.Expr) (*exprpb.Expr, bool) { +func (p *astPruner) maybePruneIn(node ast.Expr) (ast.Expr, bool) { // elem in list - call := node.GetCallExpr() - val, exists := p.maybeValue(call.GetArgs()[1].GetId()) + call := node.AsCall() + val, exists := p.maybeValue(call.Args()[1].ID()) if !exists { return nil, false } if sz, ok := val.(traits.Sizer); ok && sz.Size() == types.IntZero { - return p.maybeCreateLiteral(node.GetId(), types.False) + return p.maybeCreateLiteral(node.ID(), types.False) } return nil, false } -func (p *astPruner) maybePruneLogicalNot(node *exprpb.Expr) (*exprpb.Expr, bool) { - call := node.GetCallExpr() - arg := call.GetArgs()[0] - val, exists := p.maybeValue(arg.GetId()) +func (p *astPruner) maybePruneLogicalNot(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + arg := call.Args()[0] + val, exists := p.maybeValue(arg.ID()) if !exists { return nil, false } if b, ok := val.(types.Bool); ok { - return p.maybeCreateLiteral(node.GetId(), !b) + return p.maybeCreateLiteral(node.ID(), !b) } return nil, false } -func (p *astPruner) maybePruneOr(node *exprpb.Expr) (*exprpb.Expr, bool) { - call := node.GetCallExpr() +func (p *astPruner) maybePruneOr(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. - if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { + if v, exists := p.maybeValue(call.Args()[0].ID()); exists { if v == types.True { - return p.maybeCreateLiteral(node.GetId(), types.True) + return p.maybeCreateLiteral(node.ID(), types.True) } - return call.GetArgs()[1], true + return call.Args()[1], true } - if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { + if v, exists := p.maybeValue(call.Args()[1].ID()); exists { if v == types.True { - return p.maybeCreateLiteral(node.GetId(), types.True) + return p.maybeCreateLiteral(node.ID(), types.True) } - return call.GetArgs()[0], true + return call.Args()[0], true } return nil, false } -func (p *astPruner) maybePruneAnd(node *exprpb.Expr) (*exprpb.Expr, bool) { - call := node.GetCallExpr() +func (p *astPruner) maybePruneAnd(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() // We know result is unknown, so we have at least one unknown arg // and if one side is a known value, we know we can ignore it. - if v, exists := p.maybeValue(call.GetArgs()[0].GetId()); exists { + if v, exists := p.maybeValue(call.Args()[0].ID()); exists { if v == types.False { - return p.maybeCreateLiteral(node.GetId(), types.False) + return p.maybeCreateLiteral(node.ID(), types.False) } - return call.GetArgs()[1], true + return call.Args()[1], true } - if v, exists := p.maybeValue(call.GetArgs()[1].GetId()); exists { + if v, exists := p.maybeValue(call.Args()[1].ID()); exists { if v == types.False { - return p.maybeCreateLiteral(node.GetId(), types.False) + return p.maybeCreateLiteral(node.ID(), types.False) } - return call.GetArgs()[0], true + return call.Args()[0], true } return nil, false } -func (p *astPruner) maybePruneConditional(node *exprpb.Expr) (*exprpb.Expr, bool) { - call := node.GetCallExpr() - cond, exists := p.maybeValue(call.GetArgs()[0].GetId()) +func (p *astPruner) maybePruneConditional(node ast.Expr) (ast.Expr, bool) { + call := node.AsCall() + cond, exists := p.maybeValue(call.Args()[0].ID()) if !exists { return nil, false } if cond.Value().(bool) { - return call.GetArgs()[1], true + return call.Args()[1], true } - return call.GetArgs()[2], true + return call.Args()[2], true } -func (p *astPruner) maybePruneFunction(node *exprpb.Expr) (*exprpb.Expr, bool) { - if _, exists := p.value(node.GetId()); !exists { +func (p *astPruner) maybePruneFunction(node ast.Expr) (ast.Expr, bool) { + if _, exists := p.value(node.ID()); !exists { return nil, false } - call := node.GetCallExpr() - if call.Function == operators.LogicalOr { + call := node.AsCall() + if call.FunctionName() == operators.LogicalOr { return p.maybePruneOr(node) } - if call.Function == operators.LogicalAnd { + if call.FunctionName() == operators.LogicalAnd { return p.maybePruneAnd(node) } - if call.Function == operators.Conditional { + if call.FunctionName() == operators.Conditional { return p.maybePruneConditional(node) } - if call.Function == operators.In { + if call.FunctionName() == operators.In { return p.maybePruneIn(node) } - if call.Function == operators.LogicalNot { + if call.FunctionName() == operators.LogicalNot { return p.maybePruneLogicalNot(node) } return nil, false } -func (p *astPruner) maybePrune(node *exprpb.Expr) (*exprpb.Expr, bool) { +func (p *astPruner) maybePrune(node ast.Expr) (ast.Expr, bool) { return p.prune(node) } -func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { +func (p *astPruner) prune(node ast.Expr) (ast.Expr, bool) { if node == nil { return node, false } - val, valueExists := p.maybeValue(node.GetId()) + val, valueExists := p.maybeValue(node.ID()) if valueExists { - if newNode, ok := p.maybeCreateLiteral(node.GetId(), val); ok { - delete(p.macroCalls, node.GetId()) + if newNode, ok := p.maybeCreateLiteral(node.ID(), val); ok { + delete(p.macroCalls, node.ID()) return newNode, true } } - if macro, found := p.macroCalls[node.GetId()]; found { + if macro, found := p.macroCalls[node.ID()]; found { // Ensure that intermediate values for the comprehension are cleared during pruning - compre := node.GetComprehensionExpr() - if compre != nil { - visit(macro, clearIterVarVisitor(compre.IterVar, p.state)) + if node.Kind() == ast.ComprehensionKind { + compre := node.AsComprehension() + visit(macro, clearIterVarVisitor(compre.IterVar(), p.state)) } // prune the expression in terms of the macro call instead of the expanded form. if newMacro, pruned := p.prune(macro); pruned { - p.macroCalls[node.GetId()] = newMacro + p.macroCalls[node.ID()] = newMacro } } // We have either an unknown/error value, or something we don't want to // transform, or expression was not evaluated. If possible, drill down // more. - switch node.GetExprKind().(type) { - case *exprpb.Expr_SelectExpr: - if operand, pruned := p.maybePrune(node.GetSelectExpr().GetOperand()); pruned { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_SelectExpr{ - SelectExpr: &exprpb.Expr_Select{ - Operand: operand, - Field: node.GetSelectExpr().GetField(), - TestOnly: node.GetSelectExpr().GetTestOnly(), - }, - }, - }, true - } - case *exprpb.Expr_CallExpr: - var prunedCall bool - call := node.GetCallExpr() - args := call.GetArgs() - newArgs := make([]*exprpb.Expr, len(args)) - newCall := &exprpb.Expr_Call{ - Function: call.GetFunction(), - Target: call.GetTarget(), - Args: newArgs, - } - for i, arg := range args { - newArgs[i] = arg - if newArg, prunedArg := p.maybePrune(arg); prunedArg { - prunedCall = true - newArgs[i] = newArg + switch node.Kind() { + case ast.SelectKind: + sel := node.AsSelect() + if operand, isPruned := p.maybePrune(sel.Operand()); isPruned { + if sel.IsTestOnly() { + return p.NewPresenceTest(node.ID(), operand, sel.FieldName()), true } + return p.NewSelect(node.ID(), operand, sel.FieldName()), true } - if newTarget, prunedTarget := p.maybePrune(call.GetTarget()); prunedTarget { - prunedCall = true - newCall.Target = newTarget + case ast.CallKind: + argsPruned := false + call := node.AsCall() + args := call.Args() + newArgs := make([]ast.Expr, len(args)) + for i, a := range args { + newArgs[i] = a + if arg, isPruned := p.maybePrune(a); isPruned { + argsPruned = true + newArgs[i] = arg + } } - newNode := &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_CallExpr{ - CallExpr: newCall, - }, + if !call.IsMemberFunction() { + newCall := p.NewCall(node.ID(), call.FunctionName(), newArgs...) + if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { + return prunedCall, true + } + return newCall, argsPruned } - if newExpr, pruned := p.maybePruneFunction(newNode); pruned { - newExpr, _ = p.maybePrune(newExpr) - return newExpr, true + newTarget := call.Target() + targetPruned := false + if prunedTarget, isPruned := p.maybePrune(call.Target()); isPruned { + targetPruned = true + newTarget = prunedTarget } - if prunedCall { - return newNode, true + newCall := p.NewMemberCall(node.ID(), call.FunctionName(), newTarget, newArgs...) + if prunedCall, isPruned := p.maybePruneFunction(newCall); isPruned { + return prunedCall, true } - case *exprpb.Expr_ListExpr: - elems := node.GetListExpr().GetElements() - optIndices := node.GetListExpr().GetOptionalIndices() + return newCall, targetPruned || argsPruned + case ast.ListKind: + l := node.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() optIndexMap := map[int32]bool{} for _, i := range optIndices { optIndexMap[i] = true } newOptIndexMap := make(map[int32]bool, len(optIndexMap)) - newElems := make([]*exprpb.Expr, 0, len(elems)) - var prunedList bool - + newElems := make([]ast.Expr, 0, len(elems)) + var listPruned bool prunedIdx := 0 for i, elem := range elems { _, isOpt := optIndexMap[int32(i)] if isOpt { newElem, pruned := p.maybePruneOptional(elem) if pruned { - prunedList = true + listPruned = true if newElem != nil { newElems = append(newElems, newElem) prunedIdx++ @@ -440,7 +361,7 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { } if newElem, prunedElem := p.maybePrune(elem); prunedElem { newElems = append(newElems, newElem) - prunedList = true + listPruned = true } else { newElems = append(newElems, elem) } @@ -452,76 +373,64 @@ func (p *astPruner) prune(node *exprpb.Expr) (*exprpb.Expr, bool) { optIndices[idx] = i idx++ } - if prunedList { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_ListExpr{ - ListExpr: &exprpb.Expr_CreateList{ - Elements: newElems, - OptionalIndices: optIndices, - }, - }, - }, true + if listPruned { + return p.NewList(node.ID(), newElems, optIndices), true } - case *exprpb.Expr_StructExpr: - var prunedStruct bool - entries := node.GetStructExpr().GetEntries() - messageType := node.GetStructExpr().GetMessageName() - newEntries := make([]*exprpb.Expr_CreateStruct_Entry, len(entries)) + case ast.MapKind: + var mapPruned bool + m := node.AsMap() + entries := m.Entries() + newEntries := make([]ast.EntryExpr, len(entries)) for i, entry := range entries { newEntries[i] = entry - newKey, prunedKey := p.maybePrune(entry.GetMapKey()) - newValue, prunedValue := p.maybePrune(entry.GetValue()) - if !prunedKey && !prunedValue { + e := entry.AsMapEntry() + newKey, keyPruned := p.maybePrune(e.Key()) + newValue, valuePruned := p.maybePrune(e.Value()) + if !keyPruned && !valuePruned { continue } - prunedStruct = true - newEntry := &exprpb.Expr_CreateStruct_Entry{ - Value: newValue, - } - if messageType != "" { - newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_FieldKey{ - FieldKey: entry.GetFieldKey(), - } - } else { - newEntry.KeyKind = &exprpb.Expr_CreateStruct_Entry_MapKey{ - MapKey: newKey, - } - } - newEntry.OptionalEntry = entry.GetOptionalEntry() + mapPruned = true + newEntry := p.NewMapEntry(entry.ID(), newKey, newValue, e.IsOptional()) newEntries[i] = newEntry } - if prunedStruct { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_StructExpr{ - StructExpr: &exprpb.Expr_CreateStruct{ - MessageName: messageType, - Entries: newEntries, - }, - }, - }, true + if mapPruned { + return p.NewMap(node.ID(), newEntries), true } - case *exprpb.Expr_ComprehensionExpr: - compre := node.GetComprehensionExpr() + case ast.StructKind: + var structPruned bool + obj := node.AsStruct() + fields := obj.Fields() + newFields := make([]ast.EntryExpr, len(fields)) + for i, field := range fields { + newFields[i] = field + f := field.AsStructField() + newValue, prunedValue := p.maybePrune(f.Value()) + if !prunedValue { + continue + } + structPruned = true + newEntry := p.NewStructField(field.ID(), f.Name(), newValue, f.IsOptional()) + newFields[i] = newEntry + } + if structPruned { + return p.NewStruct(node.ID(), obj.TypeName(), newFields), true + } + case ast.ComprehensionKind: + compre := node.AsComprehension() // Only the range of the comprehension is pruned since the state tracking only records // the last iteration of the comprehension and not each step in the evaluation which // means that the any residuals computed in between might be inaccurate. - if newRange, pruned := p.maybePrune(compre.GetIterRange()); pruned { - return &exprpb.Expr{ - Id: node.GetId(), - ExprKind: &exprpb.Expr_ComprehensionExpr{ - ComprehensionExpr: &exprpb.Expr_Comprehension{ - IterVar: compre.GetIterVar(), - IterRange: newRange, - AccuVar: compre.GetAccuVar(), - AccuInit: compre.GetAccuInit(), - LoopCondition: compre.GetLoopCondition(), - LoopStep: compre.GetLoopStep(), - Result: compre.GetResult(), - }, - }, - }, true + if newRange, pruned := p.maybePrune(compre.IterRange()); pruned { + return p.NewComprehension( + node.ID(), + newRange, + compre.IterVar(), + compre.AccuVar(), + compre.AccuInit(), + compre.LoopCondition(), + compre.LoopStep(), + compre.Result(), + ), true } } return node, false @@ -548,12 +457,12 @@ func (p *astPruner) nextID() int64 { type astVisitor struct { // visitEntry is called on every expr node, including those within a map/struct entry. - visitExpr func(expr *exprpb.Expr) + visitExpr func(expr ast.Expr) // visitEntry is called before entering the key, value of a map/struct entry. - visitEntry func(entry *exprpb.Expr_CreateStruct_Entry) + visitEntry func(entry ast.EntryExpr) } -func getMaxID(expr *exprpb.Expr) int64 { +func getMaxID(expr ast.Expr) int64 { maxID := int64(1) visit(expr, maxIDVisitor(&maxID)) return maxID @@ -561,10 +470,9 @@ func getMaxID(expr *exprpb.Expr) int64 { func clearIterVarVisitor(varName string, state EvalState) astVisitor { return astVisitor{ - visitExpr: func(e *exprpb.Expr) { - ident := e.GetIdentExpr() - if ident != nil && ident.GetName() == varName { - state.SetValue(e.GetId(), nil) + visitExpr: func(e ast.Expr) { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + state.SetValue(e.ID(), nil) } }, } @@ -572,56 +480,63 @@ func clearIterVarVisitor(varName string, state EvalState) astVisitor { func maxIDVisitor(maxID *int64) astVisitor { return astVisitor{ - visitExpr: func(e *exprpb.Expr) { - if e.GetId() >= *maxID { - *maxID = e.GetId() + 1 + visitExpr: func(e ast.Expr) { + if e.ID() >= *maxID { + *maxID = e.ID() + 1 } }, - visitEntry: func(e *exprpb.Expr_CreateStruct_Entry) { - if e.GetId() >= *maxID { - *maxID = e.GetId() + 1 + visitEntry: func(e ast.EntryExpr) { + if e.ID() >= *maxID { + *maxID = e.ID() + 1 } }, } } -func visit(expr *exprpb.Expr, visitor astVisitor) { - exprs := []*exprpb.Expr{expr} +func visit(expr ast.Expr, visitor astVisitor) { + exprs := []ast.Expr{expr} for len(exprs) != 0 { e := exprs[0] if visitor.visitExpr != nil { visitor.visitExpr(e) } exprs = exprs[1:] - switch e.GetExprKind().(type) { - case *exprpb.Expr_SelectExpr: - exprs = append(exprs, e.GetSelectExpr().GetOperand()) - case *exprpb.Expr_CallExpr: - call := e.GetCallExpr() - if call.GetTarget() != nil { - exprs = append(exprs, call.GetTarget()) + switch e.Kind() { + case ast.SelectKind: + exprs = append(exprs, e.AsSelect().Operand()) + case ast.CallKind: + call := e.AsCall() + if call.Target() != nil { + exprs = append(exprs, call.Target()) } - exprs = append(exprs, call.GetArgs()...) - case *exprpb.Expr_ComprehensionExpr: - compre := e.GetComprehensionExpr() + exprs = append(exprs, call.Args()...) + case ast.ComprehensionKind: + compre := e.AsComprehension() exprs = append(exprs, - compre.GetIterRange(), - compre.GetAccuInit(), - compre.GetLoopCondition(), - compre.GetLoopStep(), - compre.GetResult()) - case *exprpb.Expr_ListExpr: - list := e.GetListExpr() - exprs = append(exprs, list.GetElements()...) - case *exprpb.Expr_StructExpr: - for _, entry := range e.GetStructExpr().GetEntries() { + compre.IterRange(), + compre.AccuInit(), + compre.LoopCondition(), + compre.LoopStep(), + compre.Result()) + case ast.ListKind: + list := e.AsList() + exprs = append(exprs, list.Elements()...) + case ast.MapKind: + for _, entry := range e.AsMap().Entries() { + e := entry.AsMapEntry() if visitor.visitEntry != nil { visitor.visitEntry(entry) } - if entry.GetMapKey() != nil { - exprs = append(exprs, entry.GetMapKey()) + exprs = append(exprs, e.Key()) + exprs = append(exprs, e.Value()) + } + case ast.StructKind: + for _, entry := range e.AsStruct().Fields() { + f := entry.AsStructField() + if visitor.visitEntry != nil { + visitor.visitEntry(entry) } - exprs = append(exprs, entry.GetValue()) + exprs = append(exprs, f.Value()) } } } diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index c32ca0f39..d7bfb750d 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -283,11 +283,11 @@ var testCases = []testInfo{ }, { expr: `[timestamp(0), timestamp(1)]`, - out: `[timestamp(0), timestamp(1)]`, + out: `[timestamp("1970-01-01T00:00:00Z"), timestamp("1970-01-01T00:00:01Z")]`, }, { expr: `{"epoch": timestamp(0)}`, - out: `{"epoch": timestamp(0)}`, + out: `{"epoch": timestamp("1970-01-01T00:00:00Z")}`, }, { in: partialActivation(map[string]any{"x": false}, "y"), From 26aa36760a6ef8b8fbb3aed0792671d7effda264 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 18 Aug 2023 12:50:41 -0400 Subject: [PATCH 19/31] Migrate the parser.Unparse to the Go-native Expr (#800) Migrate unparser internals to native Expr type --- parser/unparser.go | 232 +++++++++++++++++++++------------------------ 1 file changed, 108 insertions(+), 124 deletions(-) diff --git a/parser/unparser.go b/parser/unparser.go index 4c4d125b7..91cf72944 100644 --- a/parser/unparser.go +++ b/parser/unparser.go @@ -22,8 +22,7 @@ import ( "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/operators" - - exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/google/cel-go/common/types" ) // Unparse takes an input expression and source position information and generates a human-readable @@ -41,20 +40,13 @@ import ( // This function optionally takes in one or more UnparserOption to alter the unparsing behavior, such as // performing word wrapping on expressions. func Unparse(expr ast.Expr, info *ast.SourceInfo, opts ...UnparserOption) (string, error) { - pbExpr, err := ast.ExprToProto(expr) - if err != nil { - return "", err - } - pbInfo, err := ast.SourceInfoToProto(info) - if err != nil { - return "", err - } unparserOpts := &unparserOption{ wrapOnColumn: defaultWrapOnColumn, wrapAfterColumnLimit: defaultWrapAfterColumnLimit, operatorsToWrapOn: defaultOperatorsToWrapOn, } + var err error for _, opt := range opts { unparserOpts, err = opt(unparserOpts) if err != nil { @@ -63,10 +55,10 @@ func Unparse(expr ast.Expr, info *ast.SourceInfo, opts ...UnparserOption) (strin } un := &unparser{ - info: pbInfo, + info: info, options: unparserOpts, } - err = un.visit(pbExpr) + err = un.visit(expr) if err != nil { return "", err } @@ -76,12 +68,12 @@ func Unparse(expr ast.Expr, info *ast.SourceInfo, opts ...UnparserOption) (strin // unparser visits an expression to reconstruct a human-readable string from an AST. type unparser struct { str strings.Builder - info *exprpb.SourceInfo + info *ast.SourceInfo options *unparserOption lastWrappedIndex int } -func (un *unparser) visit(expr *exprpb.Expr) error { +func (un *unparser) visit(expr ast.Expr) error { if expr == nil { return errors.New("unsupported expression") } @@ -89,27 +81,29 @@ func (un *unparser) visit(expr *exprpb.Expr) error { if visited || err != nil { return err } - switch expr.GetExprKind().(type) { - case *exprpb.Expr_CallExpr: + switch expr.Kind() { + case ast.CallKind: return un.visitCall(expr) - case *exprpb.Expr_ConstExpr: + case ast.LiteralKind: return un.visitConst(expr) - case *exprpb.Expr_IdentExpr: + case ast.IdentKind: return un.visitIdent(expr) - case *exprpb.Expr_ListExpr: + case ast.ListKind: return un.visitList(expr) - case *exprpb.Expr_SelectExpr: + case ast.MapKind: + return un.visitStructMap(expr) + case ast.SelectKind: return un.visitSelect(expr) - case *exprpb.Expr_StructExpr: - return un.visitStruct(expr) + case ast.StructKind: + return un.visitStructMsg(expr) default: return fmt.Errorf("unsupported expression: %v", expr) } } -func (un *unparser) visitCall(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - fun := c.GetFunction() +func (un *unparser) visitCall(expr ast.Expr) error { + c := expr.AsCall() + fun := c.FunctionName() switch fun { // ternary operator case operators.Conditional: @@ -149,10 +143,10 @@ func (un *unparser) visitCall(expr *exprpb.Expr) error { } } -func (un *unparser) visitCallBinary(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - fun := c.GetFunction() - args := c.GetArgs() +func (un *unparser) visitCallBinary(expr ast.Expr) error { + c := expr.AsCall() + fun := c.FunctionName() + args := c.Args() lhs := args[0] // add parens if the current operator is lower precedence than the lhs expr operator. lhsParen := isComplexOperatorWithRespectTo(fun, lhs) @@ -176,9 +170,9 @@ func (un *unparser) visitCallBinary(expr *exprpb.Expr) error { return un.visitMaybeNested(rhs, rhsParen) } -func (un *unparser) visitCallConditional(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - args := c.GetArgs() +func (un *unparser) visitCallConditional(expr ast.Expr) error { + c := expr.AsCall() + args := c.Args() // add parens if operand is a conditional itself. nested := isSamePrecedence(operators.Conditional, args[0]) || isComplexOperator(args[0]) @@ -204,13 +198,13 @@ func (un *unparser) visitCallConditional(expr *exprpb.Expr) error { return un.visitMaybeNested(args[2], nested) } -func (un *unparser) visitCallFunc(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - fun := c.GetFunction() - args := c.GetArgs() - if c.GetTarget() != nil { - nested := isBinaryOrTernaryOperator(c.GetTarget()) - err := un.visitMaybeNested(c.GetTarget(), nested) +func (un *unparser) visitCallFunc(expr ast.Expr) error { + c := expr.AsCall() + fun := c.FunctionName() + args := c.Args() + if c.IsMemberFunction() { + nested := isBinaryOrTernaryOperator(c.Target()) + err := un.visitMaybeNested(c.Target(), nested) if err != nil { return err } @@ -231,17 +225,17 @@ func (un *unparser) visitCallFunc(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitCallIndex(expr *exprpb.Expr) error { +func (un *unparser) visitCallIndex(expr ast.Expr) error { return un.visitCallIndexInternal(expr, "[") } -func (un *unparser) visitCallOptIndex(expr *exprpb.Expr) error { +func (un *unparser) visitCallOptIndex(expr ast.Expr) error { return un.visitCallIndexInternal(expr, "[?") } -func (un *unparser) visitCallIndexInternal(expr *exprpb.Expr, op string) error { - c := expr.GetCallExpr() - args := c.GetArgs() +func (un *unparser) visitCallIndexInternal(expr ast.Expr, op string) error { + c := expr.AsCall() + args := c.Args() nested := isBinaryOrTernaryOperator(args[0]) err := un.visitMaybeNested(args[0], nested) if err != nil { @@ -256,10 +250,10 @@ func (un *unparser) visitCallIndexInternal(expr *exprpb.Expr, op string) error { return nil } -func (un *unparser) visitCallUnary(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - fun := c.GetFunction() - args := c.GetArgs() +func (un *unparser) visitCallUnary(expr ast.Expr) error { + c := expr.AsCall() + fun := c.FunctionName() + args := c.Args() unmangled, found := operators.FindReverse(fun) if !found { return fmt.Errorf("cannot unmangle operator: %s", fun) @@ -269,35 +263,34 @@ func (un *unparser) visitCallUnary(expr *exprpb.Expr) error { return un.visitMaybeNested(args[0], nested) } -func (un *unparser) visitConst(expr *exprpb.Expr) error { - c := expr.GetConstExpr() - switch c.GetConstantKind().(type) { - case *exprpb.Constant_BoolValue: - un.str.WriteString(strconv.FormatBool(c.GetBoolValue())) - case *exprpb.Constant_BytesValue: +func (un *unparser) visitConst(expr ast.Expr) error { + val := expr.AsLiteral() + switch val := val.(type) { + case types.Bool: + un.str.WriteString(strconv.FormatBool(bool(val))) + case types.Bytes: // bytes constants are surrounded with b"" - b := c.GetBytesValue() un.str.WriteString(`b"`) - un.str.WriteString(bytesToOctets(b)) + un.str.WriteString(bytesToOctets([]byte(val))) un.str.WriteString(`"`) - case *exprpb.Constant_DoubleValue: + case types.Double: // represent the float using the minimum required digits - d := strconv.FormatFloat(c.GetDoubleValue(), 'g', -1, 64) + d := strconv.FormatFloat(float64(val), 'g', -1, 64) un.str.WriteString(d) if !strings.Contains(d, ".") { un.str.WriteString(".0") } - case *exprpb.Constant_Int64Value: - i := strconv.FormatInt(c.GetInt64Value(), 10) + case types.Int: + i := strconv.FormatInt(int64(val), 10) un.str.WriteString(i) - case *exprpb.Constant_NullValue: + case types.Null: un.str.WriteString("null") - case *exprpb.Constant_StringValue: + case types.String: // strings will be double quoted with quotes escaped. - un.str.WriteString(strconv.Quote(c.GetStringValue())) - case *exprpb.Constant_Uint64Value: + un.str.WriteString(strconv.Quote(string(val))) + case types.Uint: // uint literals have a 'u' suffix. - ui := strconv.FormatUint(c.GetUint64Value(), 10) + ui := strconv.FormatUint(uint64(val), 10) un.str.WriteString(ui) un.str.WriteString("u") default: @@ -306,16 +299,16 @@ func (un *unparser) visitConst(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitIdent(expr *exprpb.Expr) error { - un.str.WriteString(expr.GetIdentExpr().GetName()) +func (un *unparser) visitIdent(expr ast.Expr) error { + un.str.WriteString(expr.AsIdent()) return nil } -func (un *unparser) visitList(expr *exprpb.Expr) error { - l := expr.GetListExpr() - elems := l.GetElements() +func (un *unparser) visitList(expr ast.Expr) error { + l := expr.AsList() + elems := l.Elements() optIndices := make(map[int]bool, len(elems)) - for _, idx := range l.GetOptionalIndices() { + for _, idx := range l.OptionalIndices() { optIndices[int(idx)] = true } un.str.WriteString("[") @@ -335,20 +328,20 @@ func (un *unparser) visitList(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitOptSelect(expr *exprpb.Expr) error { - c := expr.GetCallExpr() - args := c.GetArgs() +func (un *unparser) visitOptSelect(expr ast.Expr) error { + c := expr.AsCall() + args := c.Args() operand := args[0] - field := args[1].GetConstExpr().GetStringValue() - return un.visitSelectInternal(operand, false, ".?", field) + field := args[1].AsLiteral().(types.String) + return un.visitSelectInternal(operand, false, ".?", string(field)) } -func (un *unparser) visitSelect(expr *exprpb.Expr) error { - sel := expr.GetSelectExpr() - return un.visitSelectInternal(sel.GetOperand(), sel.GetTestOnly(), ".", sel.GetField()) +func (un *unparser) visitSelect(expr ast.Expr) error { + sel := expr.AsSelect() + return un.visitSelectInternal(sel.Operand(), sel.IsTestOnly(), ".", sel.FieldName()) } -func (un *unparser) visitSelectInternal(operand *exprpb.Expr, testOnly bool, op string, field string) error { +func (un *unparser) visitSelectInternal(operand ast.Expr, testOnly bool, op string, field string) error { // handle the case when the select expression was generated by the has() macro. if testOnly { un.str.WriteString("has(") @@ -366,34 +359,25 @@ func (un *unparser) visitSelectInternal(operand *exprpb.Expr, testOnly bool, op return nil } -func (un *unparser) visitStruct(expr *exprpb.Expr) error { - s := expr.GetStructExpr() - // If the message name is non-empty, then this should be treated as message construction. - if s.GetMessageName() != "" { - return un.visitStructMsg(expr) - } - // Otherwise, build a map. - return un.visitStructMap(expr) -} - -func (un *unparser) visitStructMsg(expr *exprpb.Expr) error { - m := expr.GetStructExpr() - entries := m.GetEntries() - un.str.WriteString(m.GetMessageName()) +func (un *unparser) visitStructMsg(expr ast.Expr) error { + m := expr.AsStruct() + fields := m.Fields() + un.str.WriteString(m.TypeName()) un.str.WriteString("{") - for i, entry := range entries { - f := entry.GetFieldKey() - if entry.GetOptionalEntry() { + for i, f := range fields { + field := f.AsStructField() + f := field.Name() + if field.IsOptional() { un.str.WriteString("?") } un.str.WriteString(f) un.str.WriteString(": ") - v := entry.GetValue() + v := field.Value() err := un.visit(v) if err != nil { return err } - if i < len(entries)-1 { + if i < len(fields)-1 { un.str.WriteString(", ") } } @@ -401,13 +385,14 @@ func (un *unparser) visitStructMsg(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitStructMap(expr *exprpb.Expr) error { - m := expr.GetStructExpr() - entries := m.GetEntries() +func (un *unparser) visitStructMap(expr ast.Expr) error { + m := expr.AsMap() + entries := m.Entries() un.str.WriteString("{") - for i, entry := range entries { - k := entry.GetMapKey() - if entry.GetOptionalEntry() { + for i, e := range entries { + entry := e.AsMapEntry() + k := entry.Key() + if entry.IsOptional() { un.str.WriteString("?") } err := un.visit(k) @@ -415,7 +400,7 @@ func (un *unparser) visitStructMap(expr *exprpb.Expr) error { return err } un.str.WriteString(": ") - v := entry.GetValue() + v := entry.Value() err = un.visit(v) if err != nil { return err @@ -428,16 +413,15 @@ func (un *unparser) visitStructMap(expr *exprpb.Expr) error { return nil } -func (un *unparser) visitMaybeMacroCall(expr *exprpb.Expr) (bool, error) { - macroCalls := un.info.GetMacroCalls() - call, found := macroCalls[expr.GetId()] +func (un *unparser) visitMaybeMacroCall(expr ast.Expr) (bool, error) { + call, found := un.info.GetMacroCall(expr.ID()) if !found { return false, nil } return true, un.visit(call) } -func (un *unparser) visitMaybeNested(expr *exprpb.Expr, nested bool) error { +func (un *unparser) visitMaybeNested(expr ast.Expr, nested bool) error { if nested { un.str.WriteString("(") } @@ -461,12 +445,12 @@ func isLeftRecursive(op string) bool { // precedence of the (possible) operation represented in the input Expr. // // If the expr is not a Call, the result is false. -func isSamePrecedence(op string, expr *exprpb.Expr) bool { - if expr.GetCallExpr() == nil { +func isSamePrecedence(op string, expr ast.Expr) bool { + if expr.Kind() != ast.CallKind { return false } - c := expr.GetCallExpr() - other := c.GetFunction() + c := expr.AsCall() + other := c.FunctionName() return operators.Precedence(op) == operators.Precedence(other) } @@ -474,16 +458,16 @@ func isSamePrecedence(op string, expr *exprpb.Expr) bool { // than the (possible) operation represented in the input Expr. // // If the expr is not a Call, the result is false. -func isLowerPrecedence(op string, expr *exprpb.Expr) bool { - c := expr.GetCallExpr() - other := c.GetFunction() +func isLowerPrecedence(op string, expr ast.Expr) bool { + c := expr.AsCall() + other := c.FunctionName() return operators.Precedence(op) < operators.Precedence(other) } // Indicates whether the expr is a complex operator, i.e., a call expression // with 2 or more arguments. -func isComplexOperator(expr *exprpb.Expr) bool { - if expr.GetCallExpr() != nil && len(expr.GetCallExpr().GetArgs()) >= 2 { +func isComplexOperator(expr ast.Expr) bool { + if expr.Kind() == ast.CallKind && len(expr.AsCall().Args()) >= 2 { return true } return false @@ -492,19 +476,19 @@ func isComplexOperator(expr *exprpb.Expr) bool { // Indicates whether it is a complex operation compared to another. // expr is *not* considered complex if it is not a call expression or has // less than two arguments, or if it has a higher precedence than op. -func isComplexOperatorWithRespectTo(op string, expr *exprpb.Expr) bool { - if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 { +func isComplexOperatorWithRespectTo(op string, expr ast.Expr) bool { + if expr.Kind() != ast.CallKind || len(expr.AsCall().Args()) < 2 { return false } return isLowerPrecedence(op, expr) } // Indicate whether this is a binary or ternary operator. -func isBinaryOrTernaryOperator(expr *exprpb.Expr) bool { - if expr.GetCallExpr() == nil || len(expr.GetCallExpr().GetArgs()) < 2 { +func isBinaryOrTernaryOperator(expr ast.Expr) bool { + if expr.Kind() != ast.CallKind || len(expr.AsCall().Args()) < 2 { return false } - _, isBinaryOp := operators.FindReverseBinaryOperator(expr.GetCallExpr().GetFunction()) + _, isBinaryOp := operators.FindReverseBinaryOperator(expr.AsCall().FunctionName()) return isBinaryOp || isSamePrecedence(operators.Conditional, expr) } From eaebecb272d8b00d75f0851e1e7ebe23b6b595cd Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 18 Aug 2023 18:29:16 -0400 Subject: [PATCH 20/31] FindStructTypeFields support for types.Provider (#814) --- cel/env.go | 7 +++++++ cel/env_test.go | 4 ++++ common/types/provider.go | 21 +++++++++++++++++++++ common/types/provider_test.go | 34 ++++++++++++++++++++++++++++++++++ ext/native.go | 18 ++++++++++++++++++ ext/native_test.go | 35 +++++++++++++++++++++++++++++++++++ 6 files changed, 119 insertions(+) diff --git a/cel/env.go b/cel/env.go index 96b61ad70..473604bc6 100644 --- a/cel/env.go +++ b/cel/env.go @@ -773,6 +773,13 @@ func (p *interopCELTypeProvider) FindStructType(typeName string) (*types.Type, b return nil, false } +// FindStructFieldNames returns an empty set of field for the interop provider. +// +// To inspect the field names, migrate to a `types.Provider` implementation. +func (p *interopCELTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) { + return []string{}, false +} + // FindStructFieldType returns a types.FieldType instance for the given fully-qualified typeName and field // name, if one exists. // diff --git a/cel/env_test.go b/cel/env_test.go index 663727961..ed1f836e8 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -385,6 +385,10 @@ func (p *customCELProvider) FindStructType(typeName string) (*types.Type, bool) return p.provider.FindStructType(typeName) } +func (p *customCELProvider) FindStructFieldNames(typeName string) ([]string, bool) { + return p.provider.FindStructFieldNames(typeName) +} + func (p *customCELProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { return p.provider.FindStructFieldType(structType, fieldName) } diff --git a/common/types/provider.go b/common/types/provider.go index e9bda552c..d301aa38a 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -54,6 +54,10 @@ type Provider interface { // Returns false if not found. FindStructType(structType string) (*Type, bool) + // FindStructFieldNames returns thet field names associated with the type, if the type + // is found. + FindStructFieldNames(structType string) ([]string, bool) + // FieldStructFieldType returns the field type for a checked type value. Returns // false if the field could not be found. FindStructFieldType(structType, fieldName string) (*FieldType, bool) @@ -173,6 +177,23 @@ func (p *Registry) FindFieldType(structType, fieldName string) (*ref.FieldType, GetFrom: field.GetFrom}, true } +// FindStructFieldNames returns the set of field names for the given struct type, +// if the type exists in the registry. +func (p *Registry) FindStructFieldNames(structType string) ([]string, bool) { + msgType, found := p.pbdb.DescribeType(structType) + if !found { + return []string{}, false + } + fieldMap := msgType.FieldMap() + fields := make([]string, len(fieldMap)) + idx := 0 + for f := range fieldMap { + fields[idx] = f + idx++ + } + return fields, true +} + // FindStructFieldType returns the field type for a checked type value. Returns // false if the field could not be found. func (p *Registry) FindStructFieldType(structType, fieldName string) (*FieldType, bool) { diff --git a/common/types/provider_test.go b/common/types/provider_test.go index 3e93afc54..56f15290c 100644 --- a/common/types/provider_test.go +++ b/common/types/provider_test.go @@ -18,6 +18,7 @@ import ( "bytes" "fmt" "reflect" + "sort" "strings" "testing" "time" @@ -132,6 +133,39 @@ func TestRegistryFindStructType(t *testing.T) { } } +func TestRegistryFindStructFieldNames(t *testing.T) { + reg := newTestRegistry(t, &exprpb.Decl{}, &exprpb.Reference{}) + tests := []struct { + typeName string + fields []string + }{ + { + typeName: "google.api.expr.v1alpha1.Reference", + fields: []string{"name", "overload_id", "value"}, + }, + { + typeName: "google.api.expr.v1alpha1.Decl", + fields: []string{"name", "ident", "function"}, + }, + { + typeName: "invalid.TypeName", + fields: []string{}, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) { + fields, _ := reg.FindStructFieldNames(tc.typeName) + sort.Strings(fields) + sort.Strings(tc.fields) + if !reflect.DeepEqual(fields, tc.fields) { + t.Errorf("got %v, wanted %v", fields, tc.fields) + } + }) + } +} + func TestRegistryFindStructFieldType(t *testing.T) { reg := newTestRegistry(t) err := reg.RegisterDescriptor(proto3pb.GlobalEnum_GOO.Descriptor().ParentFile()) diff --git a/ext/native.go b/ext/native.go index 0b5fc38ca..0c2cd52f9 100644 --- a/ext/native.go +++ b/ext/native.go @@ -151,6 +151,24 @@ func (tp *nativeTypeProvider) FindStructType(typeName string) (*types.Type, bool return tp.baseProvider.FindStructType(typeName) } +// FindStructFieldNames looks up the type definition first from the native types, then from +// the backing provider type set. If found, a set of field names corresponding to the type +// will be returned. +func (tp *nativeTypeProvider) FindStructFieldNames(typeName string) ([]string, bool) { + if t, found := tp.nativeTypes[typeName]; found { + fieldCount := t.refType.NumField() + fields := make([]string, fieldCount) + for i := 0; i < fieldCount; i++ { + fields[i] = t.refType.Field(i).Name + } + return fields, true + } + if celTypeFields, found := tp.baseProvider.FindStructFieldNames(typeName); found { + return celTypeFields, true + } + return tp.baseProvider.FindStructFieldNames(typeName) +} + // FindStructFieldType looks up a native type's field definition, and if the type name is not a native // type then proxies to the composed types.Provider func (tp *nativeTypeProvider) FindStructFieldType(typeName, fieldName string) (*types.FieldType, bool) { diff --git a/ext/native_test.go b/ext/native_test.go index 21a678792..ead7bc1c1 100644 --- a/ext/native_test.go +++ b/ext/native_test.go @@ -17,6 +17,7 @@ package ext import ( "fmt" "reflect" + "sort" "strings" "testing" "time" @@ -162,6 +163,40 @@ func TestNativeTypes(t *testing.T) { } } +func TestNativeFindStructFieldNames(t *testing.T) { + env := testNativeEnv(t) + provider := env.CELTypeProvider() + tests := []struct { + typeName string + fields []string + }{ + { + typeName: "ext.TestNestedType", + fields: []string{"NestedListVal", "NestedMapVal"}, + }, + { + typeName: "google.expr.proto3.test.TestAllTypes.NestedMessage", + fields: []string{"bb"}, + }, + { + typeName: "invalid.TypeName", + fields: []string{}, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) { + fields, _ := provider.FindStructFieldNames(tc.typeName) + sort.Strings(fields) + sort.Strings(tc.fields) + if !reflect.DeepEqual(fields, tc.fields) { + t.Errorf("got %v, wanted %v", fields, tc.fields) + } + }) + } +} + func TestNativeTypesStaticErrors(t *testing.T) { var nativeTests = []struct { expr string From 1a6373dee99caa6e41082eb4287260e3a42bb6ee Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 18 Aug 2023 22:35:10 -0400 Subject: [PATCH 21/31] Introduce pre-order / post-order visitor pattern (#813) * Introduce pre-order / post-order visitor pattern --- common/ast/ast.go | 25 ++++++ common/ast/ast_test.go | 48 +++++++--- common/ast/conversion.go | 2 +- common/ast/conversion_test.go | 7 ++ common/ast/factory.go | 18 ++++ common/ast/navigable.go | 163 ++++++++++++++++++++++++++++++---- common/ast/navigable_test.go | 106 ++++++++++++++++++++++ 7 files changed, 341 insertions(+), 28 deletions(-) diff --git a/common/ast/ast.go b/common/ast/ast.go index 0a0cbdc77..7610b4676 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -147,6 +147,13 @@ func Copy(a *AST) *AST { return NewCheckedAST(NewAST(e, CopySourceInfo(a.SourceInfo())), typesCopy, refsCopy) } +// MaxID returns the upper-bound, non-inclusive, of ids present within the AST's Expr value. +func MaxID(a *AST) int64 { + visitor := &maxIDVisitor{maxID: 1} + PostOrderVisit(a.Expr(), visitor) + return visitor.maxID + 1 +} + // NewSourceInfo creates a simple SourceInfo object from an input common.Source value. func NewSourceInfo(src common.Source) *SourceInfo { var lineOffsets []int32 @@ -397,3 +404,21 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool { } return true } + +type maxIDVisitor struct { + maxID int64 +} + +// VisitExpr updates the max identifier if the incoming expression id is greater than previously observed. +func (v *maxIDVisitor) VisitExpr(e Expr) { + if v.maxID < e.ID() { + v.maxID = e.ID() + } +} + +// VisitEntryExpr updates the max identifier if the incoming entry id is greater than previously observed. +func (v *maxIDVisitor) VisitEntryExpr(e EntryExpr) { + if v.maxID < e.ID() { + v.maxID = e.ID() + } +} diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go index d152b2a2d..bb8f03a04 100644 --- a/common/ast/ast_test.go +++ b/common/ast/ast_test.go @@ -23,6 +23,8 @@ import ( "github.com/google/cel-go/common/ast" "github.com/google/cel-go/common/overloads" "github.com/google/cel-go/common/types" + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" ) func TestASTCopy(t *testing.T) { @@ -40,12 +42,38 @@ func TestASTCopy(t *testing.T) { for _, tst := range tests { checked := mustTypeCheck(t, tst) - copied := ast.Copy(checked) - if !reflect.DeepEqual(copied.Expr(), checked.Expr()) { - t.Errorf("Copy() got expr %v, wanted %v", copied.Expr(), checked.Expr()) + copyChecked := ast.Copy(checked) + if !reflect.DeepEqual(copyChecked.Expr(), checked.Expr()) { + t.Errorf("Copy() got expr %v, wanted %v", copyChecked.Expr(), checked.Expr()) } - if !reflect.DeepEqual(copied.SourceInfo(), checked.SourceInfo()) { - t.Errorf("Copy() got source info %v, wanted %v", copied.SourceInfo(), checked.SourceInfo()) + if !reflect.DeepEqual(copyChecked.SourceInfo(), checked.SourceInfo()) { + t.Errorf("Copy() got source info %v, wanted %v", copyChecked.SourceInfo(), checked.SourceInfo()) + } + copyParsed := ast.Copy(ast.NewAST(checked.Expr(), checked.SourceInfo())) + if !reflect.DeepEqual(copyParsed.Expr(), checked.Expr()) { + t.Errorf("Copy() got expr %v, wanted %v", copyParsed.Expr(), checked.Expr()) + } + if !reflect.DeepEqual(copyParsed.SourceInfo(), checked.SourceInfo()) { + t.Errorf("Copy() got source info %v, wanted %v", copyParsed.SourceInfo(), checked.SourceInfo()) + } + checkedPB, err := ast.ToProto(checked) + if err != nil { + t.Errorf("ast.ToProto() failed: %v", err) + } + copyCheckedPB, err := ast.ToProto(copyChecked) + if err != nil { + t.Errorf("ast.ToProto() failed: %v", err) + } + if !proto.Equal(checkedPB, copyCheckedPB) { + t.Errorf("Copy() produced different proto results, got %v, wanted %v", + prototext.Format(checkedPB), prototext.Format(copyCheckedPB)) + } + checkedRoundtrip, err := ast.ToAST(checkedPB) + if err != nil { + t.Errorf("ast.ToAST() failed: %v", err) + } + if !reflect.DeepEqual(checked, checkedRoundtrip) { + t.Errorf("Roundtrip got %v, wanted %v", checkedRoundtrip, checked) } } } @@ -67,10 +95,10 @@ func TestASTNilSafety(t *testing.T) { ast.NewAST(ex, info), ast.NewCheckedAST(ast.NewAST(ex, info), map[int64]*types.Type{}, map[int64]*ast.ReferenceInfo{}), } - for i, tst := range tests { - tc := tst - t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - testAST := tc + for _, tst := range tests { + a := tst + asts := []*ast.AST{a, ast.Copy(a)} + for _, testAST := range asts { if testAST.Expr().ID() != 0 { t.Errorf("Expr().ID() got %v, wanted 0", testAST.Expr().ID()) } @@ -86,7 +114,7 @@ func TestASTNilSafety(t *testing.T) { if len(testAST.GetOverloadIDs(testAST.Expr().ID())) != 0 { t.Errorf("GetOverloadIDs() got %v, wanted empty set", testAST.GetOverloadIDs(testAST.Expr().ID())) } - }) + } } } diff --git a/common/ast/conversion.go b/common/ast/conversion.go index 16c57871e..8f2c4bd1e 100644 --- a/common/ast/conversion.go +++ b/common/ast/conversion.go @@ -533,7 +533,7 @@ func SourceInfoToProto(info *SourceInfo) (*exprpb.SourceInfo, error) { return sourceInfo, nil } -// ProtoToSourceInfo deserializes +// ProtoToSourceInfo deserializes the protobuf into a native SourceInfo value. func ProtoToSourceInfo(info *exprpb.SourceInfo) (*SourceInfo, error) { sourceInfo := &SourceInfo{ syntax: info.GetSyntaxVersion(), diff --git a/common/ast/conversion_test.go b/common/ast/conversion_test.go index 523c5a812..32ad13c3c 100644 --- a/common/ast/conversion_test.go +++ b/common/ast/conversion_test.go @@ -219,6 +219,13 @@ func TestConvertExpr(t *testing.T) { if !reflect.DeepEqual(gotExpr, tc.wantExpr) { t.Errorf("got %v, wanted %v", gotExpr, tc.wantExpr) } + gotExprRoundtrip, err := ast.ProtoToExpr(gotPBExpr) + if err != nil { + t.Fatalf("ast.ProtoToExpr() failed: %v", err) + } + if !reflect.DeepEqual(parsed.Expr(), gotExprRoundtrip) { + t.Errorf("ast.ProtoToExpr() got %v, wanted %v", gotExprRoundtrip, parsed.Expr()) + } info := parsed.SourceInfo() for id, wantCall := range tc.macroCalls { call, found := info.GetMacroCall(id) diff --git a/common/ast/factory.go b/common/ast/factory.go index 8f4a31f34..0111c289a 100644 --- a/common/ast/factory.go +++ b/common/ast/factory.go @@ -1,3 +1,17 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package ast import "github.com/google/cel-go/common/types/ref" @@ -183,6 +197,10 @@ func (fac *baseExprFactory) NewUnspecifiedExpr(id int64) Expr { } func (fac *baseExprFactory) CopyExpr(e Expr) Expr { + // unwrap navigable expressions to avoid unnecessary allocations during copying. + if nav, ok := e.(*navigableExprImpl); ok { + e = nav.Expr + } switch e.Kind() { case CallKind: c := e.AsCall() diff --git a/common/ast/navigable.go b/common/ast/navigable.go index d4fe2395a..2836b5655 100644 --- a/common/ast/navigable.go +++ b/common/ast/navigable.go @@ -1,3 +1,17 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package ast import ( @@ -87,34 +101,149 @@ func AllMatcher() ExprMatcher { } } -// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values of the -// descendants which match. +// MatchDescendants takes a NavigableExpr and ExprMatcher and produces a list of NavigableExpr values +// matching the input criteria in post-order (bottom up). func MatchDescendants(expr NavigableExpr, matcher ExprMatcher) []NavigableExpr { - return matchListInternal([]NavigableExpr{expr}, matcher, true) + matches := []NavigableExpr{} + navVisitor := &baseVisitor{ + visitExpr: func(e Expr) { + nav := e.(NavigableExpr) + if matcher(nav) { + matches = append(matches, nav) + } + }, + } + visit(expr, navVisitor, postOrder, 0, 0) + return matches } // MatchSubset applies an ExprMatcher to a list of NavigableExpr values and their descendants, producing a // subset of NavigableExpr values which match. func MatchSubset(exprs []NavigableExpr, matcher ExprMatcher) []NavigableExpr { - visit := make([]NavigableExpr, len(exprs)) - copy(visit, exprs) - return matchListInternal(visit, matcher, false) + matches := []NavigableExpr{} + navVisitor := &baseVisitor{ + visitExpr: func(e Expr) { + nav := e.(NavigableExpr) + if matcher(nav) { + matches = append(matches, nav) + } + }, + } + for _, expr := range exprs { + visit(expr, navVisitor, postOrder, 0, 1) + } + return matches +} + +// Visitor defines an object for visiting Expr and EntryExpr nodes within an expression graph. +type Visitor interface { + // VisitExpr visits the input expression. + VisitExpr(Expr) + + // VisitEntryExpr visits the input entry expression, i.e. a struct field or map entry. + VisitEntryExpr(EntryExpr) +} + +type baseVisitor struct { + visitExpr func(Expr) + visitEntryExpr func(EntryExpr) +} + +// VisitExpr visits the Expr if the internal expr visitor has been configured. +func (v *baseVisitor) VisitExpr(e Expr) { + if v.visitExpr != nil { + v.visitExpr(e) + } +} + +// VisitEntryExpr visits the entry if the internal expr entry visitor has been configured. +func (v *baseVisitor) VisitEntryExpr(e EntryExpr) { + if v.visitEntryExpr != nil { + v.visitEntryExpr(e) + } +} + +// NewExprVisitor creates a visitor which only visits expression nodes. +func NewExprVisitor(v func(Expr)) Visitor { + return &baseVisitor{ + visitExpr: v, + visitEntryExpr: nil, + } } -func matchListInternal(visit []NavigableExpr, matcher ExprMatcher, visitDescendants bool) []NavigableExpr { - var matched []NavigableExpr - for len(visit) != 0 { - e := visit[0] - if matcher(e) { - matched = append(matched, e) +// PostOrderVisit walks the expression graph and calls the visitor in post-order (bottom-up). +func PostOrderVisit(expr Expr, visitor Visitor) { + visit(expr, visitor, postOrder, 0, 0) +} + +// PreOrderVisit walks the expression graph and calls the visitor in pre-order (top-down). +func PreOrderVisit(expr Expr, visitor Visitor) { + visit(expr, visitor, preOrder, 0, 0) +} + +type visitOrder int + +const ( + preOrder = iota + 1 + postOrder +) + +// TODO: consider exposing a way to configure a limit for the max visit depth. +// It's possible that we could want to configure this on the NewExprVisitor() +// and through MatchDescendents() / MaxID(). +func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) { + if maxDepth > 0 && depth == maxDepth { + return + } + if order == preOrder { + visitor.VisitExpr(expr) + } + switch expr.Kind() { + case CallKind: + c := expr.AsCall() + if c.IsMemberFunction() { + visit(c.Target(), visitor, order, depth+1, maxDepth) } - if visitDescendants { - visit = append(visit[1:], e.Children()...) - } else { - visit = visit[1:] + for _, arg := range c.Args() { + visit(arg, visitor, order, depth+1, maxDepth) } + case ComprehensionKind: + c := expr.AsComprehension() + visit(c.IterRange(), visitor, order, depth+1, maxDepth) + visit(c.AccuInit(), visitor, order, depth+1, maxDepth) + visit(c.LoopCondition(), visitor, order, depth+1, maxDepth) + visit(c.LoopStep(), visitor, order, depth+1, maxDepth) + visit(c.Result(), visitor, order, depth+1, maxDepth) + case ListKind: + l := expr.AsList() + for _, elem := range l.Elements() { + visit(elem, visitor, order, depth+1, maxDepth) + } + case MapKind: + m := expr.AsMap() + for _, e := range m.Entries() { + if order == preOrder { + visitor.VisitEntryExpr(e) + } + entry := e.AsMapEntry() + visit(entry.Key(), visitor, order, depth+1, maxDepth) + visit(entry.Value(), visitor, order, depth+1, maxDepth) + if order == postOrder { + visitor.VisitEntryExpr(e) + } + } + case SelectKind: + visit(expr.AsSelect().Operand(), visitor, order, depth+1, maxDepth) + case StructKind: + s := expr.AsStruct() + for _, f := range s.Fields() { + visitor.VisitEntryExpr(f) + visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth) + } + } + if order == postOrder { + visitor.VisitExpr(expr) } - return matched } func matchIsConstantValue(e NavigableExpr) bool { diff --git a/common/ast/navigable_test.go b/common/ast/navigable_test.go index e479a8ede..9b3cc0ea1 100644 --- a/common/ast/navigable_test.go +++ b/common/ast/navigable_test.go @@ -15,6 +15,7 @@ package ast_test import ( + "reflect" "testing" "google.golang.org/protobuf/proto" @@ -37,60 +38,70 @@ func TestNavigateAST(t *testing.T) { descendantCount int callCount int maxDepth int + maxID int64 }{ { expr: `'a' == 'b'`, descendantCount: 3, callCount: 1, maxDepth: 1, + maxID: 4, }, { expr: `'a'.size()`, descendantCount: 2, callCount: 1, maxDepth: 1, + maxID: 3, }, { expr: `[1, 2, 3]`, descendantCount: 4, callCount: 0, maxDepth: 1, + maxID: 5, }, { expr: `[1, 2, 3][0]`, descendantCount: 6, callCount: 1, maxDepth: 2, + maxID: 7, }, { expr: `{1u: 'hello'}`, descendantCount: 3, callCount: 0, maxDepth: 1, + maxID: 5, }, { expr: `{'hello': 'world'}.hello`, descendantCount: 4, callCount: 0, maxDepth: 2, + maxID: 6, }, { expr: `type(1) == int`, descendantCount: 4, callCount: 2, maxDepth: 2, + maxID: 5, }, { expr: `google.expr.proto3.test.TestAllTypes{single_int32: 1}`, descendantCount: 2, callCount: 0, maxDepth: 1, + maxID: 4, }, { expr: `[true].exists(i, i)`, descendantCount: 11, // 2 for iter range, 1 for accu init, 4 for loop condition, 3 for loop step, 1 for result callCount: 3, // @not_strictly_false(!result), accu_init || i maxDepth: 3, + maxID: 14, }, } @@ -112,6 +123,10 @@ func TestNavigateAST(t *testing.T) { if maxDepth != tc.maxDepth { t.Errorf("got max NavigableExpr.Depth() of %d, wanted %d", maxDepth, tc.maxDepth) } + maxID := ast.MaxID(checked) + if maxID != tc.maxID { + t.Errorf("got max id %d, wanted %d", maxID, tc.maxID) + } calls := ast.MatchSubset(descendants, ast.KindMatcher(ast.CallKind)) if len(calls) != tc.callCount { t.Errorf("ast.MatchSubset(%v) got %d calls, wanted %d", checked.Expr(), len(calls), tc.callCount) @@ -120,6 +135,97 @@ func TestNavigateAST(t *testing.T) { } } +func TestExprVisitor(t *testing.T) { + tests := []struct { + expr string + preOrderIDs []int64 + postOrderIDs []int64 + }{ + { + // [2] ==, [1] 'a', [3] 'b' + expr: `'a' == 'b'`, + preOrderIDs: []int64{2, 1, 3}, + postOrderIDs: []int64{1, 3, 2}, + }, + { + // [2] size(), [1] 'a' + expr: `'a'.size()`, + preOrderIDs: []int64{2, 1}, + postOrderIDs: []int64{1, 2}, + }, + { + // [3] ==, [1] type(), [2] 1, [4] int + expr: `type(1) == int`, + preOrderIDs: []int64{3, 1, 2, 4}, + postOrderIDs: []int64{2, 1, 4, 3}, + }, + { + // [5] .hello, [1] {}, [3] 'hello', [4] 'world' + expr: `{'hello': 'world'}.hello`, + preOrderIDs: []int64{5, 1, 3, 4}, + postOrderIDs: []int64{3, 4, 1, 5}, + }, + { + // [1] TestAllTypes, [3] 1 + expr: `google.expr.proto3.test.TestAllTypes{single_int32: 1}`, + preOrderIDs: []int64{1, 3}, + postOrderIDs: []int64{3, 1}, + }, + { + // [13] comprehension + // range: [1] [], [2] true + // accuInit: [6] result=false + // loopCond: [9] @not_strictly_false, [8] !, [7] result + // loopStep: [11] ||, [10] result [5] i + // result: [12] result + expr: `[true].exists(i, i)`, + preOrderIDs: []int64{13, 1, 2, 6, 9, 8, 7, 11, 10, 5, 12}, + postOrderIDs: []int64{2, 1, 6, 7, 8, 9, 10, 5, 11, 12, 13}, + }, + } + + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked := mustTypeCheck(t, tc.expr) + root := ast.NavigateAST(checked) + // Verify pre order visit behavior + preOrderExprIDs := []int64{} + navVisitor := ast.NewExprVisitor(func(e ast.Expr) { + nav := e.(ast.NavigableExpr) + preOrderExprIDs = append(preOrderExprIDs, nav.ID()) + }) + ast.PreOrderVisit(root, navVisitor) + if !reflect.DeepEqual(tc.preOrderIDs, preOrderExprIDs) { + t.Errorf("PreOrderVisit() got %v expressions, wanted %v", tc.preOrderIDs, preOrderExprIDs) + } + + // Demonstrate preOrder visit behavior with Children() + preOrderExprIDs = []int64{} + visited := []ast.NavigableExpr{root} + for len(visited) > 0 { + e := visited[0] + preOrderExprIDs = append(preOrderExprIDs, e.ID()) + visited = append(e.Children()[:], visited[1:]...) + } + if !reflect.DeepEqual(tc.preOrderIDs, preOrderExprIDs) { + t.Errorf("PreOrderVisit() got %v expressions, wanted %v", tc.preOrderIDs, preOrderExprIDs) + } + + // Verify post order visit behavior. + postOrderExprIDs := []int64{} + navVisitor = ast.NewExprVisitor(func(e ast.Expr) { + nav := e.(ast.NavigableExpr) + postOrderExprIDs = append(postOrderExprIDs, nav.ID()) + }) + ast.PostOrderVisit(root, navVisitor) + if !reflect.DeepEqual(tc.postOrderIDs, postOrderExprIDs) { + t.Errorf("PostOrderVisit() got %v expressions, wanted %v", tc.postOrderIDs, postOrderExprIDs) + } + }) + } +} + func TestNavigableASTNilSafety(t *testing.T) { tests := []struct { name string From 2de99521d8c8d4f5c1aa08ac4b4dbc414578396c Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Sat, 19 Aug 2023 00:23:52 -0400 Subject: [PATCH 22/31] Ensure stable ordering of overload candidates during dynamic dispatch (#817) --- cel/cel_test.go | 3 ++- common/decls/decls.go | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cel/cel_test.go b/cel/cel_test.go index b7dd7996f..0f8e36a1d 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -2081,7 +2081,8 @@ func TestDynamicDispatch(t *testing.T) { ), ) out, err := interpret(t, env, ` - [1, 2].first() == 1 + dyn([]).first() == 0 + && [1, 2].first() == 1 && [1.0, 2.0].first() == 1.0 && ["hello", "world"].first() == "hello" && [["hello"], ["world", "!"]].first().first() == "hello" diff --git a/common/decls/decls.go b/common/decls/decls.go index 0284f8dbc..734ebe57e 100644 --- a/common/decls/decls.go +++ b/common/decls/decls.go @@ -243,7 +243,8 @@ func (f *FunctionDecl) Bindings() ([]*functions.Overload, error) { // performs dynamic dispatch to the proper overload based on the argument types. bindings := append([]*functions.Overload{}, overloads...) funcDispatch := func(args ...ref.Val) ref.Val { - for _, o := range f.overloads { + for _, oID := range f.overloadOrdinals { + o := f.overloads[oID] // During dynamic dispatch over multiple functions, signature agreement checks // are preserved in order to assist with the function resolution step. switch len(args) { From 78039f1fa0bb190af877ed6fd7aefe7a4074cd18 Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Tue, 22 Aug 2023 12:24:17 -0400 Subject: [PATCH 23/31] Clarify replace with/by empty string (#820) --- ext/strings.go | 4 +++- ext/strings_test.go | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/ext/strings.go b/ext/strings.go index 13e919545..813f27f3e 100644 --- a/ext/strings.go +++ b/ext/strings.go @@ -202,6 +202,8 @@ const ( // 'hello hello'.replace('he', 'we', -1) // returns 'wello wello' // 'hello hello'.replace('he', 'we', 1) // returns 'wello hello' // 'hello hello'.replace('he', 'we', 0) // returns 'hello hello' +// 'hello hello'.replace('', '_') // returns '_h_e_l_l_o_ _h_e_l_l_o_' +// 'hello hello'.replace('h', '') // returns 'ello ello' // // # Split // @@ -514,7 +516,7 @@ func (lib *stringLib) CompileOptions() []cel.EnvOption { ) } if lib.version >= 3 { - opts = append( opts, + opts = append(opts, cel.Function("reverse", cel.MemberOverload("reverse", []*cel.Type{cel.StringType}, cel.StringType, cel.UnaryBinding(func(str ref.Val) ref.Val { diff --git a/ext/strings_test.go b/ext/strings_test.go index 6c1583ad3..138dd554c 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -82,6 +82,8 @@ var stringTests = []struct { {expr: `"{0} days {0} hours".replace("{0}", "2") == "2 days 2 hours"`}, {expr: `"{0} days {0} hours".replace("{0}", "2", 1).replace("{0}", "23") == "2 days 23 hours"`}, {expr: `"1 ©αT taco".replace("αT", "o©α") == "1 ©o©α taco"`}, + {expr: `"hello hello".replace("", "_") == "_h_e_l_l_o_ _h_e_l_l_o_"`}, + {expr: `"hello hello".replace("h", "") == "ello ello"`}, // Split tests. {expr: `"hello world".split(" ") == ["hello", "world"]`}, {expr: `"hello world events!".split(" ", 0) == []`}, From 8a4595524249653b05082a3ce5b1acb75c448844 Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Wed, 23 Aug 2023 12:32:17 -0700 Subject: [PATCH 24/31] Add Libraries() function to Env (#822) --- cel/env.go | 9 +++++++++ cel/env_test.go | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/cel/env.go b/cel/env.go index 473604bc6..0d9553ffd 100644 --- a/cel/env.go +++ b/cel/env.go @@ -389,6 +389,15 @@ func (e *Env) HasLibrary(libName string) bool { return exists && configured } +// Libraries returns a list of SingletonLibrary that have been configured in the environment. +func (e *Env) Libraries() []string { + libraries := make([]string, len(e.libraries)) + for libName := range e.libraries { + libraries = append(libraries, libName) + } + return libraries +} + // HasValidator returns whether a specific ASTValidator has been configured in the environment. func (e *Env) HasValidator(name string) bool { for _, v := range e.validators { diff --git a/cel/env_test.go b/cel/env_test.go index ed1f836e8..54be88ce6 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -247,6 +247,26 @@ func TestTypeProviderInterop(t *testing.T) { } } +func TestLibraries(t *testing.T) { + e, err := NewEnv(OptionalTypes()) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, expected := range []string{"cel.lib.std", "cel.lib.optional"} { + if !e.HasLibrary(expected) { + t.Errorf("Expected HasLibrary() to return true for '%s'", expected) + } + libMap := map[string]struct{}{} + for _, lib := range e.Libraries() { + libMap[lib] = struct{}{} + } + + if _, ok := libMap[expected]; !ok { + t.Errorf("Expected Libraries() to include '%s'", expected) + } + } +} + func BenchmarkNewCustomEnvLazy(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { From dd6d31dd8b4d3f220f07d0dd906476f9687a7d79 Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Wed, 23 Aug 2023 13:48:26 -0700 Subject: [PATCH 25/31] Fix slice handling in Env.Libraries() (#825) --- cel/env.go | 2 +- cel/env_test.go | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cel/env.go b/cel/env.go index 0d9553ffd..113b89b56 100644 --- a/cel/env.go +++ b/cel/env.go @@ -391,7 +391,7 @@ func (e *Env) HasLibrary(libName string) bool { // Libraries returns a list of SingletonLibrary that have been configured in the environment. func (e *Env) Libraries() []string { - libraries := make([]string, len(e.libraries)) + libraries := make([]string, 0, len(e.libraries)) for libName := range e.libraries { libraries = append(libraries, libName) } diff --git a/cel/env_test.go b/cel/env_test.go index 54be88ce6..7a69052b4 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -257,9 +257,13 @@ func TestLibraries(t *testing.T) { t.Errorf("Expected HasLibrary() to return true for '%s'", expected) } libMap := map[string]struct{}{} - for _, lib := range e.Libraries() { + libraries := e.Libraries() + for _, lib := range libraries { libMap[lib] = struct{}{} } + if len(libraries) != 2 { + t.Errorf("Expected HasLibrary() to contain exactly 2 libraries but got: %v", libraries) + } if _, ok := libMap[expected]; !ok { t.Errorf("Expected Libraries() to include '%s'", expected) From 509c1d6161f0ae6bbb91cab1e65daeefb0109316 Mon Sep 17 00:00:00 2001 From: Joe Betz Date: Wed, 23 Aug 2023 15:47:17 -0700 Subject: [PATCH 26/31] Document that string reverse was introduced in version 3 of ext.strings (#824) --- ext/strings.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ext/strings.go b/ext/strings.go index 813f27f3e..1faa6ed7d 100644 --- a/ext/strings.go +++ b/ext/strings.go @@ -272,6 +272,8 @@ const ( // // # Reverse // +// Introduced at version: 3 +// // Returns a new string whose characters are the same as the target string, only formatted in // reverse order. // This function relies on converting strings to rune arrays in order to reverse From 705546a3fd4975d917e972b33596796ef2e739c0 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Wed, 30 Aug 2023 16:16:47 -0700 Subject: [PATCH 27/31] Static optimizer for constant folding (#804) * Optimizer API with Constant Folding implementatiton * Better logical folds and additional tests * Add a configurable limit to constant folding --- cel/BUILD.bazel | 3 + cel/env.go | 8 + cel/folding.go | 553 ++++++++++++++++++++++++++++++++++ cel/folding_test.go | 652 ++++++++++++++++++++++++++++++++++++++++ cel/optimizer.go | 317 +++++++++++++++++++ common/ast/ast.go | 13 +- common/ast/expr.go | 22 +- common/ast/factory.go | 11 +- common/ast/navigable.go | 4 + 9 files changed, 1575 insertions(+), 8 deletions(-) create mode 100644 cel/folding.go create mode 100644 cel/folding_test.go create mode 100644 cel/optimizer.go diff --git a/cel/BUILD.bazel b/cel/BUILD.bazel index aa978e064..62b903c81 100644 --- a/cel/BUILD.bazel +++ b/cel/BUILD.bazel @@ -10,9 +10,11 @@ go_library( "cel.go", "decls.go", "env.go", + "folding.go", "io.go", "library.go", "macro.go", + "optimizer.go", "options.go", "program.go", "validator.go", @@ -56,6 +58,7 @@ go_test( "cel_test.go", "decls_test.go", "env_test.go", + "folding_test.go", "io_test.go", "validator_test.go", ], diff --git a/cel/env.go b/cel/env.go index 113b89b56..786a13c4d 100644 --- a/cel/env.go +++ b/cel/env.go @@ -43,6 +43,9 @@ type Ast struct { } // Expr returns the proto serializable instance of the parsed/checked expression. +// +// Deprecated: prefer cel.AstToCheckedExpr() or cel.AstToParsedExpr() and call GetExpr() +// the result instead. func (ast *Ast) Expr() *exprpb.Expr { if ast == nil { return nil @@ -221,6 +224,11 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { source: ast.Source(), impl: checked} + // Avoid creating a validator config if it's not needed. + if len(e.validators) == 0 { + return ast, nil + } + // Generate a validator configuration from the set of configured validators. vConfig := newValidatorConfig() for _, v := range e.validators { diff --git a/cel/folding.go b/cel/folding.go new file mode 100644 index 000000000..5903d732c --- /dev/null +++ b/cel/folding.go @@ -0,0 +1,553 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "fmt" + + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + "github.com/google/cel-go/common/types/traits" +) + +// ConstantFoldingOption defines a functional option for configuring constant folding. +type ConstantFoldingOption func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) + +// MaxConstantFoldIterations limits the number of times literals may be folding during optimization. +// +// Defaults to 100 if not set. +func MaxConstantFoldIterations(limit int) ConstantFoldingOption { + return func(opt *constantFoldingOptimizer) (*constantFoldingOptimizer, error) { + opt.maxFoldIterations = limit + return opt, nil + } +} + +// NewConstantFoldingOptimizer creates an optimizer which inlines constant scalar an aggregate +// literal values within function calls and select statements with their evaluated result. +func NewConstantFoldingOptimizer(opts ...ConstantFoldingOption) (ASTOptimizer, error) { + folder := &constantFoldingOptimizer{ + maxFoldIterations: defaultMaxConstantFoldIterations, + } + var err error + for _, o := range opts { + folder, err = o(folder) + if err != nil { + return nil, err + } + } + return folder, nil +} + +type constantFoldingOptimizer struct { + maxFoldIterations int +} + +// Optimize queries the expression graph for scalar and aggregate literal expressions within call and +// select statements and then evaluates them and replaces the call site with the literal result. +// +// Note: only values which can be represented as literals in CEL syntax are supported. +func (opt *constantFoldingOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + + // Walk the list of foldable expression and continue to fold until there are no more folds left. + // All of the fold candidates returned by the constantExprMatcher should succeed unless there's + // a logic bug with the selection of expressions. + foldableExprs := ast.MatchDescendants(root, constantExprMatcher) + foldCount := 0 + for len(foldableExprs) != 0 && foldCount < opt.maxFoldIterations { + for _, fold := range foldableExprs { + // If the expression could be folded because it's a non-strict call, and the + // branches are pruned, continue to the next fold. + if fold.Kind() == ast.CallKind && maybePruneBranches(ctx, fold) { + continue + } + // Otherwise, assume all context is needed to evaluate the expression. + err := tryFold(ctx, a, fold) + if err != nil { + ctx.ReportErrorAtID(fold.ID(), "constant-folding evaluation failed: %v", err.Error()) + return a + } + } + foldCount++ + foldableExprs = ast.MatchDescendants(root, constantExprMatcher) + } + // Once all of the constants have been folded, try to run through the remaining comprehensions + // one last time. In this case, there's no guarantee they'll run, so we only update the + // target comprehension node with the literal value if the evaluation succeeds. + for _, compre := range ast.MatchDescendants(root, ast.KindMatcher(ast.ComprehensionKind)) { + tryFold(ctx, a, compre) + } + + // If the output is a list, map, or struct which contains optional entries, then prune it + // to make sure that the optionals, if resolved, do not surface in the output literal. + pruneOptionalElements(ctx, root) + + // Ensure that all intermediate values in the folded expression can be represented as valid + // CEL literals within the AST structure. Use `PostOrderVisit` rather than `MatchDescendents` + // to avoid extra allocations during this final pass through the AST. + ast.PostOrderVisit(root, ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() != ast.LiteralKind { + return + } + val := e.AsLiteral() + adapted, err := adaptLiteral(ctx, val) + if err != nil { + ctx.ReportErrorAtID(root.ID(), "constant-folding evaluation failed: %v", err.Error()) + return + } + e.SetKindCase(adapted) + })) + + return a +} + +// tryFold attempts to evaluate a sub-expression to a literal. +// +// If the evaluation succeeds, the input expr value will be modified to become a literal, otherwise +// the method will return an error. +func tryFold(ctx *OptimizerContext, a *ast.AST, expr ast.Expr) error { + // Assume all context is needed to evaluate the expression. + subAST := &Ast{ + impl: ast.NewCheckedAST(ast.NewAST(expr, a.SourceInfo()), a.TypeMap(), a.ReferenceMap()), + } + prg, err := ctx.Program(subAST) + if err != nil { + return err + } + out, _, err := prg.Eval(NoVars()) + if err != nil { + return err + } + // Clear any macro metadata associated with the fold. + a.SourceInfo().ClearMacroCall(expr.ID()) + // Update the fold expression to be a literal. + expr.SetKindCase(ctx.NewLiteral(out)) + return nil +} + +// maybePruneBranches inspects the non-strict call expression to determine whether +// a branch can be removed. Evaluation will naturally prune logical and / or calls, +// but conditional will not be pruned cleanly, so this is one small area where the +// constant folding step reimplements a portion of the evaluator. +func maybePruneBranches(ctx *OptimizerContext, expr ast.NavigableExpr) bool { + call := expr.AsCall() + args := call.Args() + switch call.FunctionName() { + case operators.LogicalAnd, operators.LogicalOr: + return maybeShortcircuitLogic(ctx, call.FunctionName(), args, expr) + case operators.Conditional: + cond := args[0] + truthy := args[1] + falsy := args[2] + if cond.Kind() != ast.LiteralKind { + return false + } + if cond.AsLiteral() == types.True { + expr.SetKindCase(truthy) + } else { + expr.SetKindCase(falsy) + } + return true + case operators.In: + haystack := args[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + expr.SetKindCase(ctx.NewLiteral(types.False)) + return true + } + needle := args[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + expr.SetKindCase(ctx.NewLiteral(types.True)) + return true + } + } + } + } + return false +} + +func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.Expr, expr ast.NavigableExpr) bool { + shortcircuit := types.False + skip := types.True + if function == operators.LogicalOr { + shortcircuit = types.True + skip = types.False + } + newArgs := []ast.Expr{} + for _, arg := range args { + if arg.Kind() != ast.LiteralKind { + newArgs = append(newArgs, arg) + continue + } + if arg.AsLiteral() == skip { + continue + } + if arg.AsLiteral() == shortcircuit { + expr.SetKindCase(arg) + return true + } + } + if len(newArgs) == 1 { + expr.SetKindCase(newArgs[0]) + return true + } + expr.SetKindCase(ctx.NewCall(function, newArgs...)) + return true +} + +// pruneOptionalElements works from the bottom up to resolve optional elements within +// aggregate literals. +// +// Note, many aggregate literals will be resolved as arguments to functions or select +// statements, so this method exists to handle the case where the literal could not be +// fully resolved or exists outside of a call, select, or comprehension context. +func pruneOptionalElements(ctx *OptimizerContext, root ast.NavigableExpr) { + aggregateLiterals := ast.MatchDescendants(root, aggregateLiteralMatcher) + for _, lit := range aggregateLiterals { + switch lit.Kind() { + case ast.ListKind: + pruneOptionalListElements(ctx, lit) + case ast.MapKind: + pruneOptionalMapEntries(ctx, lit) + case ast.StructKind: + pruneOptionalStructFields(ctx, lit) + } + } +} + +func pruneOptionalListElements(ctx *OptimizerContext, e ast.Expr) { + l := e.AsList() + elems := l.Elements() + optIndices := l.OptionalIndices() + if len(optIndices) == 0 { + return + } + updatedElems := []ast.Expr{} + updatedIndices := []int32{} + for i, e := range elems { + if !l.IsOptional(int32(i)) { + updatedElems = append(updatedElems, e) + continue + } + if e.Kind() != ast.LiteralKind { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + optElemVal, ok := e.AsLiteral().(*types.Optional) + if !ok { + updatedElems = append(updatedElems, e) + updatedIndices = append(updatedIndices, int32(i)) + continue + } + if !optElemVal.HasValue() { + continue + } + e.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedElems = append(updatedElems, e) + } + e.SetKindCase(ctx.NewList(updatedElems, updatedIndices)) +} + +func pruneOptionalMapEntries(ctx *OptimizerContext, e ast.Expr) { + m := e.AsMap() + entries := m.Entries() + updatedEntries := []ast.EntryExpr{} + modified := false + for _, e := range entries { + entry := e.AsMapEntry() + key := entry.Key() + val := entry.Value() + // If the entry is not optional, or the value-side of the optional hasn't + // been resolved to a literal, then preserve the entry as-is. + if !entry.IsOptional() || val.Kind() != ast.LiteralKind { + updatedEntries = append(updatedEntries, e) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedEntries = append(updatedEntries, e) + continue + } + // When the key is not a literal, but the value is, then it needs to be + // restored to an optional value. + if key.Kind() != ast.LiteralKind { + undoOptVal, err := adaptLiteral(ctx, optElemVal) + if err != nil { + ctx.ReportErrorAtID(val.ID(), "invalid map value literal %v: %v", optElemVal, err) + } + val.SetKindCase(undoOptVal) + updatedEntries = append(updatedEntries, e) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedEntry := ctx.NewMapEntry(key, val, false) + updatedEntries = append(updatedEntries, updatedEntry) + } + if modified { + e.SetKindCase(ctx.NewMap(updatedEntries)) + } +} + +func pruneOptionalStructFields(ctx *OptimizerContext, e ast.Expr) { + s := e.AsStruct() + fields := s.Fields() + updatedFields := []ast.EntryExpr{} + modified := false + for _, f := range fields { + field := f.AsStructField() + val := field.Value() + if !field.IsOptional() || val.Kind() != ast.LiteralKind { + updatedFields = append(updatedFields, f) + continue + } + optElemVal, ok := val.AsLiteral().(*types.Optional) + if !ok { + updatedFields = append(updatedFields, f) + continue + } + modified = true + if !optElemVal.HasValue() { + continue + } + val.SetKindCase(ctx.NewLiteral(optElemVal.GetValue())) + updatedField := ctx.NewStructField(field.Name(), val, false) + updatedFields = append(updatedFields, updatedField) + } + if modified { + e.SetKindCase(ctx.NewStruct(s.TypeName(), updatedFields)) + } +} + +// adaptLiteral converts a runtime CEL value to its equivalent literal expression. +// +// For strongly typed values, the type-provider will be used to reconstruct the fields +// which are present in the literal and their equivalent initialization values. +func adaptLiteral(ctx *OptimizerContext, val ref.Val) (ast.Expr, error) { + switch t := val.Type().(type) { + case *types.Type: + switch t { + case types.BoolType, types.BytesType, types.DoubleType, types.IntType, + types.NullType, types.StringType, types.UintType: + return ctx.NewLiteral(val), nil + case types.DurationType: + return ctx.NewCall( + overloads.TypeConvertDuration, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.TimestampType: + return ctx.NewCall( + overloads.TypeConvertTimestamp, + ctx.NewLiteral(val.ConvertToType(types.StringType)), + ), nil + case types.OptionalType: + opt := val.(*types.Optional) + if !opt.HasValue() { + return ctx.NewCall("optional.none"), nil + } + target, err := adaptLiteral(ctx, opt.GetValue()) + if err != nil { + return nil, err + } + return ctx.NewCall("optional.of", target), nil + case types.TypeType: + return ctx.NewIdent(val.(*types.Type).TypeName()), nil + case types.ListType: + l, ok := val.(traits.Lister) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + elems := make([]ast.Expr, l.Size().(types.Int)) + idx := 0 + it := l.Iterator() + for it.HasNext() == types.True { + elemVal := it.Next() + elemExpr, err := adaptLiteral(ctx, elemVal) + if err != nil { + return nil, err + } + elems[idx] = elemExpr + idx++ + } + return ctx.NewList(elems, []int32{}), nil + case types.MapType: + m, ok := val.(traits.Mapper) + if !ok { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + entries := make([]ast.EntryExpr, m.Size().(types.Int)) + idx := 0 + it := m.Iterator() + for it.HasNext() == types.True { + keyVal := it.Next() + keyExpr, err := adaptLiteral(ctx, keyVal) + if err != nil { + return nil, err + } + valVal := m.Get(keyVal) + valExpr, err := adaptLiteral(ctx, valVal) + if err != nil { + return nil, err + } + entries[idx] = ctx.NewMapEntry(keyExpr, valExpr, false) + idx++ + } + return ctx.NewMap(entries), nil + default: + provider := ctx.CELTypeProvider() + fields, found := provider.FindStructFieldNames(t.TypeName()) + if !found { + return nil, fmt.Errorf("failed to adapt %v to literal", val) + } + tester := val.(traits.FieldTester) + indexer := val.(traits.Indexer) + fieldInits := []ast.EntryExpr{} + for _, f := range fields { + field := types.String(f) + if tester.IsSet(field) != types.True { + continue + } + fieldVal := indexer.Get(field) + fieldExpr, err := adaptLiteral(ctx, fieldVal) + if err != nil { + return nil, err + } + fieldInits = append(fieldInits, ctx.NewStructField(f, fieldExpr, false)) + } + return ctx.NewStruct(t.TypeName(), fieldInits), nil + } + } + return nil, fmt.Errorf("failed to adapt %v to literal", val) +} + +// constantExprMatcher matches calls, select statements, and comprehensions whose arguments +// are all constant scalar or aggregate literal values. +// +// Only comprehensions which are not nested are included as possible constant folds, and only +// if all variables referenced in the comprehension stack exist are only iteration or +// accumulation variables. +func constantExprMatcher(e ast.NavigableExpr) bool { + switch e.Kind() { + case ast.CallKind: + return constantCallMatcher(e) + case ast.SelectKind: + sel := e.AsSelect() // guaranteed to be a navigable value + return constantMatcher(sel.Operand().(ast.NavigableExpr)) + case ast.ComprehensionKind: + if isNestedComprehension(e) { + return false + } + vars := map[string]bool{} + constantExprs := true + visitor := ast.NewExprVisitor(func(e ast.Expr) { + if e.Kind() == ast.ComprehensionKind { + nested := e.AsComprehension() + vars[nested.AccuVar()] = true + vars[nested.IterVar()] = true + } + if e.Kind() == ast.IdentKind && !vars[e.AsIdent()] { + constantExprs = false + } + }) + ast.PreOrderVisit(e, visitor) + return constantExprs + default: + return false + } +} + +// constantCallMatcher identifies strict and non-strict calls which can be folded. +func constantCallMatcher(e ast.NavigableExpr) bool { + call := e.AsCall() + children := e.Children() + fnName := call.FunctionName() + if fnName == operators.LogicalAnd { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.LogicalOr { + for _, child := range children { + if child.Kind() == ast.LiteralKind { + return true + } + } + } + if fnName == operators.Conditional { + cond := children[0] + if cond.Kind() == ast.LiteralKind && cond.AsLiteral().Type() == types.BoolType { + return true + } + } + if fnName == operators.In { + haystack := children[1] + if haystack.Kind() == ast.ListKind && haystack.AsList().Size() == 0 { + return true + } + needle := children[0] + if needle.Kind() == ast.LiteralKind && haystack.Kind() == ast.ListKind { + needleValue := needle.AsLiteral() + list := haystack.AsList() + for _, e := range list.Elements() { + if e.Kind() == ast.LiteralKind && e.AsLiteral().Equal(needleValue) == types.True { + return true + } + } + } + } + // convert all other calls with constant arguments + for _, child := range children { + if !constantMatcher(child) { + return false + } + } + return true +} + +func isNestedComprehension(e ast.NavigableExpr) bool { + parent, found := e.Parent() + for found { + if parent.Kind() == ast.ComprehensionKind { + return true + } + parent, found = parent.Parent() + } + return false +} + +func aggregateLiteralMatcher(e ast.NavigableExpr) bool { + return e.Kind() == ast.ListKind || e.Kind() == ast.MapKind || e.Kind() == ast.StructKind +} + +var ( + constantMatcher = ast.ConstantValueMatcher() +) + +const ( + defaultMaxConstantFoldIterations = 100 +) diff --git a/cel/folding_test.go b/cel/folding_test.go new file mode 100644 index 000000000..18fe4d58f --- /dev/null +++ b/cel/folding_test.go @@ -0,0 +1,652 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "reflect" + "sort" + "testing" + + "google.golang.org/protobuf/encoding/prototext" + "google.golang.org/protobuf/proto" + + "github.com/google/cel-go/common/ast" + + proto3pb "github.com/google/cel-go/test/proto3pb" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" +) + +func TestConstantFoldingOptimizer(t *testing.T) { + tests := []struct { + expr string + folded string + }{ + { + expr: `[1, 1 + 2, 1 + (2 + 3)]`, + folded: `[1, 3, 6]`, + }, + { + expr: `6 in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `true`, + }, + { + expr: `5 in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `false`, + }, + { + expr: `x in [1, 1 + 2, 1 + (2 + 3)]`, + folded: `x in [1, 3, 6]`, + }, + { + expr: `1 in [1, x + 2, 1 + (2 + 3)]`, + folded: `true`, + }, + { + expr: `1 in [x, x + 2, 1 + (2 + 3)]`, + folded: `1 in [x, x + 2, 6]`, + }, + { + expr: `x in []`, + folded: `false`, + }, + { + expr: `{'hello': 'world'}.hello == x`, + folded: `"world" == x`, + }, + { + expr: `{'hello': 'world'}.?hello.orValue('default') == x`, + folded: `"world" == x`, + }, + { + expr: `{'hello': 'world'}['hello'] == x`, + folded: `"world" == x`, + }, + { + expr: `optional.of("hello")`, + folded: `optional.of("hello")`, + }, + { + expr: `optional.ofNonZeroValue("")`, + folded: `optional.none()`, + }, + { + expr: `{?'hello': optional.of('world')}['hello'] == x`, + folded: `"world" == x`, + }, + { + expr: `duration(string(7 * 24) + 'h')`, + folded: `duration("604800s")`, + }, + { + expr: `timestamp("1970-01-01T00:00:00Z")`, + folded: `timestamp("1970-01-01T00:00:00Z")`, + }, + { + expr: `[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 10)`, + folded: `true`, + }, + { + expr: `[1, 1 + 1, 1 + 2, 2 + 3].exists(i, i < 1 % 2)`, + folded: `false`, + }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j))`, + folded: `[[1, 2, 3], [2, 4, 6], [3, 6, 9]]`, + }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == 0))`, + folded: `[[2], [2, 4, 6], [6]]`, + }, + { + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == x))`, + folded: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j).filter(k, k % 2 == x))`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].filter(m, has(m.a))`, + folded: `[{"a": 1}]`, + }, + { + expr: `[{}, {"a": 1}, {"b": 2}].filter(m, has({'a': true}.a))`, + folded: `[{}, {"a": 1}, {"b": 2}]`, + }, + { + expr: `type(1)`, + folded: `int`, + }, + { + expr: `[google.expr.proto3.test.TestAllTypes{single_int32: 2 + 3}].map(i, i)[0]`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: 5}`, + }, + { + expr: `[1, ?optional.ofNonZeroValue(0)]`, + folded: `[1]`, + }, + { + expr: `[1, x, ?optional.ofNonZeroValue(3), ?x.?y]`, + folded: `[1, x, 3, ?x.?y]`, + }, + { + expr: `[1, x, ?optional.ofNonZeroValue(3), ?x.?y].size() > 3`, + folded: `[1, x, 3, ?x.?y].size() > 3`, + }, + { + expr: `{?'a': optional.of('hello'), ?x : optional.of(1), ?'b': optional.none()}`, + folded: `{"a": "hello", ?x: optional.of(1)}`, + }, + { + expr: `true ? x + 1 : x + 2`, + folded: `x + 1`, + }, + { + expr: `false ? x + 1 : x + 2`, + folded: `x + 2`, + }, + { + expr: `false ? x + 'world' : 'hello' + 'world'`, + folded: `"helloworld"`, + }, + { + expr: `true && x`, + folded: `x`, + }, + { + expr: `x && true`, + folded: `x`, + }, + { + expr: `false && x`, + folded: `false`, + }, + { + expr: `x && false`, + folded: `false`, + }, + { + expr: `true || x`, + folded: `true`, + }, + { + expr: `x || true`, + folded: `true`, + }, + { + expr: `false || x`, + folded: `x`, + }, + { + expr: `x || false`, + folded: `x`, + }, + { + expr: `true && x && true && x`, + folded: `x && x`, + }, + { + expr: `false || x || false || x`, + folded: `x || x`, + }, + { + expr: `null`, + folded: `null`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{?single_int32: optional.ofNonZeroValue(1)}`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: 1}`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{?single_int32: optional.ofNonZeroValue(0)}`, + folded: `google.expr.proto3.test.TestAllTypes{}`, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}`, + folded: `google.expr.proto3.test.TestAllTypes{single_int32: x, repeated_int32: [1, 2, 3]}`, + }, + { + expr: `x + dyn([1, 2] + [3, 4])`, + folded: `x + [1, 2, 3, 4]`, + }, + { + expr: `dyn([1, 2]) + [3.0, 4.0]`, + folded: `[1, 2, 3.0, 4.0]`, + }, + { + expr: `{'a': dyn([1, 2]), 'b': x}`, + folded: `{"a": [1, 2], "b": x}`, + }, + { + expr: `1 + x + 2 == 2 + x + 1`, + folded: `1 + x + 2 == 2 + x + 1`, + }, + { + // The order of operations makes it such that the appearance of x in the first means that + // none of the values provided into the addition call will be folded with the current + // implementation. Ideally, the result would be 3 + x == x + 3 (which could be trivially true + // and more easily observed as a result of common subexpression eliminiation) + expr: `1 + 2 + x == x + 2 + 1`, + folded: `3 + x == x + 2 + 1`, + }, + } + e, err := NewEnv( + OptionalTypes(), + EnableMacroCallTracking(), + Types(&proto3pb.TestAllTypes{}), + Variable("x", DynType)) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + folder, err := NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := AstToString(optimized) + if err != nil { + t.Fatalf("AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} + +func TestConstantFoldingOptimizerWithLimit(t *testing.T) { + tests := []struct { + expr string + limit int + folded string + }{ + { + expr: `[1, 1 + 2, 1 + (2 + 3)]`, + limit: 1, + folded: `[1, 3, 1 + 5]`, + }, + { + expr: `5 in [1, 1 + 2, 1 + (2 + 3)]`, + limit: 2, + folded: `5 in [1, 3, 6]`, + }, + { + // though more complex, the final tryFold() at the end of the optimization pass + // results in this computed output. + expr: `[1, 2, 3].map(i, [1, 2, 3].map(j, i * j))`, + limit: 1, + folded: `[[1, 2, 3], [2, 4, 6], [3, 6, 9]]`, + }, + } + e, err := NewEnv( + OptionalTypes(), + EnableMacroCallTracking(), + Types(&proto3pb.TestAllTypes{}), + Variable("x", DynType)) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + folder, err := NewConstantFoldingOptimizer(MaxConstantFoldIterations(tc.limit)) + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := AstToString(optimized) + if err != nil { + t.Fatalf("AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} + +func TestConstantFoldingNormalizeIDs(t *testing.T) { + tests := []struct { + expr string + ids []int64 + macros map[int64]string + normalizedIDs []int64 + normalizedMacros map[int64]string + }{ + { + expr: `[1, 2, 3]`, + ids: []int64{1, 2, 3, 4}, + normalizedIDs: []int64{1, 2, 3, 4}, + }, + { + expr: `google.expr.proto3.test.TestAllTypes{single_int32: 0}`, + ids: []int64{1, 2, 3}, + normalizedIDs: []int64{1, 2, 3}, + }, + { + expr: `has({x: 'value'}.single_int32)`, + ids: []int64{2, 3, 4, 5, 7}, + macros: map[int64]string{7: ` + call_expr: { + function: "has" + args: { + id: 6 + select_expr: { + operand: { + id: 2 + struct_expr: { + entries: { + id: 3 + map_key: { + id: 4 + ident_expr: { + name: "x" + } + } + value: { + id: 5 + const_expr: { + string_value: "value" + } + } + } + } + } + field: "single_int32" + } + } + }`}, + normalizedIDs: []int64{1, 2, 3, 4, 5}, + normalizedMacros: map[int64]string{1: ` + call_expr: { + function: "has" + args: { + id: 6 + select_expr: { + operand: { + id: 2 + struct_expr: { + entries: { + id: 3 + map_key: { + id: 4 + ident_expr: { + name: "x" + } + } + value: { + id: 5 + const_expr: { + string_value: "value" + } + } + } + } + } + field: "single_int32" + } + } + }`, + }, + }, + { + expr: `has(google.expr.proto3.test.TestAllTypes{}.single_int32)`, + ids: []int64{2, 4}, + macros: map[int64]string{ + 4: `call_expr: { + function: "has" + args: { + id: 3 + select_expr: { + operand: { + id: 2 + struct_expr: { + message_name: "google.expr.proto3.test.TestAllTypes" + } + } + field: "single_int32" + } + } + }`, + }, + normalizedIDs: []int64{1}, + }, + { + expr: `[true].exists(i, i)`, + ids: []int64{1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + macros: map[int64]string{ + 13: `call_expr: { + target: { + id: 1 + list_expr: { + elements: { + id: 2 + const_expr: { + bool_value: true + } + } + } + } + function: "exists" + args: { + id: 4 + ident_expr: { + name: "i" + } + } + args: { + id: 5 + ident_expr: { + name: "i" + } + } + }`, + }, + normalizedIDs: []int64{1}, + }, + { + expr: `[x].exists(i, i)`, + ids: []int64{1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13}, + macros: map[int64]string{ + 13: `call_expr: { + target: { + id: 1 + list_expr: { + elements: { + id: 2 + ident_expr: { + name: "x" + } + } + } + } + function: "exists" + args: { + id: 4 + ident_expr: { + name: "i" + } + } + args: { + id: 5 + ident_expr: { + name: "i" + } + } + }`, + }, + normalizedIDs: []int64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + normalizedMacros: map[int64]string{ + 1: `call_expr: { + target: { + id: 2 + list_expr: { + elements: { + id: 3 + ident_expr: { + name: "x" + } + } + } + } + function: "exists" + args: { + id: 12 + ident_expr: { + name: "i" + } + } + args: { + id: 10 + ident_expr: { + name: "i" + } + } + }`, + }, + }, + } + e, err := NewEnv( + EnableMacroCallTracking(), + Types(&proto3pb.TestAllTypes{}), + Variable("x", DynType)) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + preOpt := newIDCollector() + ast.PostOrderVisit(checked.impl.Expr(), preOpt) + if !reflect.DeepEqual(preOpt.IDs(), tc.ids) { + t.Errorf("Compile() got ids %v, expected %v", preOpt.IDs(), tc.ids) + } + for id, call := range checked.impl.SourceInfo().MacroCalls() { + macroText, found := tc.macros[id] + if !found { + t.Fatalf("Compile() did not find macro %d", id) + } + pbCall, err := ast.ExprToProto(call) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + pbMacro := &exprpb.Expr{} + err = prototext.Unmarshal([]byte(macroText), pbMacro) + if err != nil { + t.Fatalf("prototext.Unmarshal() failed: %v", err) + } + if !proto.Equal(pbCall, pbMacro) { + t.Errorf("Compile() for macro %d got %s, expected %s", id, prototext.Format(pbCall), macroText) + } + } + folder, err := NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt := NewStaticOptimizer(folder) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + postOpt := newIDCollector() + ast.PostOrderVisit(optimized.impl.Expr(), postOpt) + if !reflect.DeepEqual(postOpt.IDs(), tc.normalizedIDs) { + t.Errorf("Optimize() got ids %v, expected %v", postOpt.IDs(), tc.normalizedIDs) + } + for id, call := range optimized.impl.SourceInfo().MacroCalls() { + macroText, found := tc.normalizedMacros[id] + if !found { + t.Fatalf("Optimize() did not find macro %d", id) + } + pbCall, err := ast.ExprToProto(call) + if err != nil { + t.Fatalf("ast.ExprToProto() failed: %v", err) + } + pbMacro := &exprpb.Expr{} + err = prototext.Unmarshal([]byte(macroText), pbMacro) + if err != nil { + t.Fatalf("prototext.Unmarshal() failed: %v", err) + } + if !proto.Equal(pbCall, pbMacro) { + t.Errorf("Optimize() for macro %d got %s, expected %s", id, prototext.Format(pbCall), macroText) + } + } + }) + } +} + +func newIDCollector() *idCollector { + return &idCollector{ + ids: int64Slice{}, + } +} + +type idCollector struct { + ids int64Slice +} + +func (c *idCollector) VisitExpr(e ast.Expr) { + if e.ID() == 0 { + return + } + c.ids = append(c.ids, e.ID()) +} + +// VisitEntryExpr updates the max identifier if the incoming entry id is greater than previously observed. +func (c *idCollector) VisitEntryExpr(e ast.EntryExpr) { + if e.ID() == 0 { + return + } + c.ids = append(c.ids, e.ID()) +} + +func (c *idCollector) IDs() []int64 { + sort.Sort(c.ids) + return c.ids +} + +// int64Slice is an implementation of the sort.Interface +type int64Slice []int64 + +// Len returns the number of elements in the slice. +func (x int64Slice) Len() int { return len(x) } + +// Less indicates whether the value at index i is less than the value at index j. +func (x int64Slice) Less(i, j int) bool { return x[i] < x[j] } + +// Swap swaps the values at indices i and j in place. +func (x int64Slice) Swap(i, j int) { x[i], x[j] = x[j], x[i] } + +// Sort is a convenience method: x.Sort() calls Sort(x). +func (x int64Slice) Sort() { sort.Sort(x) } diff --git a/cel/optimizer.go b/cel/optimizer.go new file mode 100644 index 000000000..a2023ed7f --- /dev/null +++ b/cel/optimizer.go @@ -0,0 +1,317 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/google/cel-go/common" + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types/ref" +) + +// StaticOptimizer contains a sequence of ASTOptimizer instances which will be applied in order. +// +// The static optimizer normalizes expression ids and type-checking run between optimization +// passes to ensure that the final optimized output is a valid expression with metadata consistent +// with what would have been generated from a parsed and checked expression. +// +// Note: source position information is best-effort and likely wrong, but optimized expressions +// should be suitable for calls to parser.Unparse. +type StaticOptimizer struct { + optimizers []ASTOptimizer +} + +// NewStaticOptimizer creates a StaticOptimizer with a sequence of ASTOptimizer's to be applied +// to a checked expression. +func NewStaticOptimizer(optimizers ...ASTOptimizer) *StaticOptimizer { + return &StaticOptimizer{ + optimizers: optimizers, + } +} + +// Optimize applies a sequence of optimizations to an Ast within a given environment. +// +// If issues are encountered, the Issues.Err() return value will be non-nil. +func (opt *StaticOptimizer) Optimize(env *Env, a *Ast) (*Ast, *Issues) { + // Make a copy of the AST to be optimized. + optimized := ast.Copy(a.impl) + + // Create the optimizer context, could be pooled in the future. + issues := NewIssues(common.NewErrors(a.Source())) + ids := newMonotonicIDGen(ast.MaxID(a.impl)) + fac := &optimizerExprFactory{ + nextID: ids.nextID, + renumberID: ids.renumberID, + fac: ast.NewExprFactory(), + sourceInfo: optimized.SourceInfo(), + } + ctx := &OptimizerContext{ + optimizerExprFactory: fac, + Env: env, + Issues: issues, + } + + // Apply the optimizations sequentially. + for _, o := range opt.optimizers { + optimized = o.Optimize(ctx, optimized) + if issues.Err() != nil { + return nil, issues + } + // Normalize expression id metadata including coordination with macro call metadata. + normalizeIDs(env, optimized) + + // Recheck the updated expression for any possible type-agreement or validation errors. + parsed := &Ast{ + source: a.Source(), + impl: ast.NewAST(optimized.Expr(), optimized.SourceInfo())} + checked, iss := ctx.Check(parsed) + if iss.Err() != nil { + return nil, iss + } + optimized = checked.impl + } + + // Return the optimized result. + return &Ast{ + source: a.Source(), + impl: optimized, + }, nil +} + +// normalizeIDs ensures that the metadata present with an AST is reset in a manner such +// that the ids within the expression correspond to the ids within macros. +func normalizeIDs(e *Env, optimized *ast.AST) { + ids := newStableIDGen() + optimized.Expr().RenumberIDs(ids.renumberID) + allExprMap := make(map[int64]ast.Expr) + ast.PostOrderVisit(optimized.Expr(), ast.NewExprVisitor(func(e ast.Expr) { + allExprMap[e.ID()] = e + })) + info := optimized.SourceInfo() + + // First, update the macro call ids themselves. + for id, call := range info.MacroCalls() { + info.ClearMacroCall(id) + callID := ids.renumberID(id) + if e, found := allExprMap[callID]; found && e.Kind() == ast.LiteralKind { + continue + } + info.SetMacroCall(callID, call) + } + + // Second, update the macro call id references to ensure that macro pointers are' + // updated consistently across macros. + for id, call := range info.MacroCalls() { + call.RenumberIDs(ids.renumberID) + resetMacroCall(optimized, call, allExprMap) + info.SetMacroCall(id, call) + } +} + +func resetMacroCall(optimized *ast.AST, call ast.Expr, allExprMap map[int64]ast.Expr) { + modified := []ast.Expr{} + ast.PostOrderVisit(call, ast.NewExprVisitor(func(e ast.Expr) { + if _, found := allExprMap[e.ID()]; found { + modified = append(modified, e) + } + })) + for _, m := range modified { + updated := allExprMap[m.ID()] + m.SetKindCase(updated) + } +} + +// newMonotonicIDGen increments numbers from an initial seed value. +func newMonotonicIDGen(seed int64) *monotonicIDGenerator { + return &monotonicIDGenerator{seed: seed} +} + +type monotonicIDGenerator struct { + seed int64 +} + +func (gen *monotonicIDGenerator) nextID() int64 { + gen.seed++ + return gen.seed +} + +func (gen *monotonicIDGenerator) renumberID(int64) int64 { + return gen.nextID() +} + +// newStableIDGen ensures that new ids are only created the first time they are encountered. +func newStableIDGen() *stableIDGenerator { + return &stableIDGenerator{ + idMap: make(map[int64]int64), + } +} + +type stableIDGenerator struct { + idMap map[int64]int64 + nextID int64 +} + +func (gen *stableIDGenerator) renumberID(id int64) int64 { + if id == 0 { + return 0 + } + if newID, found := gen.idMap[id]; found { + return newID + } + gen.nextID++ + gen.idMap[id] = gen.nextID + return gen.nextID +} + +// OptimizerContext embeds Env and Issues instances to make it easy to type-check and evaluate +// subexpressions and report any errors encountered along the way. The context also embeds the +// optimizerExprFactory which can be used to generate new sub-expressions with expression ids +// consistent with the expectations of a parsed expression. +type OptimizerContext struct { + *Env + *optimizerExprFactory + *Issues +} + +// ASTOptimizer applies an optimization over an AST and returns the optimized result. +type ASTOptimizer interface { + // Optimize optimizes a type-checked AST within an Environment and accumulates any issues. + Optimize(*OptimizerContext, *ast.AST) *ast.AST +} + +type optimizerExprFactory struct { + nextID func() int64 + renumberID ast.IDGenerator + fac ast.ExprFactory + sourceInfo *ast.SourceInfo +} + +// NewCall creates a global function call invocation expression. +// +// Example: +// +// countByField(list, fieldName) +// - function: countByField +// - args: [list, fieldName] +func (opt *optimizerExprFactory) NewCall(function string, args ...ast.Expr) ast.Expr { + return opt.fac.NewCall(opt.nextID(), function, args...) +} + +// NewMemberCall creates a member function call invocation expression where 'target' is the receiver of the call. +// +// Example: +// +// list.countByField(fieldName) +// - function: countByField +// - target: list +// - args: [fieldName] +func (opt *optimizerExprFactory) NewMemberCall(function string, target ast.Expr, args ...ast.Expr) ast.Expr { + return opt.fac.NewMemberCall(opt.nextID(), function, target, args...) +} + +// NewIdent creates a new identifier expression. +// +// Examples: +// +// - simple_var_name +// - qualified.subpackage.var_name +func (opt *optimizerExprFactory) NewIdent(name string) ast.Expr { + return opt.fac.NewIdent(opt.nextID(), name) +} + +// NewLiteral creates a new literal expression value. +// +// The range of valid values for a literal generated during optimization is different than for expressions +// generated via parsing / type-checking, as the ref.Val may be _any_ CEL value so long as the value can +// be converted back to a literal-like form. +func (opt *optimizerExprFactory) NewLiteral(value ref.Val) ast.Expr { + return opt.fac.NewLiteral(opt.nextID(), value) +} + +// NewList creates a list expression with a set of optional indices. +// +// Examples: +// +// [a, b] +// - elems: [a, b] +// - optIndices: [] +// +// [a, ?b, ?c] +// - elems: [a, b, c] +// - optIndices: [1, 2] +func (opt *optimizerExprFactory) NewList(elems []ast.Expr, optIndices []int32) ast.Expr { + return opt.fac.NewList(opt.nextID(), elems, optIndices) +} + +// NewMap creates a map from a set of entry expressions which contain a key and value expression. +func (opt *optimizerExprFactory) NewMap(entries []ast.EntryExpr) ast.Expr { + return opt.fac.NewMap(opt.nextID(), entries) +} + +// NewMapEntry creates a map entry with a key and value expression and a flag to indicate whether the +// entry is optional. +// +// Examples: +// +// {a: b} +// - key: a +// - value: b +// - optional: false +// +// {?a: ?b} +// - key: a +// - value: b +// - optional: true +func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional) +} + +// NewSelect creates a select expression where a field value is selected from an operand. +// +// Example: +// +// msg.field_name +// - operand: msg +// - field: field_name +func (opt *optimizerExprFactory) NewSelect(operand ast.Expr, field string) ast.Expr { + return opt.fac.NewSelect(opt.nextID(), operand, field) +} + +// NewStruct creates a new typed struct value with an set of field initializations. +// +// Example: +// +// pkg.TypeName{field: value} +// - typeName: pkg.TypeName +// - fields: [{field: value}] +func (opt *optimizerExprFactory) NewStruct(typeName string, fields []ast.EntryExpr) ast.Expr { + return opt.fac.NewStruct(opt.nextID(), typeName, fields) +} + +// NewStructField creates a struct field initialization. +// +// Examples: +// +// {count: 3u} +// - field: count +// - value: 3u +// - optional: false +// +// {?count: x} +// - field: count +// - value: x +// - optional: true +func (opt *optimizerExprFactory) NewStructField(field string, value ast.Expr, isOptional bool) ast.EntryExpr { + return opt.fac.NewStructField(opt.nextID(), field, value, isOptional) +} diff --git a/common/ast/ast.go b/common/ast/ast.go index 7610b4676..c3620eb95 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -249,10 +249,16 @@ func (s *SourceInfo) GetMacroCall(id int64) (Expr, bool) { // SetMacroCall records a macro call at a specific location. func (s *SourceInfo) SetMacroCall(id int64, e Expr) { - if s == nil { - return + if s != nil { + s.macroCalls[id] = e + } +} + +// ClearMacroCall removes the macro call at the given expression id. +func (s *SourceInfo) ClearMacroCall(id int64) { + if s != nil { + delete(s.macroCalls, id) } - s.macroCalls[id] = e } // OffsetRanges returns a map of expression id to OffsetRange values where the range indicates either: @@ -407,6 +413,7 @@ func (r *ReferenceInfo) Equals(other *ReferenceInfo) bool { type maxIDVisitor struct { maxID int64 + *baseVisitor } // VisitExpr updates the max identifier if the incoming expression id is greater than previously observed. diff --git a/common/ast/expr.go b/common/ast/expr.go index 5811e3957..c9d88bbaa 100644 --- a/common/ast/expr.go +++ b/common/ast/expr.go @@ -184,6 +184,9 @@ type ListExpr interface { // OptionalIndicies returns the list of optional indices in the list literal. OptionalIndices() []int32 + // IsOptional indicates whether the given element index is optional. + IsOptional(int32) bool + // Size returns the number of elements in the list. Size() int @@ -404,9 +407,14 @@ func (e *expr) SetKindCase(other Expr) { e.exprKindCase = baseIdentExpr(other.AsIdent()) case ListKind: l := other.AsList() + optIndexMap := make(map[int32]struct{}, len(l.OptionalIndices())) + for _, idx := range l.OptionalIndices() { + optIndexMap[idx] = struct{}{} + } e.exprKindCase = &baseListExpr{ - elements: l.Elements(), - optIndices: l.OptionalIndices(), + elements: l.Elements(), + optIndices: l.OptionalIndices(), + optIndexMap: optIndexMap, } case LiteralKind: e.exprKindCase = &baseLiteral{Val: other.AsLiteral()} @@ -591,8 +599,9 @@ func (*baseLiteral) isExpr() {} var _ ListExpr = &baseListExpr{} type baseListExpr struct { - elements []Expr - optIndices []int32 + elements []Expr + optIndices []int32 + optIndexMap map[int32]struct{} } func (*baseListExpr) Kind() ExprKind { @@ -606,6 +615,11 @@ func (e *baseListExpr) Elements() []Expr { return e.elements } +func (e *baseListExpr) IsOptional(index int32) bool { + _, found := e.optIndexMap[index] + return found +} + func (e *baseListExpr) OptionalIndices() []int32 { if e == nil { return []int32{} diff --git a/common/ast/factory.go b/common/ast/factory.go index 0111c289a..b7f36e72a 100644 --- a/common/ast/factory.go +++ b/common/ast/factory.go @@ -137,7 +137,16 @@ func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr { } func (fac *baseExprFactory) NewList(id int64, elems []Expr, optIndices []int32) Expr { - return fac.newExpr(id, &baseListExpr{elements: elems, optIndices: optIndices}) + optIndexMap := make(map[int32]struct{}, len(optIndices)) + for _, idx := range optIndices { + optIndexMap[idx] = struct{}{} + } + return fac.newExpr(id, + &baseListExpr{ + elements: elems, + optIndices: optIndices, + optIndexMap: optIndexMap, + }) } func (fac *baseExprFactory) NewMap(id int64, entries []EntryExpr) Expr { diff --git a/common/ast/navigable.go b/common/ast/navigable.go index 2836b5655..f5ddf6aac 100644 --- a/common/ast/navigable.go +++ b/common/ast/navigable.go @@ -423,6 +423,10 @@ func (l navigableListImpl) Elements() []Expr { return elems } +func (l navigableListImpl) IsOptional(index int32) bool { + return l.Expr.AsList().IsOptional(index) +} + func (l navigableListImpl) OptionalIndices() []int32 { return l.Expr.AsList().OptionalIndices() } From bfccebd9086672bb5cbd18e7ef9c9973cca6eb50 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Thu, 31 Aug 2023 16:11:16 -0700 Subject: [PATCH 28/31] Fix functional exemptions for homogeneous literal checks (#832) --- cel/validator.go | 7 +++---- ext/strings_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/cel/validator.go b/cel/validator.go index a8a86a746..b50c67452 100644 --- a/cel/validator.go +++ b/cel/validator.go @@ -304,7 +304,8 @@ func (v homogeneousAggregateLiteralValidator) Validate(_ *Env, c ValidatorConfig } func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool { - if parent, found := e.Parent(); found { + parent, found := e.Parent() + for found { if parent.Kind() == ast.CallKind { fnName := parent.AsCall().FunctionName() for _, exempt := range exemptFunctions { @@ -313,9 +314,7 @@ func inExemptFunction(e ast.NavigableExpr, exemptFunctions []string) bool { } } } - if parent.Kind() == ast.ListKind || parent.Kind() == ast.MapKind { - return inExemptFunction(parent, exemptFunctions) - } + parent, found = parent.Parent() } return false } diff --git a/ext/strings_test.go b/ext/strings_test.go index 138dd554c..fc0ade115 100644 --- a/ext/strings_test.go +++ b/ext/strings_test.go @@ -1318,6 +1318,50 @@ func TestStringFormat(t *testing.T) { } } +func TestStringFormatHeterogeneousLiterals(t *testing.T) { + tests := []struct { + expr string + out string + }{ + { + expr: `"list: %s".format([[[1, 2, [3.0, 4]]]])`, + out: `list: [[1, 2, [3.000000, 4]]]`, + }, + { + expr: `"list size: %d".format([[[1, 2, [3.0, 4]]].size()])`, + out: `list size: 1`, + }, + { + expr: `"list element: %s".format([[[1, 2, [3.0, 4]]][0]])`, + out: `list element: [1, 2, [3.000000, 4]]`, + }, + } + env, err := cel.NewEnv(Strings(), cel.ASTValidators(cel.ValidateHomogeneousAggregateLiterals())) + if err != nil { + t.Fatalf("cel.NewEnv() failed: %v", err) + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile(%q) failed: %v", tc.expr, iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(cel.NoVars()) + if err != nil { + t.Fatalf("Eval() failed: %v", err) + } + if out.Value() != tc.out { + t.Errorf("Eval() got %v, wanted %v", out, tc.out) + } + }) + } +} + func TestBadLocale(t *testing.T) { _, err := cel.NewEnv(Strings(StringsLocale("bad-locale"))) if err != nil { From 4eebcf38782990c8636f14111f75c6b241c505b0 Mon Sep 17 00:00:00 2001 From: swh Date: Fri, 1 Sep 2023 10:47:30 -0700 Subject: [PATCH 29/31] Fix logical operator folding that only involve literals (#833) --- cel/folding.go | 5 +++++ cel/folding_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/cel/folding.go b/cel/folding.go index 5903d732c..025b2b047 100644 --- a/cel/folding.go +++ b/cel/folding.go @@ -206,6 +206,11 @@ func maybeShortcircuitLogic(ctx *OptimizerContext, function string, args []ast.E return true } } + if len(newArgs) == 0 { + newArgs = append(newArgs, args[0]) + expr.SetKindCase(newArgs[0]) + return true + } if len(newArgs) == 1 { expr.SetKindCase(newArgs[0]) return true diff --git a/cel/folding_test.go b/cel/folding_test.go index 18fe4d58f..18451ef6d 100644 --- a/cel/folding_test.go +++ b/cel/folding_test.go @@ -197,6 +197,30 @@ func TestConstantFoldingOptimizer(t *testing.T) { expr: `false || x || false || x`, folded: `x || x`, }, + { + expr: `true && true`, + folded: `true`, + }, + { + expr: `true && false`, + folded: `false`, + }, + { + expr: `true || false`, + folded: `true`, + }, + { + expr: `false || false`, + folded: `false`, + }, + { + expr: `true && false || true`, + folded: `true`, + }, + { + expr: `false && true || false`, + folded: `false`, + }, { expr: `null`, folded: `null`, From 8943046ee066d9fefea77ce856f175a987271925 Mon Sep 17 00:00:00 2001 From: swh Date: Fri, 1 Sep 2023 10:47:46 -0700 Subject: [PATCH 30/31] Upgrade go-genproto to latest (#831) * Upgrade go-genproto to latest See https://github.com/googleapis/go-genproto/issues/1015. * Update WORKSPACE --- WORKSPACE | 26 +- go.mod | 8 +- go.sum | 16 +- .../protobuf/encoding/protojson/encode.go | 14 +- .../protobuf/encoding/prototext/encode.go | 14 +- .../protobuf/internal/encoding/json/encode.go | 10 +- .../protobuf/internal/encoding/text/encode.go | 10 +- .../protobuf/internal/genid/descriptor_gen.go | 48 + .../protobuf/internal/genid/type_gen.go | 6 + .../protobuf/internal/order/order.go | 2 +- .../protobuf/internal/version/version.go | 2 +- .../google.golang.org/protobuf/proto/size.go | 10 +- .../reflect/protoreflect/source_gen.go | 27 + .../types/descriptorpb/descriptor.pb.go | 1011 +++++++++++------ .../protobuf/types/dynamicpb/types.go | 177 +++ .../protobuf/types/known/anypb/any.pb.go | 70 +- .../types/known/structpb/struct.pb.go | 2 +- .../types/known/timestamppb/timestamp.pb.go | 2 +- vendor/modules.txt | 8 +- 19 files changed, 998 insertions(+), 465 deletions(-) create mode 100644 vendor/google.golang.org/protobuf/types/dynamicpb/types.go diff --git a/WORKSPACE b/WORKSPACE index 8a993c368..ad3a996b6 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -30,13 +30,13 @@ http_archive( ], ) -# googleapis as of 05/26/2023 +# googleapis as of 08/31/2023 http_archive( name = "com_google_googleapis", - strip_prefix = "googleapis-07c27163ac591955d736f3057b1619ece66f5b99", - sha256 = "bd8e735d881fb829751ecb1a77038dda4a8d274c45490cb9fcf004583ee10571", + sha256 = "5c56500adf7b1b7a3a2ee5ca5b77500617ad80afb808e3d3979f582e64c0523d", + strip_prefix = "googleapis-25f99371444ea7fd0dc1523ca6925e91cc48a664", urls = [ - "https://github.com/googleapis/googleapis/archive/07c27163ac591955d736f3057b1619ece66f5b99.tar.gz", + "https://github.com/googleapis/googleapis/archive/25f99371444ea7fd0dc1523ca6925e91cc48a664.tar.gz", ], ) @@ -68,22 +68,22 @@ go_repository( tag = "v1.28.1", ) -# Generated Google APIs protos for Golang 05/25/2023 +# Generated Google APIs protos for Golang 08/03/2023 go_repository( name = "org_golang_google_genproto_googleapis_api", build_file_proto_mode = "disable_global", importpath = "google.golang.org/genproto/googleapis/api", - sum = "h1:m8v1xLLLzMe1m5P+gCTF8nJB9epwZQUBERm20Oy1poQ=", - version = "v0.0.0-20230525234035-dd9d682886f9", + sum = "h1:nIgk/EEq3/YlnmVVXVnm14rC2oxgs1o0ong4sD/rd44=", + version = "v0.0.0-20230803162519-f966b187b2e5", ) -# Generated Google APIs protos for Golang 05/25/2023 +# Generated Google APIs protos for Golang 08/03/2023 go_repository( name = "org_golang_google_genproto_googleapis_rpc", build_file_proto_mode = "disable_global", importpath = "google.golang.org/genproto/googleapis/rpc", - sum = "h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM=", - version = "v0.0.0-20230525234030-28d5490b6b19", + sum = "h1:eSaPbMR4T7WfH9FvABk36NBMacoTUKdWCvV0dx+KfOg=", + version = "v0.0.0-20230803162519-f966b187b2e5", ) # gRPC deps for v1.49.0 (including x/text and x/net) @@ -119,8 +119,8 @@ go_repository( # CEL Spec deps v0.9.0 go_repository( name = "com_google_cel_spec", - importpath = "github.com/google/cel-spec", commit = "51af45e2b75a8aa2b3108b00f0e91cd172cfbea1", + importpath = "github.com/google/cel-spec", ) # strcase deps @@ -134,8 +134,8 @@ go_repository( # Readline for repl go_repository( name = "com_github_chzyer_readline", + commit = "62c6fe6193755f722b8b8788aa7357be55a50ff1", # v1.4 importpath = "github.com/chzyer/readline", - commit = "62c6fe6193755f722b8b8788aa7357be55a50ff1" # v1.4 ) # golang.org/x/exp deps @@ -143,7 +143,7 @@ go_repository( name = "org_golang_x_exp", importpath = "golang.org/x/exp", sum = "h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA=", - version = "v0.0.0-20220722155223-a9213eeb770e" + version = "v0.0.0-20220722155223-a9213eeb770e", ) # Run the dependencies at the end. These will silently try to import some diff --git a/go.mod b/go.mod index c62be054a..8145ea8e0 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,12 @@ go 1.18 require ( github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230305170008-8188dc5388df github.com/stoewer/go-strcase v1.2.0 - golang.org/x/text v0.8.0 - google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9 - google.golang.org/protobuf v1.30.0 + golang.org/x/text v0.9.0 + google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 + google.golang.org/protobuf v1.31.0 ) require ( golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 // indirect ) diff --git a/go.sum b/go.sum index 8858cc570..f2ccff699 100644 --- a/go.sum +++ b/go.sum @@ -14,16 +14,16 @@ github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA= golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9 h1:m8v1xLLLzMe1m5P+gCTF8nJB9epwZQUBERm20Oy1poQ= -google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9/go.mod h1:vHYtlOoi6TsQ3Uk2yxR7NI5z8uoV+3pZtR4jmHIkRig= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 h1:0nDDozoAU19Qb2HwhXadU8OcsiO/09cnTqhUtq2MEOM= -google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19/go.mod h1:66JfowdXAEgad5O9NnYcsNPLCPZJD++2L9X0PCMODrA= +google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 h1:nIgk/EEq3/YlnmVVXVnm14rC2oxgs1o0ong4sD/rd44= +google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5/go.mod h1:5DZzOUPCLYL3mNkQ0ms0F3EuUNZ7py1Bqeq6sxzI7/Q= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 h1:eSaPbMR4T7WfH9FvABk36NBMacoTUKdWCvV0dx+KfOg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5/go.mod h1:zBEcrKX2ZOcEkHWxBPAIvYUWOKKMIhYcmNiUIu2ji3I= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= -google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= +google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/vendor/google.golang.org/protobuf/encoding/protojson/encode.go b/vendor/google.golang.org/protobuf/encoding/protojson/encode.go index d09d22e13..66b95870e 100644 --- a/vendor/google.golang.org/protobuf/encoding/protojson/encode.go +++ b/vendor/google.golang.org/protobuf/encoding/protojson/encode.go @@ -106,13 +106,19 @@ func (o MarshalOptions) Format(m proto.Message) string { // MarshalOptions. Do not depend on the output being stable. It may change over // time across different versions of the program. func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) { - return o.marshal(m) + return o.marshal(nil, m) +} + +// MarshalAppend appends the JSON format encoding of m to b, +// returning the result. +func (o MarshalOptions) MarshalAppend(b []byte, m proto.Message) ([]byte, error) { + return o.marshal(b, m) } // marshal is a centralized function that all marshal operations go through. // For profiling purposes, avoid changing the name of this function or // introducing other code paths for marshal that do not go through this. -func (o MarshalOptions) marshal(m proto.Message) ([]byte, error) { +func (o MarshalOptions) marshal(b []byte, m proto.Message) ([]byte, error) { if o.Multiline && o.Indent == "" { o.Indent = defaultIndent } @@ -120,7 +126,7 @@ func (o MarshalOptions) marshal(m proto.Message) ([]byte, error) { o.Resolver = protoregistry.GlobalTypes } - internalEnc, err := json.NewEncoder(o.Indent) + internalEnc, err := json.NewEncoder(b, o.Indent) if err != nil { return nil, err } @@ -128,7 +134,7 @@ func (o MarshalOptions) marshal(m proto.Message) ([]byte, error) { // Treat nil message interface as an empty message, // in which case the output in an empty JSON object. if m == nil { - return []byte("{}"), nil + return append(b, '{', '}'), nil } enc := encoder{internalEnc, o} diff --git a/vendor/google.golang.org/protobuf/encoding/prototext/encode.go b/vendor/google.golang.org/protobuf/encoding/prototext/encode.go index ebf6c6528..722a7b41d 100644 --- a/vendor/google.golang.org/protobuf/encoding/prototext/encode.go +++ b/vendor/google.golang.org/protobuf/encoding/prototext/encode.go @@ -101,13 +101,19 @@ func (o MarshalOptions) Format(m proto.Message) string { // MarshalOptions object. Do not depend on the output being stable. It may // change over time across different versions of the program. func (o MarshalOptions) Marshal(m proto.Message) ([]byte, error) { - return o.marshal(m) + return o.marshal(nil, m) +} + +// MarshalAppend appends the textproto format encoding of m to b, +// returning the result. +func (o MarshalOptions) MarshalAppend(b []byte, m proto.Message) ([]byte, error) { + return o.marshal(b, m) } // marshal is a centralized function that all marshal operations go through. // For profiling purposes, avoid changing the name of this function or // introducing other code paths for marshal that do not go through this. -func (o MarshalOptions) marshal(m proto.Message) ([]byte, error) { +func (o MarshalOptions) marshal(b []byte, m proto.Message) ([]byte, error) { var delims = [2]byte{'{', '}'} if o.Multiline && o.Indent == "" { @@ -117,7 +123,7 @@ func (o MarshalOptions) marshal(m proto.Message) ([]byte, error) { o.Resolver = protoregistry.GlobalTypes } - internalEnc, err := text.NewEncoder(o.Indent, delims, o.EmitASCII) + internalEnc, err := text.NewEncoder(b, o.Indent, delims, o.EmitASCII) if err != nil { return nil, err } @@ -125,7 +131,7 @@ func (o MarshalOptions) marshal(m proto.Message) ([]byte, error) { // Treat nil message interface as an empty message, // in which case there is nothing to output. if m == nil { - return []byte{}, nil + return b, nil } enc := encoder{internalEnc, o} diff --git a/vendor/google.golang.org/protobuf/internal/encoding/json/encode.go b/vendor/google.golang.org/protobuf/internal/encoding/json/encode.go index fbdf34873..934f2dcb3 100644 --- a/vendor/google.golang.org/protobuf/internal/encoding/json/encode.go +++ b/vendor/google.golang.org/protobuf/internal/encoding/json/encode.go @@ -41,8 +41,10 @@ type Encoder struct { // // If indent is a non-empty string, it causes every entry for an Array or Object // to be preceded by the indent and trailed by a newline. -func NewEncoder(indent string) (*Encoder, error) { - e := &Encoder{} +func NewEncoder(buf []byte, indent string) (*Encoder, error) { + e := &Encoder{ + out: buf, + } if len(indent) > 0 { if strings.Trim(indent, " \t") != "" { return nil, errors.New("indent may only be composed of space or tab characters") @@ -176,13 +178,13 @@ func appendFloat(out []byte, n float64, bitSize int) []byte { // WriteInt writes out the given signed integer in JSON number value. func (e *Encoder) WriteInt(n int64) { e.prepareNext(scalar) - e.out = append(e.out, strconv.FormatInt(n, 10)...) + e.out = strconv.AppendInt(e.out, n, 10) } // WriteUint writes out the given unsigned integer in JSON number value. func (e *Encoder) WriteUint(n uint64) { e.prepareNext(scalar) - e.out = append(e.out, strconv.FormatUint(n, 10)...) + e.out = strconv.AppendUint(e.out, n, 10) } // StartObject writes out the '{' symbol. diff --git a/vendor/google.golang.org/protobuf/internal/encoding/text/encode.go b/vendor/google.golang.org/protobuf/internal/encoding/text/encode.go index da289ccce..cf7aed77b 100644 --- a/vendor/google.golang.org/protobuf/internal/encoding/text/encode.go +++ b/vendor/google.golang.org/protobuf/internal/encoding/text/encode.go @@ -53,8 +53,10 @@ type encoderState struct { // If outputASCII is true, strings will be serialized in such a way that // multi-byte UTF-8 sequences are escaped. This property ensures that the // overall output is ASCII (as opposed to UTF-8). -func NewEncoder(indent string, delims [2]byte, outputASCII bool) (*Encoder, error) { - e := &Encoder{} +func NewEncoder(buf []byte, indent string, delims [2]byte, outputASCII bool) (*Encoder, error) { + e := &Encoder{ + encoderState: encoderState{out: buf}, + } if len(indent) > 0 { if strings.Trim(indent, " \t") != "" { return nil, errors.New("indent may only be composed of space and tab characters") @@ -195,13 +197,13 @@ func appendFloat(out []byte, n float64, bitSize int) []byte { // WriteInt writes out the given signed integer value. func (e *Encoder) WriteInt(n int64) { e.prepareNext(scalar) - e.out = append(e.out, strconv.FormatInt(n, 10)...) + e.out = strconv.AppendInt(e.out, n, 10) } // WriteUint writes out the given unsigned integer value. func (e *Encoder) WriteUint(n uint64) { e.prepareNext(scalar) - e.out = append(e.out, strconv.FormatUint(n, 10)...) + e.out = strconv.AppendUint(e.out, n, 10) } // WriteLiteral writes out the given string as a literal value without quotes. diff --git a/vendor/google.golang.org/protobuf/internal/genid/descriptor_gen.go b/vendor/google.golang.org/protobuf/internal/genid/descriptor_gen.go index 5c0e8f73f..136f1b215 100644 --- a/vendor/google.golang.org/protobuf/internal/genid/descriptor_gen.go +++ b/vendor/google.golang.org/protobuf/internal/genid/descriptor_gen.go @@ -183,13 +183,58 @@ const ( // Field names for google.protobuf.ExtensionRangeOptions. const ( ExtensionRangeOptions_UninterpretedOption_field_name protoreflect.Name = "uninterpreted_option" + ExtensionRangeOptions_Declaration_field_name protoreflect.Name = "declaration" + ExtensionRangeOptions_Verification_field_name protoreflect.Name = "verification" ExtensionRangeOptions_UninterpretedOption_field_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.uninterpreted_option" + ExtensionRangeOptions_Declaration_field_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.declaration" + ExtensionRangeOptions_Verification_field_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.verification" ) // Field numbers for google.protobuf.ExtensionRangeOptions. const ( ExtensionRangeOptions_UninterpretedOption_field_number protoreflect.FieldNumber = 999 + ExtensionRangeOptions_Declaration_field_number protoreflect.FieldNumber = 2 + ExtensionRangeOptions_Verification_field_number protoreflect.FieldNumber = 3 +) + +// Full and short names for google.protobuf.ExtensionRangeOptions.VerificationState. +const ( + ExtensionRangeOptions_VerificationState_enum_fullname = "google.protobuf.ExtensionRangeOptions.VerificationState" + ExtensionRangeOptions_VerificationState_enum_name = "VerificationState" +) + +// Names for google.protobuf.ExtensionRangeOptions.Declaration. +const ( + ExtensionRangeOptions_Declaration_message_name protoreflect.Name = "Declaration" + ExtensionRangeOptions_Declaration_message_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.Declaration" +) + +// Field names for google.protobuf.ExtensionRangeOptions.Declaration. +const ( + ExtensionRangeOptions_Declaration_Number_field_name protoreflect.Name = "number" + ExtensionRangeOptions_Declaration_FullName_field_name protoreflect.Name = "full_name" + ExtensionRangeOptions_Declaration_Type_field_name protoreflect.Name = "type" + ExtensionRangeOptions_Declaration_IsRepeated_field_name protoreflect.Name = "is_repeated" + ExtensionRangeOptions_Declaration_Reserved_field_name protoreflect.Name = "reserved" + ExtensionRangeOptions_Declaration_Repeated_field_name protoreflect.Name = "repeated" + + ExtensionRangeOptions_Declaration_Number_field_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.Declaration.number" + ExtensionRangeOptions_Declaration_FullName_field_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.Declaration.full_name" + ExtensionRangeOptions_Declaration_Type_field_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.Declaration.type" + ExtensionRangeOptions_Declaration_IsRepeated_field_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.Declaration.is_repeated" + ExtensionRangeOptions_Declaration_Reserved_field_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.Declaration.reserved" + ExtensionRangeOptions_Declaration_Repeated_field_fullname protoreflect.FullName = "google.protobuf.ExtensionRangeOptions.Declaration.repeated" +) + +// Field numbers for google.protobuf.ExtensionRangeOptions.Declaration. +const ( + ExtensionRangeOptions_Declaration_Number_field_number protoreflect.FieldNumber = 1 + ExtensionRangeOptions_Declaration_FullName_field_number protoreflect.FieldNumber = 2 + ExtensionRangeOptions_Declaration_Type_field_number protoreflect.FieldNumber = 3 + ExtensionRangeOptions_Declaration_IsRepeated_field_number protoreflect.FieldNumber = 4 + ExtensionRangeOptions_Declaration_Reserved_field_number protoreflect.FieldNumber = 5 + ExtensionRangeOptions_Declaration_Repeated_field_number protoreflect.FieldNumber = 6 ) // Names for google.protobuf.FieldDescriptorProto. @@ -540,6 +585,7 @@ const ( FieldOptions_DebugRedact_field_name protoreflect.Name = "debug_redact" FieldOptions_Retention_field_name protoreflect.Name = "retention" FieldOptions_Target_field_name protoreflect.Name = "target" + FieldOptions_Targets_field_name protoreflect.Name = "targets" FieldOptions_UninterpretedOption_field_name protoreflect.Name = "uninterpreted_option" FieldOptions_Ctype_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.ctype" @@ -552,6 +598,7 @@ const ( FieldOptions_DebugRedact_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.debug_redact" FieldOptions_Retention_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.retention" FieldOptions_Target_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.target" + FieldOptions_Targets_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.targets" FieldOptions_UninterpretedOption_field_fullname protoreflect.FullName = "google.protobuf.FieldOptions.uninterpreted_option" ) @@ -567,6 +614,7 @@ const ( FieldOptions_DebugRedact_field_number protoreflect.FieldNumber = 16 FieldOptions_Retention_field_number protoreflect.FieldNumber = 17 FieldOptions_Target_field_number protoreflect.FieldNumber = 18 + FieldOptions_Targets_field_number protoreflect.FieldNumber = 19 FieldOptions_UninterpretedOption_field_number protoreflect.FieldNumber = 999 ) diff --git a/vendor/google.golang.org/protobuf/internal/genid/type_gen.go b/vendor/google.golang.org/protobuf/internal/genid/type_gen.go index 3bc710138..e0f75fea0 100644 --- a/vendor/google.golang.org/protobuf/internal/genid/type_gen.go +++ b/vendor/google.golang.org/protobuf/internal/genid/type_gen.go @@ -32,6 +32,7 @@ const ( Type_Options_field_name protoreflect.Name = "options" Type_SourceContext_field_name protoreflect.Name = "source_context" Type_Syntax_field_name protoreflect.Name = "syntax" + Type_Edition_field_name protoreflect.Name = "edition" Type_Name_field_fullname protoreflect.FullName = "google.protobuf.Type.name" Type_Fields_field_fullname protoreflect.FullName = "google.protobuf.Type.fields" @@ -39,6 +40,7 @@ const ( Type_Options_field_fullname protoreflect.FullName = "google.protobuf.Type.options" Type_SourceContext_field_fullname protoreflect.FullName = "google.protobuf.Type.source_context" Type_Syntax_field_fullname protoreflect.FullName = "google.protobuf.Type.syntax" + Type_Edition_field_fullname protoreflect.FullName = "google.protobuf.Type.edition" ) // Field numbers for google.protobuf.Type. @@ -49,6 +51,7 @@ const ( Type_Options_field_number protoreflect.FieldNumber = 4 Type_SourceContext_field_number protoreflect.FieldNumber = 5 Type_Syntax_field_number protoreflect.FieldNumber = 6 + Type_Edition_field_number protoreflect.FieldNumber = 7 ) // Names for google.protobuf.Field. @@ -121,12 +124,14 @@ const ( Enum_Options_field_name protoreflect.Name = "options" Enum_SourceContext_field_name protoreflect.Name = "source_context" Enum_Syntax_field_name protoreflect.Name = "syntax" + Enum_Edition_field_name protoreflect.Name = "edition" Enum_Name_field_fullname protoreflect.FullName = "google.protobuf.Enum.name" Enum_Enumvalue_field_fullname protoreflect.FullName = "google.protobuf.Enum.enumvalue" Enum_Options_field_fullname protoreflect.FullName = "google.protobuf.Enum.options" Enum_SourceContext_field_fullname protoreflect.FullName = "google.protobuf.Enum.source_context" Enum_Syntax_field_fullname protoreflect.FullName = "google.protobuf.Enum.syntax" + Enum_Edition_field_fullname protoreflect.FullName = "google.protobuf.Enum.edition" ) // Field numbers for google.protobuf.Enum. @@ -136,6 +141,7 @@ const ( Enum_Options_field_number protoreflect.FieldNumber = 3 Enum_SourceContext_field_number protoreflect.FieldNumber = 4 Enum_Syntax_field_number protoreflect.FieldNumber = 5 + Enum_Edition_field_number protoreflect.FieldNumber = 6 ) // Names for google.protobuf.EnumValue. diff --git a/vendor/google.golang.org/protobuf/internal/order/order.go b/vendor/google.golang.org/protobuf/internal/order/order.go index 33745ed06..dea522e12 100644 --- a/vendor/google.golang.org/protobuf/internal/order/order.go +++ b/vendor/google.golang.org/protobuf/internal/order/order.go @@ -33,7 +33,7 @@ var ( return !inOneof(ox) && inOneof(oy) } // Fields in disjoint oneof sets are sorted by declaration index. - if ox != nil && oy != nil && ox != oy { + if inOneof(ox) && inOneof(oy) && ox != oy { return ox.Index() < oy.Index() } // Fields sorted by field number. diff --git a/vendor/google.golang.org/protobuf/internal/version/version.go b/vendor/google.golang.org/protobuf/internal/version/version.go index f7014cd51..0999f29d5 100644 --- a/vendor/google.golang.org/protobuf/internal/version/version.go +++ b/vendor/google.golang.org/protobuf/internal/version/version.go @@ -51,7 +51,7 @@ import ( // 10. Send out the CL for review and submit it. const ( Major = 1 - Minor = 30 + Minor = 31 Patch = 0 PreRelease = "" ) diff --git a/vendor/google.golang.org/protobuf/proto/size.go b/vendor/google.golang.org/protobuf/proto/size.go index 554b9c6c0..f1692b49b 100644 --- a/vendor/google.golang.org/protobuf/proto/size.go +++ b/vendor/google.golang.org/protobuf/proto/size.go @@ -73,23 +73,27 @@ func (o MarshalOptions) sizeField(fd protoreflect.FieldDescriptor, value protore } func (o MarshalOptions) sizeList(num protowire.Number, fd protoreflect.FieldDescriptor, list protoreflect.List) (size int) { + sizeTag := protowire.SizeTag(num) + if fd.IsPacked() && list.Len() > 0 { content := 0 for i, llen := 0, list.Len(); i < llen; i++ { content += o.sizeSingular(num, fd.Kind(), list.Get(i)) } - return protowire.SizeTag(num) + protowire.SizeBytes(content) + return sizeTag + protowire.SizeBytes(content) } for i, llen := 0, list.Len(); i < llen; i++ { - size += protowire.SizeTag(num) + o.sizeSingular(num, fd.Kind(), list.Get(i)) + size += sizeTag + o.sizeSingular(num, fd.Kind(), list.Get(i)) } return size } func (o MarshalOptions) sizeMap(num protowire.Number, fd protoreflect.FieldDescriptor, mapv protoreflect.Map) (size int) { + sizeTag := protowire.SizeTag(num) + mapv.Range(func(key protoreflect.MapKey, value protoreflect.Value) bool { - size += protowire.SizeTag(num) + size += sizeTag size += protowire.SizeBytes(o.sizeField(fd.MapKey(), key.Value()) + o.sizeField(fd.MapValue(), value)) return true }) diff --git a/vendor/google.golang.org/protobuf/reflect/protoreflect/source_gen.go b/vendor/google.golang.org/protobuf/reflect/protoreflect/source_gen.go index 54ce326df..717b106f3 100644 --- a/vendor/google.golang.org/protobuf/reflect/protoreflect/source_gen.go +++ b/vendor/google.golang.org/protobuf/reflect/protoreflect/source_gen.go @@ -363,6 +363,8 @@ func (p *SourcePath) appendFieldOptions(b []byte) []byte { b = p.appendSingularField(b, "retention", nil) case 18: b = p.appendSingularField(b, "target", nil) + case 19: + b = p.appendRepeatedField(b, "targets", nil) case 999: b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption) } @@ -418,6 +420,10 @@ func (p *SourcePath) appendExtensionRangeOptions(b []byte) []byte { switch (*p)[0] { case 999: b = p.appendRepeatedField(b, "uninterpreted_option", (*SourcePath).appendUninterpretedOption) + case 2: + b = p.appendRepeatedField(b, "declaration", (*SourcePath).appendExtensionRangeOptions_Declaration) + case 3: + b = p.appendSingularField(b, "verification", nil) } return b } @@ -473,3 +479,24 @@ func (p *SourcePath) appendUninterpretedOption_NamePart(b []byte) []byte { } return b } + +func (p *SourcePath) appendExtensionRangeOptions_Declaration(b []byte) []byte { + if len(*p) == 0 { + return b + } + switch (*p)[0] { + case 1: + b = p.appendSingularField(b, "number", nil) + case 2: + b = p.appendSingularField(b, "full_name", nil) + case 3: + b = p.appendSingularField(b, "type", nil) + case 4: + b = p.appendSingularField(b, "is_repeated", nil) + case 5: + b = p.appendSingularField(b, "reserved", nil) + case 6: + b = p.appendSingularField(b, "repeated", nil) + } + return b +} diff --git a/vendor/google.golang.org/protobuf/types/descriptorpb/descriptor.pb.go b/vendor/google.golang.org/protobuf/types/descriptorpb/descriptor.pb.go index dac5671db..04c00f737 100644 --- a/vendor/google.golang.org/protobuf/types/descriptorpb/descriptor.pb.go +++ b/vendor/google.golang.org/protobuf/types/descriptorpb/descriptor.pb.go @@ -48,6 +48,64 @@ import ( sync "sync" ) +// The verification state of the extension range. +type ExtensionRangeOptions_VerificationState int32 + +const ( + // All the extensions of the range must be declared. + ExtensionRangeOptions_DECLARATION ExtensionRangeOptions_VerificationState = 0 + ExtensionRangeOptions_UNVERIFIED ExtensionRangeOptions_VerificationState = 1 +) + +// Enum value maps for ExtensionRangeOptions_VerificationState. +var ( + ExtensionRangeOptions_VerificationState_name = map[int32]string{ + 0: "DECLARATION", + 1: "UNVERIFIED", + } + ExtensionRangeOptions_VerificationState_value = map[string]int32{ + "DECLARATION": 0, + "UNVERIFIED": 1, + } +) + +func (x ExtensionRangeOptions_VerificationState) Enum() *ExtensionRangeOptions_VerificationState { + p := new(ExtensionRangeOptions_VerificationState) + *p = x + return p +} + +func (x ExtensionRangeOptions_VerificationState) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (ExtensionRangeOptions_VerificationState) Descriptor() protoreflect.EnumDescriptor { + return file_google_protobuf_descriptor_proto_enumTypes[0].Descriptor() +} + +func (ExtensionRangeOptions_VerificationState) Type() protoreflect.EnumType { + return &file_google_protobuf_descriptor_proto_enumTypes[0] +} + +func (x ExtensionRangeOptions_VerificationState) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Do not use. +func (x *ExtensionRangeOptions_VerificationState) UnmarshalJSON(b []byte) error { + num, err := protoimpl.X.UnmarshalJSONEnum(x.Descriptor(), b) + if err != nil { + return err + } + *x = ExtensionRangeOptions_VerificationState(num) + return nil +} + +// Deprecated: Use ExtensionRangeOptions_VerificationState.Descriptor instead. +func (ExtensionRangeOptions_VerificationState) EnumDescriptor() ([]byte, []int) { + return file_google_protobuf_descriptor_proto_rawDescGZIP(), []int{3, 0} +} + type FieldDescriptorProto_Type int32 const ( @@ -137,11 +195,11 @@ func (x FieldDescriptorProto_Type) String() string { } func (FieldDescriptorProto_Type) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_descriptor_proto_enumTypes[0].Descriptor() + return file_google_protobuf_descriptor_proto_enumTypes[1].Descriptor() } func (FieldDescriptorProto_Type) Type() protoreflect.EnumType { - return &file_google_protobuf_descriptor_proto_enumTypes[0] + return &file_google_protobuf_descriptor_proto_enumTypes[1] } func (x FieldDescriptorProto_Type) Number() protoreflect.EnumNumber { @@ -197,11 +255,11 @@ func (x FieldDescriptorProto_Label) String() string { } func (FieldDescriptorProto_Label) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_descriptor_proto_enumTypes[1].Descriptor() + return file_google_protobuf_descriptor_proto_enumTypes[2].Descriptor() } func (FieldDescriptorProto_Label) Type() protoreflect.EnumType { - return &file_google_protobuf_descriptor_proto_enumTypes[1] + return &file_google_protobuf_descriptor_proto_enumTypes[2] } func (x FieldDescriptorProto_Label) Number() protoreflect.EnumNumber { @@ -258,11 +316,11 @@ func (x FileOptions_OptimizeMode) String() string { } func (FileOptions_OptimizeMode) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_descriptor_proto_enumTypes[2].Descriptor() + return file_google_protobuf_descriptor_proto_enumTypes[3].Descriptor() } func (FileOptions_OptimizeMode) Type() protoreflect.EnumType { - return &file_google_protobuf_descriptor_proto_enumTypes[2] + return &file_google_protobuf_descriptor_proto_enumTypes[3] } func (x FileOptions_OptimizeMode) Number() protoreflect.EnumNumber { @@ -288,7 +346,13 @@ type FieldOptions_CType int32 const ( // Default mode. - FieldOptions_STRING FieldOptions_CType = 0 + FieldOptions_STRING FieldOptions_CType = 0 + // The option [ctype=CORD] may be applied to a non-repeated field of type + // "bytes". It indicates that in C++, the data should be stored in a Cord + // instead of a string. For very large strings, this may reduce memory + // fragmentation. It may also allow better performance when parsing from a + // Cord, or when parsing with aliasing enabled, as the parsed Cord may then + // alias the original buffer. FieldOptions_CORD FieldOptions_CType = 1 FieldOptions_STRING_PIECE FieldOptions_CType = 2 ) @@ -318,11 +382,11 @@ func (x FieldOptions_CType) String() string { } func (FieldOptions_CType) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_descriptor_proto_enumTypes[3].Descriptor() + return file_google_protobuf_descriptor_proto_enumTypes[4].Descriptor() } func (FieldOptions_CType) Type() protoreflect.EnumType { - return &file_google_protobuf_descriptor_proto_enumTypes[3] + return &file_google_protobuf_descriptor_proto_enumTypes[4] } func (x FieldOptions_CType) Number() protoreflect.EnumNumber { @@ -380,11 +444,11 @@ func (x FieldOptions_JSType) String() string { } func (FieldOptions_JSType) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_descriptor_proto_enumTypes[4].Descriptor() + return file_google_protobuf_descriptor_proto_enumTypes[5].Descriptor() } func (FieldOptions_JSType) Type() protoreflect.EnumType { - return &file_google_protobuf_descriptor_proto_enumTypes[4] + return &file_google_protobuf_descriptor_proto_enumTypes[5] } func (x FieldOptions_JSType) Number() protoreflect.EnumNumber { @@ -442,11 +506,11 @@ func (x FieldOptions_OptionRetention) String() string { } func (FieldOptions_OptionRetention) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_descriptor_proto_enumTypes[5].Descriptor() + return file_google_protobuf_descriptor_proto_enumTypes[6].Descriptor() } func (FieldOptions_OptionRetention) Type() protoreflect.EnumType { - return &file_google_protobuf_descriptor_proto_enumTypes[5] + return &file_google_protobuf_descriptor_proto_enumTypes[6] } func (x FieldOptions_OptionRetention) Number() protoreflect.EnumNumber { @@ -526,11 +590,11 @@ func (x FieldOptions_OptionTargetType) String() string { } func (FieldOptions_OptionTargetType) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_descriptor_proto_enumTypes[6].Descriptor() + return file_google_protobuf_descriptor_proto_enumTypes[7].Descriptor() } func (FieldOptions_OptionTargetType) Type() protoreflect.EnumType { - return &file_google_protobuf_descriptor_proto_enumTypes[6] + return &file_google_protobuf_descriptor_proto_enumTypes[7] } func (x FieldOptions_OptionTargetType) Number() protoreflect.EnumNumber { @@ -588,11 +652,11 @@ func (x MethodOptions_IdempotencyLevel) String() string { } func (MethodOptions_IdempotencyLevel) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_descriptor_proto_enumTypes[7].Descriptor() + return file_google_protobuf_descriptor_proto_enumTypes[8].Descriptor() } func (MethodOptions_IdempotencyLevel) Type() protoreflect.EnumType { - return &file_google_protobuf_descriptor_proto_enumTypes[7] + return &file_google_protobuf_descriptor_proto_enumTypes[8] } func (x MethodOptions_IdempotencyLevel) Number() protoreflect.EnumNumber { @@ -652,11 +716,11 @@ func (x GeneratedCodeInfo_Annotation_Semantic) String() string { } func (GeneratedCodeInfo_Annotation_Semantic) Descriptor() protoreflect.EnumDescriptor { - return file_google_protobuf_descriptor_proto_enumTypes[8].Descriptor() + return file_google_protobuf_descriptor_proto_enumTypes[9].Descriptor() } func (GeneratedCodeInfo_Annotation_Semantic) Type() protoreflect.EnumType { - return &file_google_protobuf_descriptor_proto_enumTypes[8] + return &file_google_protobuf_descriptor_proto_enumTypes[9] } func (x GeneratedCodeInfo_Annotation_Semantic) Number() protoreflect.EnumNumber { @@ -1015,7 +1079,21 @@ type ExtensionRangeOptions struct { // The parser stores options it doesn't recognize here. See above. UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` -} + // go/protobuf-stripping-extension-declarations + // Like Metadata, but we use a repeated field to hold all extension + // declarations. This should avoid the size increases of transforming a large + // extension range into small ranges in generated binaries. + Declaration []*ExtensionRangeOptions_Declaration `protobuf:"bytes,2,rep,name=declaration" json:"declaration,omitempty"` + // The verification state of the range. + // TODO(b/278783756): flip the default to DECLARATION once all empty ranges + // are marked as UNVERIFIED. + Verification *ExtensionRangeOptions_VerificationState `protobuf:"varint,3,opt,name=verification,enum=google.protobuf.ExtensionRangeOptions_VerificationState,def=1" json:"verification,omitempty"` +} + +// Default values for ExtensionRangeOptions fields. +const ( + Default_ExtensionRangeOptions_Verification = ExtensionRangeOptions_UNVERIFIED +) func (x *ExtensionRangeOptions) Reset() { *x = ExtensionRangeOptions{} @@ -1056,6 +1134,20 @@ func (x *ExtensionRangeOptions) GetUninterpretedOption() []*UninterpretedOption return nil } +func (x *ExtensionRangeOptions) GetDeclaration() []*ExtensionRangeOptions_Declaration { + if x != nil { + return x.Declaration + } + return nil +} + +func (x *ExtensionRangeOptions) GetVerification() ExtensionRangeOptions_VerificationState { + if x != nil && x.Verification != nil { + return *x.Verification + } + return Default_ExtensionRangeOptions_Verification +} + // Describes a field within a message. type FieldDescriptorProto struct { state protoimpl.MessageState @@ -2046,8 +2138,10 @@ type FieldOptions struct { // The ctype option instructs the C++ code generator to use a different // representation of the field than it normally would. See the specific - // options below. This option is not yet implemented in the open source - // release -- sorry, we'll try to include it in a future version! + // options below. This option is only implemented to support use of + // [ctype=CORD] and [ctype=STRING] (the default) on non-repeated fields of + // type "bytes" in the open source release -- sorry, we'll try to include + // other types in a future version! Ctype *FieldOptions_CType `protobuf:"varint,1,opt,name=ctype,enum=google.protobuf.FieldOptions_CType,def=0" json:"ctype,omitempty"` // The packed option can be enabled for repeated primitive fields to enable // a more efficient representation on the wire. Rather than repeatedly @@ -2111,9 +2205,11 @@ type FieldOptions struct { Weak *bool `protobuf:"varint,10,opt,name=weak,def=0" json:"weak,omitempty"` // Indicate that the field value should not be printed out when using debug // formats, e.g. when the field contains sensitive credentials. - DebugRedact *bool `protobuf:"varint,16,opt,name=debug_redact,json=debugRedact,def=0" json:"debug_redact,omitempty"` - Retention *FieldOptions_OptionRetention `protobuf:"varint,17,opt,name=retention,enum=google.protobuf.FieldOptions_OptionRetention" json:"retention,omitempty"` - Target *FieldOptions_OptionTargetType `protobuf:"varint,18,opt,name=target,enum=google.protobuf.FieldOptions_OptionTargetType" json:"target,omitempty"` + DebugRedact *bool `protobuf:"varint,16,opt,name=debug_redact,json=debugRedact,def=0" json:"debug_redact,omitempty"` + Retention *FieldOptions_OptionRetention `protobuf:"varint,17,opt,name=retention,enum=google.protobuf.FieldOptions_OptionRetention" json:"retention,omitempty"` + // Deprecated: Marked as deprecated in google/protobuf/descriptor.proto. + Target *FieldOptions_OptionTargetType `protobuf:"varint,18,opt,name=target,enum=google.protobuf.FieldOptions_OptionTargetType" json:"target,omitempty"` + Targets []FieldOptions_OptionTargetType `protobuf:"varint,19,rep,name=targets,enum=google.protobuf.FieldOptions_OptionTargetType" json:"targets,omitempty"` // The parser stores options it doesn't recognize here. See above. UninterpretedOption []*UninterpretedOption `protobuf:"bytes,999,rep,name=uninterpreted_option,json=uninterpretedOption" json:"uninterpreted_option,omitempty"` } @@ -2224,6 +2320,7 @@ func (x *FieldOptions) GetRetention() FieldOptions_OptionRetention { return FieldOptions_RETENTION_UNKNOWN } +// Deprecated: Marked as deprecated in google/protobuf/descriptor.proto. func (x *FieldOptions) GetTarget() FieldOptions_OptionTargetType { if x != nil && x.Target != nil { return *x.Target @@ -2231,6 +2328,13 @@ func (x *FieldOptions) GetTarget() FieldOptions_OptionTargetType { return FieldOptions_TARGET_TYPE_UNKNOWN } +func (x *FieldOptions) GetTargets() []FieldOptions_OptionTargetType { + if x != nil { + return x.Targets + } + return nil +} + func (x *FieldOptions) GetUninterpretedOption() []*UninterpretedOption { if x != nil { return x.UninterpretedOption @@ -2960,6 +3064,108 @@ func (x *DescriptorProto_ReservedRange) GetEnd() int32 { return 0 } +type ExtensionRangeOptions_Declaration struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // The extension number declared within the extension range. + Number *int32 `protobuf:"varint,1,opt,name=number" json:"number,omitempty"` + // The fully-qualified name of the extension field. There must be a leading + // dot in front of the full name. + FullName *string `protobuf:"bytes,2,opt,name=full_name,json=fullName" json:"full_name,omitempty"` + // The fully-qualified type name of the extension field. Unlike + // Metadata.type, Declaration.type must have a leading dot for messages + // and enums. + Type *string `protobuf:"bytes,3,opt,name=type" json:"type,omitempty"` + // Deprecated. Please use "repeated". + // + // Deprecated: Marked as deprecated in google/protobuf/descriptor.proto. + IsRepeated *bool `protobuf:"varint,4,opt,name=is_repeated,json=isRepeated" json:"is_repeated,omitempty"` + // If true, indicates that the number is reserved in the extension range, + // and any extension field with the number will fail to compile. Set this + // when a declared extension field is deleted. + Reserved *bool `protobuf:"varint,5,opt,name=reserved" json:"reserved,omitempty"` + // If true, indicates that the extension must be defined as repeated. + // Otherwise the extension must be defined as optional. + Repeated *bool `protobuf:"varint,6,opt,name=repeated" json:"repeated,omitempty"` +} + +func (x *ExtensionRangeOptions_Declaration) Reset() { + *x = ExtensionRangeOptions_Declaration{} + if protoimpl.UnsafeEnabled { + mi := &file_google_protobuf_descriptor_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ExtensionRangeOptions_Declaration) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ExtensionRangeOptions_Declaration) ProtoMessage() {} + +func (x *ExtensionRangeOptions_Declaration) ProtoReflect() protoreflect.Message { + mi := &file_google_protobuf_descriptor_proto_msgTypes[23] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ExtensionRangeOptions_Declaration.ProtoReflect.Descriptor instead. +func (*ExtensionRangeOptions_Declaration) Descriptor() ([]byte, []int) { + return file_google_protobuf_descriptor_proto_rawDescGZIP(), []int{3, 0} +} + +func (x *ExtensionRangeOptions_Declaration) GetNumber() int32 { + if x != nil && x.Number != nil { + return *x.Number + } + return 0 +} + +func (x *ExtensionRangeOptions_Declaration) GetFullName() string { + if x != nil && x.FullName != nil { + return *x.FullName + } + return "" +} + +func (x *ExtensionRangeOptions_Declaration) GetType() string { + if x != nil && x.Type != nil { + return *x.Type + } + return "" +} + +// Deprecated: Marked as deprecated in google/protobuf/descriptor.proto. +func (x *ExtensionRangeOptions_Declaration) GetIsRepeated() bool { + if x != nil && x.IsRepeated != nil { + return *x.IsRepeated + } + return false +} + +func (x *ExtensionRangeOptions_Declaration) GetReserved() bool { + if x != nil && x.Reserved != nil { + return *x.Reserved + } + return false +} + +func (x *ExtensionRangeOptions_Declaration) GetRepeated() bool { + if x != nil && x.Repeated != nil { + return *x.Repeated + } + return false +} + // Range of reserved numeric values. Reserved values may not be used by // entries in the same enum. Reserved ranges may not overlap. // @@ -2978,7 +3184,7 @@ type EnumDescriptorProto_EnumReservedRange struct { func (x *EnumDescriptorProto_EnumReservedRange) Reset() { *x = EnumDescriptorProto_EnumReservedRange{} if protoimpl.UnsafeEnabled { - mi := &file_google_protobuf_descriptor_proto_msgTypes[23] + mi := &file_google_protobuf_descriptor_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2991,7 +3197,7 @@ func (x *EnumDescriptorProto_EnumReservedRange) String() string { func (*EnumDescriptorProto_EnumReservedRange) ProtoMessage() {} func (x *EnumDescriptorProto_EnumReservedRange) ProtoReflect() protoreflect.Message { - mi := &file_google_protobuf_descriptor_proto_msgTypes[23] + mi := &file_google_protobuf_descriptor_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3038,7 +3244,7 @@ type UninterpretedOption_NamePart struct { func (x *UninterpretedOption_NamePart) Reset() { *x = UninterpretedOption_NamePart{} if protoimpl.UnsafeEnabled { - mi := &file_google_protobuf_descriptor_proto_msgTypes[24] + mi := &file_google_protobuf_descriptor_proto_msgTypes[25] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3051,7 +3257,7 @@ func (x *UninterpretedOption_NamePart) String() string { func (*UninterpretedOption_NamePart) ProtoMessage() {} func (x *UninterpretedOption_NamePart) ProtoReflect() protoreflect.Message { - mi := &file_google_protobuf_descriptor_proto_msgTypes[24] + mi := &file_google_protobuf_descriptor_proto_msgTypes[25] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3182,7 +3388,7 @@ type SourceCodeInfo_Location struct { func (x *SourceCodeInfo_Location) Reset() { *x = SourceCodeInfo_Location{} if protoimpl.UnsafeEnabled { - mi := &file_google_protobuf_descriptor_proto_msgTypes[25] + mi := &file_google_protobuf_descriptor_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3195,7 +3401,7 @@ func (x *SourceCodeInfo_Location) String() string { func (*SourceCodeInfo_Location) ProtoMessage() {} func (x *SourceCodeInfo_Location) ProtoReflect() protoreflect.Message { - mi := &file_google_protobuf_descriptor_proto_msgTypes[25] + mi := &file_google_protobuf_descriptor_proto_msgTypes[26] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3269,7 +3475,7 @@ type GeneratedCodeInfo_Annotation struct { func (x *GeneratedCodeInfo_Annotation) Reset() { *x = GeneratedCodeInfo_Annotation{} if protoimpl.UnsafeEnabled { - mi := &file_google_protobuf_descriptor_proto_msgTypes[26] + mi := &file_google_protobuf_descriptor_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -3282,7 +3488,7 @@ func (x *GeneratedCodeInfo_Annotation) String() string { func (*GeneratedCodeInfo_Annotation) ProtoMessage() {} func (x *GeneratedCodeInfo_Annotation) ProtoReflect() protoreflect.Message { - mi := &file_google_protobuf_descriptor_proto_msgTypes[26] + mi := &file_google_protobuf_descriptor_proto_msgTypes[27] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -3436,264 +3642,296 @@ var file_google_protobuf_descriptor_proto_rawDesc = []byte{ 0x65, 0x64, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x22, - 0x7c, 0x0a, 0x15, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x61, 0x6e, 0x67, - 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x58, 0x0a, 0x14, 0x75, 0x6e, 0x69, 0x6e, - 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x18, 0xe7, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x55, 0x6e, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x13, 0x75, - 0x6e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x2a, 0x09, 0x08, 0xe8, 0x07, 0x10, 0x80, 0x80, 0x80, 0x80, 0x02, 0x22, 0xc1, 0x06, - 0x0a, 0x14, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, - 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6e, 0x75, - 0x6d, 0x62, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x6e, 0x75, 0x6d, 0x62, - 0x65, 0x72, 0x12, 0x41, 0x0a, 0x05, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x0e, 0x32, 0x2b, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, - 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x05, - 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x3e, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x05, 0x20, - 0x01, 0x28, 0x0e, 0x32, 0x2a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0xad, 0x04, 0x0a, 0x15, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x61, 0x6e, + 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x58, 0x0a, 0x14, 0x75, 0x6e, 0x69, + 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0xe7, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x55, 0x6e, 0x69, 0x6e, 0x74, + 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x13, + 0x75, 0x6e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x12, 0x59, 0x0a, 0x0b, 0x64, 0x65, 0x63, 0x6c, 0x61, 0x72, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x32, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x78, 0x74, 0x65, 0x6e, + 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x2e, 0x44, 0x65, 0x63, 0x6c, 0x61, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x42, 0x03, 0x88, 0x01, + 0x02, 0x52, 0x0b, 0x64, 0x65, 0x63, 0x6c, 0x61, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x68, + 0x0a, 0x0c, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0e, 0x32, 0x38, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x78, 0x74, 0x65, 0x6e, 0x73, 0x69, 0x6f, 0x6e, + 0x52, 0x61, 0x6e, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x56, 0x65, 0x72, + 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x3a, 0x0a, + 0x55, 0x4e, 0x56, 0x45, 0x52, 0x49, 0x46, 0x49, 0x45, 0x44, 0x52, 0x0c, 0x76, 0x65, 0x72, 0x69, + 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x1a, 0xb3, 0x01, 0x0a, 0x0b, 0x44, 0x65, 0x63, + 0x6c, 0x61, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x6e, 0x75, 0x6d, 0x62, + 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x6e, 0x75, 0x6d, 0x62, 0x65, 0x72, + 0x12, 0x1b, 0x0a, 0x09, 0x66, 0x75, 0x6c, 0x6c, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x75, 0x6c, 0x6c, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, + 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x79, 0x70, + 0x65, 0x12, 0x23, 0x0a, 0x0b, 0x69, 0x73, 0x5f, 0x72, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x42, 0x02, 0x18, 0x01, 0x52, 0x0a, 0x69, 0x73, 0x52, 0x65, + 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x72, 0x65, 0x70, 0x65, 0x61, 0x74, 0x65, 0x64, 0x22, 0x34, + 0x0a, 0x11, 0x56, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x12, 0x0f, 0x0a, 0x0b, 0x44, 0x45, 0x43, 0x4c, 0x41, 0x52, 0x41, 0x54, 0x49, + 0x4f, 0x4e, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x55, 0x4e, 0x56, 0x45, 0x52, 0x49, 0x46, 0x49, + 0x45, 0x44, 0x10, 0x01, 0x2a, 0x09, 0x08, 0xe8, 0x07, 0x10, 0x80, 0x80, 0x80, 0x80, 0x02, 0x22, + 0xc1, 0x06, 0x0a, 0x14, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, + 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x16, 0x0a, 0x06, + 0x6e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x6e, 0x75, + 0x6d, 0x62, 0x65, 0x72, 0x12, 0x41, 0x0a, 0x05, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x2b, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x44, 0x65, 0x73, 0x63, 0x72, - 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, - 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x79, 0x70, 0x65, 0x5f, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x79, 0x70, 0x65, 0x4e, 0x61, - 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65, 0x65, 0x12, 0x23, - 0x0a, 0x0d, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x56, 0x61, - 0x6c, 0x75, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x5f, 0x69, 0x6e, 0x64, - 0x65, 0x78, 0x18, 0x09, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x49, - 0x6e, 0x64, 0x65, 0x78, 0x12, 0x1b, 0x0a, 0x09, 0x6a, 0x73, 0x6f, 0x6e, 0x5f, 0x6e, 0x61, 0x6d, - 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6a, 0x73, 0x6f, 0x6e, 0x4e, 0x61, 0x6d, - 0x65, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x27, 0x0a, 0x0f, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x18, 0x11, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x4f, 0x70, 0x74, 0x69, 0x6f, - 0x6e, 0x61, 0x6c, 0x22, 0xb6, 0x02, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0f, 0x0a, 0x0b, - 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x4f, 0x55, 0x42, 0x4c, 0x45, 0x10, 0x01, 0x12, 0x0e, 0x0a, - 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x46, 0x4c, 0x4f, 0x41, 0x54, 0x10, 0x02, 0x12, 0x0e, 0x0a, - 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x36, 0x34, 0x10, 0x03, 0x12, 0x0f, 0x0a, - 0x0b, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x49, 0x4e, 0x54, 0x36, 0x34, 0x10, 0x04, 0x12, 0x0e, - 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x33, 0x32, 0x10, 0x05, 0x12, 0x10, - 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x46, 0x49, 0x58, 0x45, 0x44, 0x36, 0x34, 0x10, 0x06, - 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x46, 0x49, 0x58, 0x45, 0x44, 0x33, 0x32, - 0x10, 0x07, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x42, 0x4f, 0x4f, 0x4c, 0x10, - 0x08, 0x12, 0x0f, 0x0a, 0x0b, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x52, 0x49, 0x4e, 0x47, - 0x10, 0x09, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x47, 0x52, 0x4f, 0x55, 0x50, - 0x10, 0x0a, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x45, 0x53, 0x53, 0x41, - 0x47, 0x45, 0x10, 0x0b, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x42, 0x59, 0x54, - 0x45, 0x53, 0x10, 0x0c, 0x12, 0x0f, 0x0a, 0x0b, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x49, 0x4e, - 0x54, 0x33, 0x32, 0x10, 0x0d, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x45, 0x4e, - 0x55, 0x4d, 0x10, 0x0e, 0x12, 0x11, 0x0a, 0x0d, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x46, 0x49, - 0x58, 0x45, 0x44, 0x33, 0x32, 0x10, 0x0f, 0x12, 0x11, 0x0a, 0x0d, 0x54, 0x59, 0x50, 0x45, 0x5f, - 0x53, 0x46, 0x49, 0x58, 0x45, 0x44, 0x36, 0x34, 0x10, 0x10, 0x12, 0x0f, 0x0a, 0x0b, 0x54, 0x59, - 0x50, 0x45, 0x5f, 0x53, 0x49, 0x4e, 0x54, 0x33, 0x32, 0x10, 0x11, 0x12, 0x0f, 0x0a, 0x0b, 0x54, - 0x59, 0x50, 0x45, 0x5f, 0x53, 0x49, 0x4e, 0x54, 0x36, 0x34, 0x10, 0x12, 0x22, 0x43, 0x0a, 0x05, - 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x0e, 0x4c, 0x41, 0x42, 0x45, 0x4c, 0x5f, 0x4f, - 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x41, 0x4c, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x4c, 0x41, 0x42, - 0x45, 0x4c, 0x5f, 0x52, 0x45, 0x51, 0x55, 0x49, 0x52, 0x45, 0x44, 0x10, 0x02, 0x12, 0x12, 0x0a, - 0x0e, 0x4c, 0x41, 0x42, 0x45, 0x4c, 0x5f, 0x52, 0x45, 0x50, 0x45, 0x41, 0x54, 0x45, 0x44, 0x10, - 0x03, 0x22, 0x63, 0x0a, 0x14, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, - 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, - 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x37, 0x0a, - 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, + 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x4c, 0x61, 0x62, 0x65, 0x6c, + 0x52, 0x05, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x3e, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x44, 0x65, 0x73, + 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, 0x79, 0x70, + 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x79, 0x70, 0x65, 0x5f, + 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x79, 0x70, 0x65, + 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65, 0x65, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x78, 0x74, 0x65, 0x6e, 0x64, 0x65, 0x65, + 0x12, 0x23, 0x0a, 0x0d, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, + 0x56, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x5f, 0x69, + 0x6e, 0x64, 0x65, 0x78, 0x18, 0x09, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x6f, 0x6e, 0x65, 0x6f, + 0x66, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x12, 0x1b, 0x0a, 0x09, 0x6a, 0x73, 0x6f, 0x6e, 0x5f, 0x6e, + 0x61, 0x6d, 0x65, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6a, 0x73, 0x6f, 0x6e, 0x4e, + 0x61, 0x6d, 0x65, 0x12, 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x08, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x27, 0x0a, 0x0f, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x18, + 0x11, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x22, 0xb6, 0x02, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0f, + 0x0a, 0x0b, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x4f, 0x55, 0x42, 0x4c, 0x45, 0x10, 0x01, 0x12, + 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x46, 0x4c, 0x4f, 0x41, 0x54, 0x10, 0x02, 0x12, + 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x36, 0x34, 0x10, 0x03, 0x12, + 0x0f, 0x0a, 0x0b, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x49, 0x4e, 0x54, 0x36, 0x34, 0x10, 0x04, + 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x49, 0x4e, 0x54, 0x33, 0x32, 0x10, 0x05, + 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x46, 0x49, 0x58, 0x45, 0x44, 0x36, 0x34, + 0x10, 0x06, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x46, 0x49, 0x58, 0x45, 0x44, + 0x33, 0x32, 0x10, 0x07, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x42, 0x4f, 0x4f, + 0x4c, 0x10, 0x08, 0x12, 0x0f, 0x0a, 0x0b, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x54, 0x52, 0x49, + 0x4e, 0x47, 0x10, 0x09, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x47, 0x52, 0x4f, + 0x55, 0x50, 0x10, 0x0a, 0x12, 0x10, 0x0a, 0x0c, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x4d, 0x45, 0x53, + 0x53, 0x41, 0x47, 0x45, 0x10, 0x0b, 0x12, 0x0e, 0x0a, 0x0a, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x42, + 0x59, 0x54, 0x45, 0x53, 0x10, 0x0c, 0x12, 0x0f, 0x0a, 0x0b, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, + 0x49, 0x4e, 0x54, 0x33, 0x32, 0x10, 0x0d, 0x12, 0x0d, 0x0a, 0x09, 0x54, 0x59, 0x50, 0x45, 0x5f, + 0x45, 0x4e, 0x55, 0x4d, 0x10, 0x0e, 0x12, 0x11, 0x0a, 0x0d, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, + 0x46, 0x49, 0x58, 0x45, 0x44, 0x33, 0x32, 0x10, 0x0f, 0x12, 0x11, 0x0a, 0x0d, 0x54, 0x59, 0x50, + 0x45, 0x5f, 0x53, 0x46, 0x49, 0x58, 0x45, 0x44, 0x36, 0x34, 0x10, 0x10, 0x12, 0x0f, 0x0a, 0x0b, + 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x49, 0x4e, 0x54, 0x33, 0x32, 0x10, 0x11, 0x12, 0x0f, 0x0a, + 0x0b, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x53, 0x49, 0x4e, 0x54, 0x36, 0x34, 0x10, 0x12, 0x22, 0x43, + 0x0a, 0x05, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x12, 0x0a, 0x0e, 0x4c, 0x41, 0x42, 0x45, 0x4c, + 0x5f, 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x41, 0x4c, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x4c, + 0x41, 0x42, 0x45, 0x4c, 0x5f, 0x52, 0x45, 0x51, 0x55, 0x49, 0x52, 0x45, 0x44, 0x10, 0x02, 0x12, + 0x12, 0x0a, 0x0e, 0x4c, 0x41, 0x42, 0x45, 0x4c, 0x5f, 0x52, 0x45, 0x50, 0x45, 0x41, 0x54, 0x45, + 0x44, 0x10, 0x03, 0x22, 0x63, 0x0a, 0x14, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x44, 0x65, 0x73, 0x63, + 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, + 0x37, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x75, 0x66, 0x2e, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, + 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xe3, 0x02, 0x0a, 0x13, 0x45, 0x6e, 0x75, + 0x6d, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x3f, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x44, + 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x52, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x36, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x5d, 0x0a, + 0x0e, 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x5f, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, + 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x36, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x44, 0x65, 0x73, 0x63, + 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x45, 0x6e, 0x75, 0x6d, + 0x52, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x0d, 0x72, + 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x23, 0x0a, 0x0d, + 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x05, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x4e, 0x61, 0x6d, + 0x65, 0x1a, 0x3b, 0x0a, 0x11, 0x45, 0x6e, 0x75, 0x6d, 0x52, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, + 0x64, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, + 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x22, 0x83, + 0x01, 0x0a, 0x18, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x44, 0x65, 0x73, 0x63, + 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, + 0x16, 0x0a, 0x06, 0x6e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x06, 0x6e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, 0x3b, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x56, + 0x61, 0x6c, 0x75, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xa7, 0x01, 0x0a, 0x16, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, + 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, + 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x12, 0x3e, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x02, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x26, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x44, 0x65, 0x73, 0x63, + 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x52, 0x06, 0x6d, 0x65, 0x74, + 0x68, 0x6f, 0x64, 0x12, 0x39, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x89, + 0x02, 0x0a, 0x15, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, + 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1d, 0x0a, 0x0a, + 0x69, 0x6e, 0x70, 0x75, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x09, 0x69, 0x6e, 0x70, 0x75, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, + 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0a, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x38, 0x0a, 0x07, + 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, + 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, + 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x30, 0x0a, 0x10, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x5f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, + 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, + 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x12, 0x30, 0x0a, 0x10, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x5f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0f, 0x73, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x22, 0x91, 0x09, 0x0a, 0x0b, 0x46, + 0x69, 0x6c, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x6a, 0x61, + 0x76, 0x61, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x6a, 0x61, 0x76, 0x61, 0x50, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x12, 0x30, 0x0a, + 0x14, 0x6a, 0x61, 0x76, 0x61, 0x5f, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x5f, 0x63, 0x6c, 0x61, 0x73, + 0x73, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x6a, 0x61, 0x76, + 0x61, 0x4f, 0x75, 0x74, 0x65, 0x72, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x6e, 0x61, 0x6d, 0x65, 0x12, + 0x35, 0x0a, 0x13, 0x6a, 0x61, 0x76, 0x61, 0x5f, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, + 0x5f, 0x66, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, + 0x6c, 0x73, 0x65, 0x52, 0x11, 0x6a, 0x61, 0x76, 0x61, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, + 0x65, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x12, 0x44, 0x0a, 0x1d, 0x6a, 0x61, 0x76, 0x61, 0x5f, 0x67, + 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x5f, 0x65, 0x71, 0x75, 0x61, 0x6c, 0x73, 0x5f, 0x61, + 0x6e, 0x64, 0x5f, 0x68, 0x61, 0x73, 0x68, 0x18, 0x14, 0x20, 0x01, 0x28, 0x08, 0x42, 0x02, 0x18, + 0x01, 0x52, 0x19, 0x6a, 0x61, 0x76, 0x61, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x45, + 0x71, 0x75, 0x61, 0x6c, 0x73, 0x41, 0x6e, 0x64, 0x48, 0x61, 0x73, 0x68, 0x12, 0x3a, 0x0a, 0x16, + 0x6a, 0x61, 0x76, 0x61, 0x5f, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x63, 0x68, 0x65, 0x63, + 0x6b, 0x5f, 0x75, 0x74, 0x66, 0x38, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, + 0x6c, 0x73, 0x65, 0x52, 0x13, 0x6a, 0x61, 0x76, 0x61, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x43, + 0x68, 0x65, 0x63, 0x6b, 0x55, 0x74, 0x66, 0x38, 0x12, 0x53, 0x0a, 0x0c, 0x6f, 0x70, 0x74, 0x69, + 0x6d, 0x69, 0x7a, 0x65, 0x5f, 0x66, 0x6f, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x29, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, - 0x2e, 0x4f, 0x6e, 0x65, 0x6f, 0x66, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, - 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0xe3, 0x02, 0x0a, 0x13, 0x45, 0x6e, 0x75, 0x6d, 0x44, - 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, - 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, - 0x6d, 0x65, 0x12, 0x3f, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x29, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x44, 0x65, 0x73, - 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x52, 0x05, 0x76, 0x61, - 0x6c, 0x75, 0x65, 0x12, 0x36, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x4f, 0x70, 0x74, 0x69, 0x6f, - 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x5d, 0x0a, 0x0e, 0x72, - 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x5f, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x04, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x36, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, - 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x52, 0x65, - 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x52, 0x0d, 0x72, 0x65, 0x73, - 0x65, 0x72, 0x76, 0x65, 0x64, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x65, - 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x05, 0x20, 0x03, 0x28, - 0x09, 0x52, 0x0c, 0x72, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x4e, 0x61, 0x6d, 0x65, 0x1a, - 0x3b, 0x0a, 0x11, 0x45, 0x6e, 0x75, 0x6d, 0x52, 0x65, 0x73, 0x65, 0x72, 0x76, 0x65, 0x64, 0x52, - 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, - 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x22, 0x83, 0x01, 0x0a, - 0x18, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, 0x6c, 0x75, 0x65, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, - 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, - 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x16, 0x0a, - 0x06, 0x6e, 0x75, 0x6d, 0x62, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x6e, - 0x75, 0x6d, 0x62, 0x65, 0x72, 0x12, 0x3b, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x45, 0x6e, 0x75, 0x6d, 0x56, 0x61, 0x6c, - 0x75, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, - 0x6e, 0x73, 0x22, 0xa7, 0x01, 0x0a, 0x16, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x44, 0x65, - 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x0a, - 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, - 0x65, 0x12, 0x3e, 0x0a, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x02, 0x20, 0x03, 0x28, - 0x0b, 0x32, 0x26, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, - 0x70, 0x74, 0x6f, 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x52, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, - 0x64, 0x12, 0x39, 0x0a, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x1f, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x89, 0x02, 0x0a, - 0x15, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, - 0x72, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x69, 0x6e, - 0x70, 0x75, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, - 0x69, 0x6e, 0x70, 0x75, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x6f, 0x75, 0x74, - 0x70, 0x75, 0x74, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, - 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x54, 0x79, 0x70, 0x65, 0x12, 0x38, 0x0a, 0x07, 0x6f, 0x70, - 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x67, 0x6f, - 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x4d, 0x65, - 0x74, 0x68, 0x6f, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x52, 0x07, 0x6f, 0x70, 0x74, - 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x30, 0x0a, 0x10, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, - 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, - 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x72, - 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x12, 0x30, 0x0a, 0x10, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x5f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, - 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, - 0x74, 0x72, 0x65, 0x61, 0x6d, 0x69, 0x6e, 0x67, 0x22, 0x91, 0x09, 0x0a, 0x0b, 0x46, 0x69, 0x6c, - 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x6a, 0x61, 0x76, 0x61, - 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, - 0x6a, 0x61, 0x76, 0x61, 0x50, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x12, 0x30, 0x0a, 0x14, 0x6a, - 0x61, 0x76, 0x61, 0x5f, 0x6f, 0x75, 0x74, 0x65, 0x72, 0x5f, 0x63, 0x6c, 0x61, 0x73, 0x73, 0x6e, - 0x61, 0x6d, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x6a, 0x61, 0x76, 0x61, 0x4f, - 0x75, 0x74, 0x65, 0x72, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x35, 0x0a, - 0x13, 0x6a, 0x61, 0x76, 0x61, 0x5f, 0x6d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x5f, 0x66, - 0x69, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, - 0x65, 0x52, 0x11, 0x6a, 0x61, 0x76, 0x61, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x70, 0x6c, 0x65, 0x46, - 0x69, 0x6c, 0x65, 0x73, 0x12, 0x44, 0x0a, 0x1d, 0x6a, 0x61, 0x76, 0x61, 0x5f, 0x67, 0x65, 0x6e, - 0x65, 0x72, 0x61, 0x74, 0x65, 0x5f, 0x65, 0x71, 0x75, 0x61, 0x6c, 0x73, 0x5f, 0x61, 0x6e, 0x64, - 0x5f, 0x68, 0x61, 0x73, 0x68, 0x18, 0x14, 0x20, 0x01, 0x28, 0x08, 0x42, 0x02, 0x18, 0x01, 0x52, - 0x19, 0x6a, 0x61, 0x76, 0x61, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x45, 0x71, 0x75, - 0x61, 0x6c, 0x73, 0x41, 0x6e, 0x64, 0x48, 0x61, 0x73, 0x68, 0x12, 0x3a, 0x0a, 0x16, 0x6a, 0x61, - 0x76, 0x61, 0x5f, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x5f, 0x63, 0x68, 0x65, 0x63, 0x6b, 0x5f, - 0x75, 0x74, 0x66, 0x38, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, - 0x65, 0x52, 0x13, 0x6a, 0x61, 0x76, 0x61, 0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x43, 0x68, 0x65, - 0x63, 0x6b, 0x55, 0x74, 0x66, 0x38, 0x12, 0x53, 0x0a, 0x0c, 0x6f, 0x70, 0x74, 0x69, 0x6d, 0x69, - 0x7a, 0x65, 0x5f, 0x66, 0x6f, 0x72, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x29, 0x2e, 0x67, - 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, - 0x69, 0x6c, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6d, - 0x69, 0x7a, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x3a, 0x05, 0x53, 0x50, 0x45, 0x45, 0x44, 0x52, 0x0b, - 0x6f, 0x70, 0x74, 0x69, 0x6d, 0x69, 0x7a, 0x65, 0x46, 0x6f, 0x72, 0x12, 0x1d, 0x0a, 0x0a, 0x67, - 0x6f, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x09, 0x67, 0x6f, 0x50, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x12, 0x35, 0x0a, 0x13, 0x63, 0x63, - 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x73, 0x18, 0x10, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x11, - 0x63, 0x63, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, - 0x73, 0x12, 0x39, 0x0a, 0x15, 0x6a, 0x61, 0x76, 0x61, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x69, - 0x63, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, 0x08, - 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x13, 0x6a, 0x61, 0x76, 0x61, 0x47, 0x65, 0x6e, - 0x65, 0x72, 0x69, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x12, 0x35, 0x0a, 0x13, - 0x70, 0x79, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, - 0x52, 0x11, 0x70, 0x79, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x73, 0x12, 0x37, 0x0a, 0x14, 0x70, 0x68, 0x70, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, - 0x69, 0x63, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x18, 0x2a, 0x20, 0x01, 0x28, - 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x12, 0x70, 0x68, 0x70, 0x47, 0x65, 0x6e, - 0x65, 0x72, 0x69, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x12, 0x25, 0x0a, 0x0a, - 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, 0x64, 0x18, 0x17, 0x20, 0x01, 0x28, 0x08, - 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0a, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, - 0x74, 0x65, 0x64, 0x12, 0x2e, 0x0a, 0x10, 0x63, 0x63, 0x5f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x5f, 0x61, 0x72, 0x65, 0x6e, 0x61, 0x73, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x04, 0x74, - 0x72, 0x75, 0x65, 0x52, 0x0e, 0x63, 0x63, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x72, 0x65, - 0x6e, 0x61, 0x73, 0x12, 0x2a, 0x0a, 0x11, 0x6f, 0x62, 0x6a, 0x63, 0x5f, 0x63, 0x6c, 0x61, 0x73, - 0x73, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, - 0x6f, 0x62, 0x6a, 0x63, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, - 0x29, 0x0a, 0x10, 0x63, 0x73, 0x68, 0x61, 0x72, 0x70, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, - 0x61, 0x63, 0x65, 0x18, 0x25, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x63, 0x73, 0x68, 0x61, 0x72, - 0x70, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x73, 0x77, - 0x69, 0x66, 0x74, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x27, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0b, 0x73, 0x77, 0x69, 0x66, 0x74, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x28, 0x0a, - 0x10, 0x70, 0x68, 0x70, 0x5f, 0x63, 0x6c, 0x61, 0x73, 0x73, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, - 0x78, 0x18, 0x28, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x70, 0x68, 0x70, 0x43, 0x6c, 0x61, 0x73, - 0x73, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x23, 0x0a, 0x0d, 0x70, 0x68, 0x70, 0x5f, 0x6e, - 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x29, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, - 0x70, 0x68, 0x70, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x34, 0x0a, 0x16, - 0x70, 0x68, 0x70, 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x5f, 0x6e, 0x61, 0x6d, - 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x2c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x14, 0x70, 0x68, - 0x70, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, - 0x63, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x75, 0x62, 0x79, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x61, - 0x67, 0x65, 0x18, 0x2d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x75, 0x62, 0x79, 0x50, 0x61, - 0x63, 0x6b, 0x61, 0x67, 0x65, 0x12, 0x58, 0x0a, 0x14, 0x75, 0x6e, 0x69, 0x6e, 0x74, 0x65, 0x72, - 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0xe7, 0x07, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x55, 0x6e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, - 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x13, 0x75, 0x6e, 0x69, 0x6e, - 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x22, - 0x3a, 0x0a, 0x0c, 0x4f, 0x70, 0x74, 0x69, 0x6d, 0x69, 0x7a, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x12, - 0x09, 0x0a, 0x05, 0x53, 0x50, 0x45, 0x45, 0x44, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, 0x43, 0x4f, - 0x44, 0x45, 0x5f, 0x53, 0x49, 0x5a, 0x45, 0x10, 0x02, 0x12, 0x10, 0x0a, 0x0c, 0x4c, 0x49, 0x54, - 0x45, 0x5f, 0x52, 0x55, 0x4e, 0x54, 0x49, 0x4d, 0x45, 0x10, 0x03, 0x2a, 0x09, 0x08, 0xe8, 0x07, - 0x10, 0x80, 0x80, 0x80, 0x80, 0x02, 0x4a, 0x04, 0x08, 0x26, 0x10, 0x27, 0x22, 0xbb, 0x03, 0x0a, - 0x0e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, - 0x3c, 0x0a, 0x17, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x74, 0x5f, 0x77, - 0x69, 0x72, 0x65, 0x5f, 0x66, 0x6f, 0x72, 0x6d, 0x61, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, - 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x14, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x53, 0x65, 0x74, 0x57, 0x69, 0x72, 0x65, 0x46, 0x6f, 0x72, 0x6d, 0x61, 0x74, 0x12, 0x4c, 0x0a, - 0x1f, 0x6e, 0x6f, 0x5f, 0x73, 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64, 0x5f, 0x64, 0x65, 0x73, - 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x5f, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f, 0x72, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x1c, 0x6e, - 0x6f, 0x53, 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64, 0x44, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, - 0x74, 0x6f, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f, 0x72, 0x12, 0x25, 0x0a, 0x0a, 0x64, - 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x3a, - 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0a, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, - 0x65, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x6d, 0x61, 0x70, 0x5f, 0x65, 0x6e, 0x74, 0x72, 0x79, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x6d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, - 0x56, 0x0a, 0x26, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x6c, 0x65, - 0x67, 0x61, 0x63, 0x79, 0x5f, 0x6a, 0x73, 0x6f, 0x6e, 0x5f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, - 0x63, 0x6f, 0x6e, 0x66, 0x6c, 0x69, 0x63, 0x74, 0x73, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x42, - 0x02, 0x18, 0x01, 0x52, 0x22, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, 0x64, 0x4c, - 0x65, 0x67, 0x61, 0x63, 0x79, 0x4a, 0x73, 0x6f, 0x6e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x43, 0x6f, - 0x6e, 0x66, 0x6c, 0x69, 0x63, 0x74, 0x73, 0x12, 0x58, 0x0a, 0x14, 0x75, 0x6e, 0x69, 0x6e, 0x74, + 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x4f, 0x70, 0x74, + 0x69, 0x6d, 0x69, 0x7a, 0x65, 0x4d, 0x6f, 0x64, 0x65, 0x3a, 0x05, 0x53, 0x50, 0x45, 0x45, 0x44, + 0x52, 0x0b, 0x6f, 0x70, 0x74, 0x69, 0x6d, 0x69, 0x7a, 0x65, 0x46, 0x6f, 0x72, 0x12, 0x1d, 0x0a, + 0x0a, 0x67, 0x6f, 0x5f, 0x70, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x18, 0x0b, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x09, 0x67, 0x6f, 0x50, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x12, 0x35, 0x0a, 0x13, + 0x63, 0x63, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x73, 0x18, 0x10, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, + 0x52, 0x11, 0x63, 0x63, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x73, 0x12, 0x39, 0x0a, 0x15, 0x6a, 0x61, 0x76, 0x61, 0x5f, 0x67, 0x65, 0x6e, 0x65, + 0x72, 0x69, 0x63, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x18, 0x11, 0x20, 0x01, + 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x13, 0x6a, 0x61, 0x76, 0x61, 0x47, + 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x12, 0x35, + 0x0a, 0x13, 0x70, 0x79, 0x5f, 0x67, 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x5f, 0x73, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, + 0x73, 0x65, 0x52, 0x11, 0x70, 0x79, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x73, 0x12, 0x37, 0x0a, 0x14, 0x70, 0x68, 0x70, 0x5f, 0x67, 0x65, 0x6e, + 0x65, 0x72, 0x69, 0x63, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x18, 0x2a, 0x20, + 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x12, 0x70, 0x68, 0x70, 0x47, + 0x65, 0x6e, 0x65, 0x72, 0x69, 0x63, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x12, 0x25, + 0x0a, 0x0a, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, 0x64, 0x18, 0x17, 0x20, 0x01, + 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0a, 0x64, 0x65, 0x70, 0x72, 0x65, + 0x63, 0x61, 0x74, 0x65, 0x64, 0x12, 0x2e, 0x0a, 0x10, 0x63, 0x63, 0x5f, 0x65, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x5f, 0x61, 0x72, 0x65, 0x6e, 0x61, 0x73, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x08, 0x3a, + 0x04, 0x74, 0x72, 0x75, 0x65, 0x52, 0x0e, 0x63, 0x63, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x41, + 0x72, 0x65, 0x6e, 0x61, 0x73, 0x12, 0x2a, 0x0a, 0x11, 0x6f, 0x62, 0x6a, 0x63, 0x5f, 0x63, 0x6c, + 0x61, 0x73, 0x73, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0f, 0x6f, 0x62, 0x6a, 0x63, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x50, 0x72, 0x65, 0x66, 0x69, + 0x78, 0x12, 0x29, 0x0a, 0x10, 0x63, 0x73, 0x68, 0x61, 0x72, 0x70, 0x5f, 0x6e, 0x61, 0x6d, 0x65, + 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x25, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x63, 0x73, 0x68, + 0x61, 0x72, 0x70, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x21, 0x0a, 0x0c, + 0x73, 0x77, 0x69, 0x66, 0x74, 0x5f, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, 0x27, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0b, 0x73, 0x77, 0x69, 0x66, 0x74, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, + 0x28, 0x0a, 0x10, 0x70, 0x68, 0x70, 0x5f, 0x63, 0x6c, 0x61, 0x73, 0x73, 0x5f, 0x70, 0x72, 0x65, + 0x66, 0x69, 0x78, 0x18, 0x28, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x70, 0x68, 0x70, 0x43, 0x6c, + 0x61, 0x73, 0x73, 0x50, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x23, 0x0a, 0x0d, 0x70, 0x68, 0x70, + 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x29, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0c, 0x70, 0x68, 0x70, 0x4e, 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x12, 0x34, + 0x0a, 0x16, 0x70, 0x68, 0x70, 0x5f, 0x6d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x5f, 0x6e, + 0x61, 0x6d, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x18, 0x2c, 0x20, 0x01, 0x28, 0x09, 0x52, 0x14, + 0x70, 0x68, 0x70, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0x4e, 0x61, 0x6d, 0x65, 0x73, + 0x70, 0x61, 0x63, 0x65, 0x12, 0x21, 0x0a, 0x0c, 0x72, 0x75, 0x62, 0x79, 0x5f, 0x70, 0x61, 0x63, + 0x6b, 0x61, 0x67, 0x65, 0x18, 0x2d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x72, 0x75, 0x62, 0x79, + 0x50, 0x61, 0x63, 0x6b, 0x61, 0x67, 0x65, 0x12, 0x58, 0x0a, 0x14, 0x75, 0x6e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0xe7, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x55, 0x6e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x13, 0x75, 0x6e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, - 0x6e, 0x2a, 0x09, 0x08, 0xe8, 0x07, 0x10, 0x80, 0x80, 0x80, 0x80, 0x02, 0x4a, 0x04, 0x08, 0x04, - 0x10, 0x05, 0x4a, 0x04, 0x08, 0x05, 0x10, 0x06, 0x4a, 0x04, 0x08, 0x06, 0x10, 0x07, 0x4a, 0x04, - 0x08, 0x08, 0x10, 0x09, 0x4a, 0x04, 0x08, 0x09, 0x10, 0x0a, 0x22, 0xb7, 0x08, 0x0a, 0x0c, 0x46, - 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x41, 0x0a, 0x05, 0x63, - 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x23, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, - 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x43, 0x54, 0x79, 0x70, 0x65, 0x3a, - 0x06, 0x53, 0x54, 0x52, 0x49, 0x4e, 0x47, 0x52, 0x05, 0x63, 0x74, 0x79, 0x70, 0x65, 0x12, 0x16, - 0x0a, 0x06, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, - 0x70, 0x61, 0x63, 0x6b, 0x65, 0x64, 0x12, 0x47, 0x0a, 0x06, 0x6a, 0x73, 0x74, 0x79, 0x70, 0x65, - 0x18, 0x06, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x24, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, - 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x4a, 0x53, 0x54, 0x79, 0x70, 0x65, 0x3a, 0x09, 0x4a, 0x53, - 0x5f, 0x4e, 0x4f, 0x52, 0x4d, 0x41, 0x4c, 0x52, 0x06, 0x6a, 0x73, 0x74, 0x79, 0x70, 0x65, 0x12, - 0x19, 0x0a, 0x04, 0x6c, 0x61, 0x7a, 0x79, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, - 0x61, 0x6c, 0x73, 0x65, 0x52, 0x04, 0x6c, 0x61, 0x7a, 0x79, 0x12, 0x2e, 0x0a, 0x0f, 0x75, 0x6e, - 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x65, 0x64, 0x5f, 0x6c, 0x61, 0x7a, 0x79, 0x18, 0x0f, 0x20, - 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0e, 0x75, 0x6e, 0x76, 0x65, - 0x72, 0x69, 0x66, 0x69, 0x65, 0x64, 0x4c, 0x61, 0x7a, 0x79, 0x12, 0x25, 0x0a, 0x0a, 0x64, 0x65, - 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, - 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0a, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, - 0x64, 0x12, 0x19, 0x0a, 0x04, 0x77, 0x65, 0x61, 0x6b, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x3a, - 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x04, 0x77, 0x65, 0x61, 0x6b, 0x12, 0x28, 0x0a, 0x0c, - 0x64, 0x65, 0x62, 0x75, 0x67, 0x5f, 0x72, 0x65, 0x64, 0x61, 0x63, 0x74, 0x18, 0x10, 0x20, 0x01, - 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0b, 0x64, 0x65, 0x62, 0x75, 0x67, - 0x52, 0x65, 0x64, 0x61, 0x63, 0x74, 0x12, 0x4b, 0x0a, 0x09, 0x72, 0x65, 0x74, 0x65, 0x6e, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2d, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, - 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, - 0x65, 0x74, 0x65, 0x6e, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x72, 0x65, 0x74, 0x65, 0x6e, 0x74, - 0x69, 0x6f, 0x6e, 0x12, 0x46, 0x0a, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x12, 0x20, - 0x01, 0x28, 0x0e, 0x32, 0x2e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, - 0x6e, 0x73, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x54, - 0x79, 0x70, 0x65, 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x58, 0x0a, 0x14, 0x75, + 0x6e, 0x22, 0x3a, 0x0a, 0x0c, 0x4f, 0x70, 0x74, 0x69, 0x6d, 0x69, 0x7a, 0x65, 0x4d, 0x6f, 0x64, + 0x65, 0x12, 0x09, 0x0a, 0x05, 0x53, 0x50, 0x45, 0x45, 0x44, 0x10, 0x01, 0x12, 0x0d, 0x0a, 0x09, + 0x43, 0x4f, 0x44, 0x45, 0x5f, 0x53, 0x49, 0x5a, 0x45, 0x10, 0x02, 0x12, 0x10, 0x0a, 0x0c, 0x4c, + 0x49, 0x54, 0x45, 0x5f, 0x52, 0x55, 0x4e, 0x54, 0x49, 0x4d, 0x45, 0x10, 0x03, 0x2a, 0x09, 0x08, + 0xe8, 0x07, 0x10, 0x80, 0x80, 0x80, 0x80, 0x02, 0x4a, 0x04, 0x08, 0x26, 0x10, 0x27, 0x22, 0xbb, + 0x03, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x12, 0x3c, 0x0a, 0x17, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x5f, 0x73, 0x65, 0x74, + 0x5f, 0x77, 0x69, 0x72, 0x65, 0x5f, 0x66, 0x6f, 0x72, 0x6d, 0x61, 0x74, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x14, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x53, 0x65, 0x74, 0x57, 0x69, 0x72, 0x65, 0x46, 0x6f, 0x72, 0x6d, 0x61, 0x74, 0x12, + 0x4c, 0x0a, 0x1f, 0x6e, 0x6f, 0x5f, 0x73, 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64, 0x5f, 0x64, + 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x6f, 0x72, 0x5f, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x6f, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, + 0x1c, 0x6e, 0x6f, 0x53, 0x74, 0x61, 0x6e, 0x64, 0x61, 0x72, 0x64, 0x44, 0x65, 0x73, 0x63, 0x72, + 0x69, 0x70, 0x74, 0x6f, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f, 0x72, 0x12, 0x25, 0x0a, + 0x0a, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0a, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, + 0x61, 0x74, 0x65, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x6d, 0x61, 0x70, 0x5f, 0x65, 0x6e, 0x74, 0x72, + 0x79, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x6d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, + 0x79, 0x12, 0x56, 0x0a, 0x26, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, 0x64, 0x5f, + 0x6c, 0x65, 0x67, 0x61, 0x63, 0x79, 0x5f, 0x6a, 0x73, 0x6f, 0x6e, 0x5f, 0x66, 0x69, 0x65, 0x6c, + 0x64, 0x5f, 0x63, 0x6f, 0x6e, 0x66, 0x6c, 0x69, 0x63, 0x74, 0x73, 0x18, 0x0b, 0x20, 0x01, 0x28, + 0x08, 0x42, 0x02, 0x18, 0x01, 0x52, 0x22, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, + 0x64, 0x4c, 0x65, 0x67, 0x61, 0x63, 0x79, 0x4a, 0x73, 0x6f, 0x6e, 0x46, 0x69, 0x65, 0x6c, 0x64, + 0x43, 0x6f, 0x6e, 0x66, 0x6c, 0x69, 0x63, 0x74, 0x73, 0x12, 0x58, 0x0a, 0x14, 0x75, 0x6e, 0x69, + 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0xe7, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x55, 0x6e, 0x69, 0x6e, 0x74, + 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x13, + 0x75, 0x6e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x2a, 0x09, 0x08, 0xe8, 0x07, 0x10, 0x80, 0x80, 0x80, 0x80, 0x02, 0x4a, 0x04, + 0x08, 0x04, 0x10, 0x05, 0x4a, 0x04, 0x08, 0x05, 0x10, 0x06, 0x4a, 0x04, 0x08, 0x06, 0x10, 0x07, + 0x4a, 0x04, 0x08, 0x08, 0x10, 0x09, 0x4a, 0x04, 0x08, 0x09, 0x10, 0x0a, 0x22, 0x85, 0x09, 0x0a, + 0x0c, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x41, 0x0a, + 0x05, 0x63, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x23, 0x2e, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, + 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x43, 0x54, 0x79, 0x70, + 0x65, 0x3a, 0x06, 0x53, 0x54, 0x52, 0x49, 0x4e, 0x47, 0x52, 0x05, 0x63, 0x74, 0x79, 0x70, 0x65, + 0x12, 0x16, 0x0a, 0x06, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x06, 0x70, 0x61, 0x63, 0x6b, 0x65, 0x64, 0x12, 0x47, 0x0a, 0x06, 0x6a, 0x73, 0x74, 0x79, + 0x70, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x24, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, + 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x4a, 0x53, 0x54, 0x79, 0x70, 0x65, 0x3a, 0x09, + 0x4a, 0x53, 0x5f, 0x4e, 0x4f, 0x52, 0x4d, 0x41, 0x4c, 0x52, 0x06, 0x6a, 0x73, 0x74, 0x79, 0x70, + 0x65, 0x12, 0x19, 0x0a, 0x04, 0x6c, 0x61, 0x7a, 0x79, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x3a, + 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x04, 0x6c, 0x61, 0x7a, 0x79, 0x12, 0x2e, 0x0a, 0x0f, + 0x75, 0x6e, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x65, 0x64, 0x5f, 0x6c, 0x61, 0x7a, 0x79, 0x18, + 0x0f, 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0e, 0x75, 0x6e, + 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x65, 0x64, 0x4c, 0x61, 0x7a, 0x79, 0x12, 0x25, 0x0a, 0x0a, + 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, + 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0a, 0x64, 0x65, 0x70, 0x72, 0x65, 0x63, 0x61, + 0x74, 0x65, 0x64, 0x12, 0x19, 0x0a, 0x04, 0x77, 0x65, 0x61, 0x6b, 0x18, 0x0a, 0x20, 0x01, 0x28, + 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x04, 0x77, 0x65, 0x61, 0x6b, 0x12, 0x28, + 0x0a, 0x0c, 0x64, 0x65, 0x62, 0x75, 0x67, 0x5f, 0x72, 0x65, 0x64, 0x61, 0x63, 0x74, 0x18, 0x10, + 0x20, 0x01, 0x28, 0x08, 0x3a, 0x05, 0x66, 0x61, 0x6c, 0x73, 0x65, 0x52, 0x0b, 0x64, 0x65, 0x62, + 0x75, 0x67, 0x52, 0x65, 0x64, 0x61, 0x63, 0x74, 0x12, 0x4b, 0x0a, 0x09, 0x72, 0x65, 0x74, 0x65, + 0x6e, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2d, 0x2e, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, + 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, + 0x6e, 0x52, 0x65, 0x74, 0x65, 0x6e, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x72, 0x65, 0x74, 0x65, + 0x6e, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x4a, 0x0a, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, + 0x12, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x61, 0x72, 0x67, 0x65, + 0x74, 0x54, 0x79, 0x70, 0x65, 0x42, 0x02, 0x18, 0x01, 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, + 0x74, 0x12, 0x48, 0x0a, 0x07, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x73, 0x18, 0x13, 0x20, 0x03, + 0x28, 0x0e, 0x32, 0x2e, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x2e, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x61, 0x72, 0x67, 0x65, 0x74, 0x54, 0x79, + 0x70, 0x65, 0x52, 0x07, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x73, 0x12, 0x58, 0x0a, 0x14, 0x75, 0x6e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, 0x72, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0xe7, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x24, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x55, 0x6e, 0x69, @@ -3885,98 +4123,103 @@ func file_google_protobuf_descriptor_proto_rawDescGZIP() []byte { return file_google_protobuf_descriptor_proto_rawDescData } -var file_google_protobuf_descriptor_proto_enumTypes = make([]protoimpl.EnumInfo, 9) -var file_google_protobuf_descriptor_proto_msgTypes = make([]protoimpl.MessageInfo, 27) +var file_google_protobuf_descriptor_proto_enumTypes = make([]protoimpl.EnumInfo, 10) +var file_google_protobuf_descriptor_proto_msgTypes = make([]protoimpl.MessageInfo, 28) var file_google_protobuf_descriptor_proto_goTypes = []interface{}{ - (FieldDescriptorProto_Type)(0), // 0: google.protobuf.FieldDescriptorProto.Type - (FieldDescriptorProto_Label)(0), // 1: google.protobuf.FieldDescriptorProto.Label - (FileOptions_OptimizeMode)(0), // 2: google.protobuf.FileOptions.OptimizeMode - (FieldOptions_CType)(0), // 3: google.protobuf.FieldOptions.CType - (FieldOptions_JSType)(0), // 4: google.protobuf.FieldOptions.JSType - (FieldOptions_OptionRetention)(0), // 5: google.protobuf.FieldOptions.OptionRetention - (FieldOptions_OptionTargetType)(0), // 6: google.protobuf.FieldOptions.OptionTargetType - (MethodOptions_IdempotencyLevel)(0), // 7: google.protobuf.MethodOptions.IdempotencyLevel - (GeneratedCodeInfo_Annotation_Semantic)(0), // 8: google.protobuf.GeneratedCodeInfo.Annotation.Semantic - (*FileDescriptorSet)(nil), // 9: google.protobuf.FileDescriptorSet - (*FileDescriptorProto)(nil), // 10: google.protobuf.FileDescriptorProto - (*DescriptorProto)(nil), // 11: google.protobuf.DescriptorProto - (*ExtensionRangeOptions)(nil), // 12: google.protobuf.ExtensionRangeOptions - (*FieldDescriptorProto)(nil), // 13: google.protobuf.FieldDescriptorProto - (*OneofDescriptorProto)(nil), // 14: google.protobuf.OneofDescriptorProto - (*EnumDescriptorProto)(nil), // 15: google.protobuf.EnumDescriptorProto - (*EnumValueDescriptorProto)(nil), // 16: google.protobuf.EnumValueDescriptorProto - (*ServiceDescriptorProto)(nil), // 17: google.protobuf.ServiceDescriptorProto - (*MethodDescriptorProto)(nil), // 18: google.protobuf.MethodDescriptorProto - (*FileOptions)(nil), // 19: google.protobuf.FileOptions - (*MessageOptions)(nil), // 20: google.protobuf.MessageOptions - (*FieldOptions)(nil), // 21: google.protobuf.FieldOptions - (*OneofOptions)(nil), // 22: google.protobuf.OneofOptions - (*EnumOptions)(nil), // 23: google.protobuf.EnumOptions - (*EnumValueOptions)(nil), // 24: google.protobuf.EnumValueOptions - (*ServiceOptions)(nil), // 25: google.protobuf.ServiceOptions - (*MethodOptions)(nil), // 26: google.protobuf.MethodOptions - (*UninterpretedOption)(nil), // 27: google.protobuf.UninterpretedOption - (*SourceCodeInfo)(nil), // 28: google.protobuf.SourceCodeInfo - (*GeneratedCodeInfo)(nil), // 29: google.protobuf.GeneratedCodeInfo - (*DescriptorProto_ExtensionRange)(nil), // 30: google.protobuf.DescriptorProto.ExtensionRange - (*DescriptorProto_ReservedRange)(nil), // 31: google.protobuf.DescriptorProto.ReservedRange - (*EnumDescriptorProto_EnumReservedRange)(nil), // 32: google.protobuf.EnumDescriptorProto.EnumReservedRange - (*UninterpretedOption_NamePart)(nil), // 33: google.protobuf.UninterpretedOption.NamePart - (*SourceCodeInfo_Location)(nil), // 34: google.protobuf.SourceCodeInfo.Location - (*GeneratedCodeInfo_Annotation)(nil), // 35: google.protobuf.GeneratedCodeInfo.Annotation + (ExtensionRangeOptions_VerificationState)(0), // 0: google.protobuf.ExtensionRangeOptions.VerificationState + (FieldDescriptorProto_Type)(0), // 1: google.protobuf.FieldDescriptorProto.Type + (FieldDescriptorProto_Label)(0), // 2: google.protobuf.FieldDescriptorProto.Label + (FileOptions_OptimizeMode)(0), // 3: google.protobuf.FileOptions.OptimizeMode + (FieldOptions_CType)(0), // 4: google.protobuf.FieldOptions.CType + (FieldOptions_JSType)(0), // 5: google.protobuf.FieldOptions.JSType + (FieldOptions_OptionRetention)(0), // 6: google.protobuf.FieldOptions.OptionRetention + (FieldOptions_OptionTargetType)(0), // 7: google.protobuf.FieldOptions.OptionTargetType + (MethodOptions_IdempotencyLevel)(0), // 8: google.protobuf.MethodOptions.IdempotencyLevel + (GeneratedCodeInfo_Annotation_Semantic)(0), // 9: google.protobuf.GeneratedCodeInfo.Annotation.Semantic + (*FileDescriptorSet)(nil), // 10: google.protobuf.FileDescriptorSet + (*FileDescriptorProto)(nil), // 11: google.protobuf.FileDescriptorProto + (*DescriptorProto)(nil), // 12: google.protobuf.DescriptorProto + (*ExtensionRangeOptions)(nil), // 13: google.protobuf.ExtensionRangeOptions + (*FieldDescriptorProto)(nil), // 14: google.protobuf.FieldDescriptorProto + (*OneofDescriptorProto)(nil), // 15: google.protobuf.OneofDescriptorProto + (*EnumDescriptorProto)(nil), // 16: google.protobuf.EnumDescriptorProto + (*EnumValueDescriptorProto)(nil), // 17: google.protobuf.EnumValueDescriptorProto + (*ServiceDescriptorProto)(nil), // 18: google.protobuf.ServiceDescriptorProto + (*MethodDescriptorProto)(nil), // 19: google.protobuf.MethodDescriptorProto + (*FileOptions)(nil), // 20: google.protobuf.FileOptions + (*MessageOptions)(nil), // 21: google.protobuf.MessageOptions + (*FieldOptions)(nil), // 22: google.protobuf.FieldOptions + (*OneofOptions)(nil), // 23: google.protobuf.OneofOptions + (*EnumOptions)(nil), // 24: google.protobuf.EnumOptions + (*EnumValueOptions)(nil), // 25: google.protobuf.EnumValueOptions + (*ServiceOptions)(nil), // 26: google.protobuf.ServiceOptions + (*MethodOptions)(nil), // 27: google.protobuf.MethodOptions + (*UninterpretedOption)(nil), // 28: google.protobuf.UninterpretedOption + (*SourceCodeInfo)(nil), // 29: google.protobuf.SourceCodeInfo + (*GeneratedCodeInfo)(nil), // 30: google.protobuf.GeneratedCodeInfo + (*DescriptorProto_ExtensionRange)(nil), // 31: google.protobuf.DescriptorProto.ExtensionRange + (*DescriptorProto_ReservedRange)(nil), // 32: google.protobuf.DescriptorProto.ReservedRange + (*ExtensionRangeOptions_Declaration)(nil), // 33: google.protobuf.ExtensionRangeOptions.Declaration + (*EnumDescriptorProto_EnumReservedRange)(nil), // 34: google.protobuf.EnumDescriptorProto.EnumReservedRange + (*UninterpretedOption_NamePart)(nil), // 35: google.protobuf.UninterpretedOption.NamePart + (*SourceCodeInfo_Location)(nil), // 36: google.protobuf.SourceCodeInfo.Location + (*GeneratedCodeInfo_Annotation)(nil), // 37: google.protobuf.GeneratedCodeInfo.Annotation } var file_google_protobuf_descriptor_proto_depIdxs = []int32{ - 10, // 0: google.protobuf.FileDescriptorSet.file:type_name -> google.protobuf.FileDescriptorProto - 11, // 1: google.protobuf.FileDescriptorProto.message_type:type_name -> google.protobuf.DescriptorProto - 15, // 2: google.protobuf.FileDescriptorProto.enum_type:type_name -> google.protobuf.EnumDescriptorProto - 17, // 3: google.protobuf.FileDescriptorProto.service:type_name -> google.protobuf.ServiceDescriptorProto - 13, // 4: google.protobuf.FileDescriptorProto.extension:type_name -> google.protobuf.FieldDescriptorProto - 19, // 5: google.protobuf.FileDescriptorProto.options:type_name -> google.protobuf.FileOptions - 28, // 6: google.protobuf.FileDescriptorProto.source_code_info:type_name -> google.protobuf.SourceCodeInfo - 13, // 7: google.protobuf.DescriptorProto.field:type_name -> google.protobuf.FieldDescriptorProto - 13, // 8: google.protobuf.DescriptorProto.extension:type_name -> google.protobuf.FieldDescriptorProto - 11, // 9: google.protobuf.DescriptorProto.nested_type:type_name -> google.protobuf.DescriptorProto - 15, // 10: google.protobuf.DescriptorProto.enum_type:type_name -> google.protobuf.EnumDescriptorProto - 30, // 11: google.protobuf.DescriptorProto.extension_range:type_name -> google.protobuf.DescriptorProto.ExtensionRange - 14, // 12: google.protobuf.DescriptorProto.oneof_decl:type_name -> google.protobuf.OneofDescriptorProto - 20, // 13: google.protobuf.DescriptorProto.options:type_name -> google.protobuf.MessageOptions - 31, // 14: google.protobuf.DescriptorProto.reserved_range:type_name -> google.protobuf.DescriptorProto.ReservedRange - 27, // 15: google.protobuf.ExtensionRangeOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption - 1, // 16: google.protobuf.FieldDescriptorProto.label:type_name -> google.protobuf.FieldDescriptorProto.Label - 0, // 17: google.protobuf.FieldDescriptorProto.type:type_name -> google.protobuf.FieldDescriptorProto.Type - 21, // 18: google.protobuf.FieldDescriptorProto.options:type_name -> google.protobuf.FieldOptions - 22, // 19: google.protobuf.OneofDescriptorProto.options:type_name -> google.protobuf.OneofOptions - 16, // 20: google.protobuf.EnumDescriptorProto.value:type_name -> google.protobuf.EnumValueDescriptorProto - 23, // 21: google.protobuf.EnumDescriptorProto.options:type_name -> google.protobuf.EnumOptions - 32, // 22: google.protobuf.EnumDescriptorProto.reserved_range:type_name -> google.protobuf.EnumDescriptorProto.EnumReservedRange - 24, // 23: google.protobuf.EnumValueDescriptorProto.options:type_name -> google.protobuf.EnumValueOptions - 18, // 24: google.protobuf.ServiceDescriptorProto.method:type_name -> google.protobuf.MethodDescriptorProto - 25, // 25: google.protobuf.ServiceDescriptorProto.options:type_name -> google.protobuf.ServiceOptions - 26, // 26: google.protobuf.MethodDescriptorProto.options:type_name -> google.protobuf.MethodOptions - 2, // 27: google.protobuf.FileOptions.optimize_for:type_name -> google.protobuf.FileOptions.OptimizeMode - 27, // 28: google.protobuf.FileOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption - 27, // 29: google.protobuf.MessageOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption - 3, // 30: google.protobuf.FieldOptions.ctype:type_name -> google.protobuf.FieldOptions.CType - 4, // 31: google.protobuf.FieldOptions.jstype:type_name -> google.protobuf.FieldOptions.JSType - 5, // 32: google.protobuf.FieldOptions.retention:type_name -> google.protobuf.FieldOptions.OptionRetention - 6, // 33: google.protobuf.FieldOptions.target:type_name -> google.protobuf.FieldOptions.OptionTargetType - 27, // 34: google.protobuf.FieldOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption - 27, // 35: google.protobuf.OneofOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption - 27, // 36: google.protobuf.EnumOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption - 27, // 37: google.protobuf.EnumValueOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption - 27, // 38: google.protobuf.ServiceOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption - 7, // 39: google.protobuf.MethodOptions.idempotency_level:type_name -> google.protobuf.MethodOptions.IdempotencyLevel - 27, // 40: google.protobuf.MethodOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption - 33, // 41: google.protobuf.UninterpretedOption.name:type_name -> google.protobuf.UninterpretedOption.NamePart - 34, // 42: google.protobuf.SourceCodeInfo.location:type_name -> google.protobuf.SourceCodeInfo.Location - 35, // 43: google.protobuf.GeneratedCodeInfo.annotation:type_name -> google.protobuf.GeneratedCodeInfo.Annotation - 12, // 44: google.protobuf.DescriptorProto.ExtensionRange.options:type_name -> google.protobuf.ExtensionRangeOptions - 8, // 45: google.protobuf.GeneratedCodeInfo.Annotation.semantic:type_name -> google.protobuf.GeneratedCodeInfo.Annotation.Semantic - 46, // [46:46] is the sub-list for method output_type - 46, // [46:46] is the sub-list for method input_type - 46, // [46:46] is the sub-list for extension type_name - 46, // [46:46] is the sub-list for extension extendee - 0, // [0:46] is the sub-list for field type_name + 11, // 0: google.protobuf.FileDescriptorSet.file:type_name -> google.protobuf.FileDescriptorProto + 12, // 1: google.protobuf.FileDescriptorProto.message_type:type_name -> google.protobuf.DescriptorProto + 16, // 2: google.protobuf.FileDescriptorProto.enum_type:type_name -> google.protobuf.EnumDescriptorProto + 18, // 3: google.protobuf.FileDescriptorProto.service:type_name -> google.protobuf.ServiceDescriptorProto + 14, // 4: google.protobuf.FileDescriptorProto.extension:type_name -> google.protobuf.FieldDescriptorProto + 20, // 5: google.protobuf.FileDescriptorProto.options:type_name -> google.protobuf.FileOptions + 29, // 6: google.protobuf.FileDescriptorProto.source_code_info:type_name -> google.protobuf.SourceCodeInfo + 14, // 7: google.protobuf.DescriptorProto.field:type_name -> google.protobuf.FieldDescriptorProto + 14, // 8: google.protobuf.DescriptorProto.extension:type_name -> google.protobuf.FieldDescriptorProto + 12, // 9: google.protobuf.DescriptorProto.nested_type:type_name -> google.protobuf.DescriptorProto + 16, // 10: google.protobuf.DescriptorProto.enum_type:type_name -> google.protobuf.EnumDescriptorProto + 31, // 11: google.protobuf.DescriptorProto.extension_range:type_name -> google.protobuf.DescriptorProto.ExtensionRange + 15, // 12: google.protobuf.DescriptorProto.oneof_decl:type_name -> google.protobuf.OneofDescriptorProto + 21, // 13: google.protobuf.DescriptorProto.options:type_name -> google.protobuf.MessageOptions + 32, // 14: google.protobuf.DescriptorProto.reserved_range:type_name -> google.protobuf.DescriptorProto.ReservedRange + 28, // 15: google.protobuf.ExtensionRangeOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption + 33, // 16: google.protobuf.ExtensionRangeOptions.declaration:type_name -> google.protobuf.ExtensionRangeOptions.Declaration + 0, // 17: google.protobuf.ExtensionRangeOptions.verification:type_name -> google.protobuf.ExtensionRangeOptions.VerificationState + 2, // 18: google.protobuf.FieldDescriptorProto.label:type_name -> google.protobuf.FieldDescriptorProto.Label + 1, // 19: google.protobuf.FieldDescriptorProto.type:type_name -> google.protobuf.FieldDescriptorProto.Type + 22, // 20: google.protobuf.FieldDescriptorProto.options:type_name -> google.protobuf.FieldOptions + 23, // 21: google.protobuf.OneofDescriptorProto.options:type_name -> google.protobuf.OneofOptions + 17, // 22: google.protobuf.EnumDescriptorProto.value:type_name -> google.protobuf.EnumValueDescriptorProto + 24, // 23: google.protobuf.EnumDescriptorProto.options:type_name -> google.protobuf.EnumOptions + 34, // 24: google.protobuf.EnumDescriptorProto.reserved_range:type_name -> google.protobuf.EnumDescriptorProto.EnumReservedRange + 25, // 25: google.protobuf.EnumValueDescriptorProto.options:type_name -> google.protobuf.EnumValueOptions + 19, // 26: google.protobuf.ServiceDescriptorProto.method:type_name -> google.protobuf.MethodDescriptorProto + 26, // 27: google.protobuf.ServiceDescriptorProto.options:type_name -> google.protobuf.ServiceOptions + 27, // 28: google.protobuf.MethodDescriptorProto.options:type_name -> google.protobuf.MethodOptions + 3, // 29: google.protobuf.FileOptions.optimize_for:type_name -> google.protobuf.FileOptions.OptimizeMode + 28, // 30: google.protobuf.FileOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption + 28, // 31: google.protobuf.MessageOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption + 4, // 32: google.protobuf.FieldOptions.ctype:type_name -> google.protobuf.FieldOptions.CType + 5, // 33: google.protobuf.FieldOptions.jstype:type_name -> google.protobuf.FieldOptions.JSType + 6, // 34: google.protobuf.FieldOptions.retention:type_name -> google.protobuf.FieldOptions.OptionRetention + 7, // 35: google.protobuf.FieldOptions.target:type_name -> google.protobuf.FieldOptions.OptionTargetType + 7, // 36: google.protobuf.FieldOptions.targets:type_name -> google.protobuf.FieldOptions.OptionTargetType + 28, // 37: google.protobuf.FieldOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption + 28, // 38: google.protobuf.OneofOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption + 28, // 39: google.protobuf.EnumOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption + 28, // 40: google.protobuf.EnumValueOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption + 28, // 41: google.protobuf.ServiceOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption + 8, // 42: google.protobuf.MethodOptions.idempotency_level:type_name -> google.protobuf.MethodOptions.IdempotencyLevel + 28, // 43: google.protobuf.MethodOptions.uninterpreted_option:type_name -> google.protobuf.UninterpretedOption + 35, // 44: google.protobuf.UninterpretedOption.name:type_name -> google.protobuf.UninterpretedOption.NamePart + 36, // 45: google.protobuf.SourceCodeInfo.location:type_name -> google.protobuf.SourceCodeInfo.Location + 37, // 46: google.protobuf.GeneratedCodeInfo.annotation:type_name -> google.protobuf.GeneratedCodeInfo.Annotation + 13, // 47: google.protobuf.DescriptorProto.ExtensionRange.options:type_name -> google.protobuf.ExtensionRangeOptions + 9, // 48: google.protobuf.GeneratedCodeInfo.Annotation.semantic:type_name -> google.protobuf.GeneratedCodeInfo.Annotation.Semantic + 49, // [49:49] is the sub-list for method output_type + 49, // [49:49] is the sub-list for method input_type + 49, // [49:49] is the sub-list for extension type_name + 49, // [49:49] is the sub-list for extension extendee + 0, // [0:49] is the sub-list for field type_name } func init() { file_google_protobuf_descriptor_proto_init() } @@ -4280,7 +4523,7 @@ func file_google_protobuf_descriptor_proto_init() { } } file_google_protobuf_descriptor_proto_msgTypes[23].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*EnumDescriptorProto_EnumReservedRange); i { + switch v := v.(*ExtensionRangeOptions_Declaration); i { case 0: return &v.state case 1: @@ -4292,7 +4535,7 @@ func file_google_protobuf_descriptor_proto_init() { } } file_google_protobuf_descriptor_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*UninterpretedOption_NamePart); i { + switch v := v.(*EnumDescriptorProto_EnumReservedRange); i { case 0: return &v.state case 1: @@ -4304,7 +4547,7 @@ func file_google_protobuf_descriptor_proto_init() { } } file_google_protobuf_descriptor_proto_msgTypes[25].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SourceCodeInfo_Location); i { + switch v := v.(*UninterpretedOption_NamePart); i { case 0: return &v.state case 1: @@ -4316,6 +4559,18 @@ func file_google_protobuf_descriptor_proto_init() { } } file_google_protobuf_descriptor_proto_msgTypes[26].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SourceCodeInfo_Location); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_google_protobuf_descriptor_proto_msgTypes[27].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GeneratedCodeInfo_Annotation); i { case 0: return &v.state @@ -4333,8 +4588,8 @@ func file_google_protobuf_descriptor_proto_init() { File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_google_protobuf_descriptor_proto_rawDesc, - NumEnums: 9, - NumMessages: 27, + NumEnums: 10, + NumMessages: 28, NumExtensions: 0, NumServices: 0, }, diff --git a/vendor/google.golang.org/protobuf/types/dynamicpb/types.go b/vendor/google.golang.org/protobuf/types/dynamicpb/types.go new file mode 100644 index 000000000..5a8010f18 --- /dev/null +++ b/vendor/google.golang.org/protobuf/types/dynamicpb/types.go @@ -0,0 +1,177 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package dynamicpb + +import ( + "fmt" + "strings" + "sync" + "sync/atomic" + + "google.golang.org/protobuf/internal/errors" + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/reflect/protoregistry" +) + +type extField struct { + name protoreflect.FullName + number protoreflect.FieldNumber +} + +// A Types is a collection of dynamically constructed descriptors. +// Its methods are safe for concurrent use. +// +// Types implements protoregistry.MessageTypeResolver and protoregistry.ExtensionTypeResolver. +// A Types may be used as a proto.UnmarshalOptions.Resolver. +type Types struct { + files *protoregistry.Files + + extMu sync.Mutex + atomicExtFiles uint64 + extensionsByMessage map[extField]protoreflect.ExtensionDescriptor +} + +// NewTypes creates a new Types registry with the provided files. +// The Files registry is retained, and changes to Files will be reflected in Types. +// It is not safe to concurrently change the Files while calling Types methods. +func NewTypes(f *protoregistry.Files) *Types { + return &Types{ + files: f, + } +} + +// FindEnumByName looks up an enum by its full name; +// e.g., "google.protobuf.Field.Kind". +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindEnumByName(name protoreflect.FullName) (protoreflect.EnumType, error) { + d, err := t.files.FindDescriptorByName(name) + if err != nil { + return nil, err + } + ed, ok := d.(protoreflect.EnumDescriptor) + if !ok { + return nil, errors.New("found wrong type: got %v, want enum", descName(d)) + } + return NewEnumType(ed), nil +} + +// FindExtensionByName looks up an extension field by the field's full name. +// Note that this is the full name of the field as determined by +// where the extension is declared and is unrelated to the full name of the +// message being extended. +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindExtensionByName(name protoreflect.FullName) (protoreflect.ExtensionType, error) { + d, err := t.files.FindDescriptorByName(name) + if err != nil { + return nil, err + } + xd, ok := d.(protoreflect.ExtensionDescriptor) + if !ok { + return nil, errors.New("found wrong type: got %v, want extension", descName(d)) + } + return NewExtensionType(xd), nil +} + +// FindExtensionByNumber looks up an extension field by the field number +// within some parent message, identified by full name. +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { + // Construct the extension number map lazily, since not every user will need it. + // Update the map if new files are added to the registry. + if atomic.LoadUint64(&t.atomicExtFiles) != uint64(t.files.NumFiles()) { + t.updateExtensions() + } + xd := t.extensionsByMessage[extField{message, field}] + if xd == nil { + return nil, protoregistry.NotFound + } + return NewExtensionType(xd), nil +} + +// FindMessageByName looks up a message by its full name; +// e.g. "google.protobuf.Any". +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindMessageByName(name protoreflect.FullName) (protoreflect.MessageType, error) { + d, err := t.files.FindDescriptorByName(name) + if err != nil { + return nil, err + } + md, ok := d.(protoreflect.MessageDescriptor) + if !ok { + return nil, errors.New("found wrong type: got %v, want message", descName(d)) + } + return NewMessageType(md), nil +} + +// FindMessageByURL looks up a message by a URL identifier. +// See documentation on google.protobuf.Any.type_url for the URL format. +// +// This returns (nil, protoregistry.NotFound) if not found. +func (t *Types) FindMessageByURL(url string) (protoreflect.MessageType, error) { + // This function is similar to FindMessageByName but + // truncates anything before and including '/' in the URL. + message := protoreflect.FullName(url) + if i := strings.LastIndexByte(url, '/'); i >= 0 { + message = message[i+len("/"):] + } + return t.FindMessageByName(message) +} + +func (t *Types) updateExtensions() { + t.extMu.Lock() + defer t.extMu.Unlock() + if atomic.LoadUint64(&t.atomicExtFiles) == uint64(t.files.NumFiles()) { + return + } + defer atomic.StoreUint64(&t.atomicExtFiles, uint64(t.files.NumFiles())) + t.files.RangeFiles(func(fd protoreflect.FileDescriptor) bool { + t.registerExtensions(fd.Extensions()) + t.registerExtensionsInMessages(fd.Messages()) + return true + }) +} + +func (t *Types) registerExtensionsInMessages(mds protoreflect.MessageDescriptors) { + count := mds.Len() + for i := 0; i < count; i++ { + md := mds.Get(i) + t.registerExtensions(md.Extensions()) + t.registerExtensionsInMessages(md.Messages()) + } +} + +func (t *Types) registerExtensions(xds protoreflect.ExtensionDescriptors) { + count := xds.Len() + for i := 0; i < count; i++ { + xd := xds.Get(i) + field := xd.Number() + message := xd.ContainingMessage().FullName() + if t.extensionsByMessage == nil { + t.extensionsByMessage = make(map[extField]protoreflect.ExtensionDescriptor) + } + t.extensionsByMessage[extField{message, field}] = xd + } +} + +func descName(d protoreflect.Descriptor) string { + switch d.(type) { + case protoreflect.EnumDescriptor: + return "enum" + case protoreflect.EnumValueDescriptor: + return "enum value" + case protoreflect.MessageDescriptor: + return "message" + case protoreflect.ExtensionDescriptor: + return "extension" + case protoreflect.ServiceDescriptor: + return "service" + default: + return fmt.Sprintf("%T", d) + } +} diff --git a/vendor/google.golang.org/protobuf/types/known/anypb/any.pb.go b/vendor/google.golang.org/protobuf/types/known/anypb/any.pb.go index a6c7a33f3..580b232f4 100644 --- a/vendor/google.golang.org/protobuf/types/known/anypb/any.pb.go +++ b/vendor/google.golang.org/protobuf/types/known/anypb/any.pb.go @@ -142,39 +142,39 @@ import ( // // Example 2: Pack and unpack a message in Java. // -// Foo foo = ...; -// Any any = Any.pack(foo); -// ... -// if (any.is(Foo.class)) { -// foo = any.unpack(Foo.class); -// } -// // or ... -// if (any.isSameTypeAs(Foo.getDefaultInstance())) { -// foo = any.unpack(Foo.getDefaultInstance()); -// } -// -// Example 3: Pack and unpack a message in Python. -// -// foo = Foo(...) -// any = Any() -// any.Pack(foo) -// ... -// if any.Is(Foo.DESCRIPTOR): -// any.Unpack(foo) -// ... -// -// Example 4: Pack and unpack a message in Go -// -// foo := &pb.Foo{...} -// any, err := anypb.New(foo) -// if err != nil { -// ... -// } -// ... -// foo := &pb.Foo{} -// if err := any.UnmarshalTo(foo); err != nil { -// ... -// } +// Foo foo = ...; +// Any any = Any.pack(foo); +// ... +// if (any.is(Foo.class)) { +// foo = any.unpack(Foo.class); +// } +// // or ... +// if (any.isSameTypeAs(Foo.getDefaultInstance())) { +// foo = any.unpack(Foo.getDefaultInstance()); +// } +// +// Example 3: Pack and unpack a message in Python. +// +// foo = Foo(...) +// any = Any() +// any.Pack(foo) +// ... +// if any.Is(Foo.DESCRIPTOR): +// any.Unpack(foo) +// ... +// +// Example 4: Pack and unpack a message in Go +// +// foo := &pb.Foo{...} +// any, err := anypb.New(foo) +// if err != nil { +// ... +// } +// ... +// foo := &pb.Foo{} +// if err := any.UnmarshalTo(foo); err != nil { +// ... +// } // // The pack methods provided by protobuf library will by default use // 'type.googleapis.com/full.type.name' as the type URL and the unpack @@ -182,8 +182,8 @@ import ( // in the type URL, for example "foo.bar.com/x/y.z" will yield type // name "y.z". // -// # JSON -// +// JSON +// ==== // The JSON representation of an `Any` value uses the regular // representation of the deserialized, embedded message, with an // additional field `@type` which contains the type URL. Example: diff --git a/vendor/google.golang.org/protobuf/types/known/structpb/struct.pb.go b/vendor/google.golang.org/protobuf/types/known/structpb/struct.pb.go index 9577ed593..d2bac8b88 100644 --- a/vendor/google.golang.org/protobuf/types/known/structpb/struct.pb.go +++ b/vendor/google.golang.org/protobuf/types/known/structpb/struct.pb.go @@ -132,7 +132,7 @@ import ( // `NullValue` is a singleton enumeration to represent the null value for the // `Value` type union. // -// The JSON representation for `NullValue` is JSON `null`. +// The JSON representation for `NullValue` is JSON `null`. type NullValue int32 const ( diff --git a/vendor/google.golang.org/protobuf/types/known/timestamppb/timestamp.pb.go b/vendor/google.golang.org/protobuf/types/known/timestamppb/timestamp.pb.go index 61f69fc11..81511a336 100644 --- a/vendor/google.golang.org/protobuf/types/known/timestamppb/timestamp.pb.go +++ b/vendor/google.golang.org/protobuf/types/known/timestamppb/timestamp.pb.go @@ -167,7 +167,7 @@ import ( // [`strftime`](https://docs.python.org/2/library/time.html#time.strftime) with // the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use // the Joda Time's [`ISODateTimeFormat.dateTime()`]( -// http://www.joda.org/joda-time/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime%2D%2D +// http://joda-time.sourceforge.net/apidocs/org/joda/time/format/ISODateTimeFormat.html#dateTime() // ) to obtain a formatter capable of generating timestamps in this format. type Timestamp struct { state protoimpl.MessageState diff --git a/vendor/modules.txt b/vendor/modules.txt index ead391bb7..4cb0f8cee 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -8,7 +8,7 @@ github.com/stoewer/go-strcase ## explicit; go 1.18 golang.org/x/exp/constraints golang.org/x/exp/slices -# golang.org/x/text v0.8.0 +# golang.org/x/text v0.9.0 ## explicit; go 1.17 golang.org/x/text/feature/plural golang.org/x/text/internal @@ -24,13 +24,13 @@ golang.org/x/text/message golang.org/x/text/message/catalog golang.org/x/text/transform golang.org/x/text/width -# google.golang.org/genproto/googleapis/api v0.0.0-20230525234035-dd9d682886f9 +# google.golang.org/genproto/googleapis/api v0.0.0-20230803162519-f966b187b2e5 ## explicit; go 1.19 google.golang.org/genproto/googleapis/api/expr/v1alpha1 -# google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19 +# google.golang.org/genproto/googleapis/rpc v0.0.0-20230803162519-f966b187b2e5 ## explicit; go 1.19 google.golang.org/genproto/googleapis/rpc/status -# google.golang.org/protobuf v1.30.0 +# google.golang.org/protobuf v1.31.0 ## explicit; go 1.11 google.golang.org/protobuf/encoding/protojson google.golang.org/protobuf/encoding/prototext From 5db3640443d76752fb201c575f0ce06e4c202459 Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 1 Sep 2023 12:25:44 -0700 Subject: [PATCH 31/31] Inlining optimizer (#827) Inliner with support for identifiers and simple select expressions Note: optional field selections are not supported for inlining matches at this time. --- cel/inlining.go | 220 +++++++++++++++++ cel/inlining_test.go | 577 +++++++++++++++++++++++++++++++++++++++++++ cel/optimizer.go | 73 ++++++ 3 files changed, 870 insertions(+) create mode 100644 cel/inlining.go create mode 100644 cel/inlining_test.go diff --git a/cel/inlining.go b/cel/inlining.go new file mode 100644 index 000000000..9fc3be278 --- /dev/null +++ b/cel/inlining.go @@ -0,0 +1,220 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel + +import ( + "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/containers" + "github.com/google/cel-go/common/operators" + "github.com/google/cel-go/common/overloads" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/traits" +) + +// InlineVariable holds a variable name to be matched and an AST representing +// the expression graph which should be used to replace it. +type InlineVariable struct { + name string + alias string + def *ast.AST +} + +// Name returns the qualified variable or field selection to replace. +func (v *InlineVariable) Name() string { + return v.name +} + +// Alias returns the alias to use when performing cel.bind() calls during inlining. +func (v *InlineVariable) Alias() string { + return v.alias +} + +// Expr returns the inlined expression value. +func (v *InlineVariable) Expr() ast.Expr { + return v.def.Expr() +} + +// Type indicates the inlined expression type. +func (v *InlineVariable) Type() *Type { + return v.def.GetType(v.def.Expr().ID()) +} + +// NewInlineVariable declares a variable name to be replaced by a checked expression. +func NewInlineVariable(name string, definition *Ast) *InlineVariable { + return NewInlineVariableWithAlias(name, name, definition) +} + +// NewInlineVariableWithAlias declares a variable name to be replaced by a checked expression. +// If the variable occurs more than once, the provided alias will be used to replace the expressions +// where the variable name occurs. +func NewInlineVariableWithAlias(name, alias string, definition *Ast) *InlineVariable { + return &InlineVariable{name: name, alias: alias, def: definition.impl} +} + +// NewInliningOptimizer creates and optimizer which replaces variables with expression definitions. +// +// If a variable occurs one time, the variable is replaced by the inline definition. If the +// variable occurs more than once, the variable occurences are replaced by a cel.bind() call. +func NewInliningOptimizer(inlineVars ...*InlineVariable) ASTOptimizer { + return &inliningOptimizer{variables: inlineVars} +} + +type inliningOptimizer struct { + variables []*InlineVariable +} + +func (opt *inliningOptimizer) Optimize(ctx *OptimizerContext, a *ast.AST) *ast.AST { + root := ast.NavigateAST(a) + for _, inlineVar := range opt.variables { + matches := ast.MatchDescendants(root, opt.matchVariable(inlineVar.Name())) + // Skip cases where the variable isn't in the expression graph + if len(matches) == 0 { + continue + } + + // For a single match, do a direct replacement of the expression sub-graph. + if len(matches) == 1 { + opt.inlineExpr(ctx, matches[0], ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type()) + continue + } + + if !isBindable(matches, inlineVar.Expr(), inlineVar.Type()) { + for _, match := range matches { + opt.inlineExpr(ctx, match, ctx.CopyExpr(inlineVar.Expr()), inlineVar.Type()) + } + continue + } + // For multiple matches, find the least common ancestor (lca) and insert the + // variable as a cel.bind() macro. + var lca ast.NavigableExpr = nil + ancestors := map[int64]bool{} + for _, match := range matches { + // Update the identifier matches with the provided alias. + aliasExpr := ctx.NewIdent(inlineVar.Alias()) + opt.inlineExpr(ctx, match, aliasExpr, inlineVar.Type()) + parent, found := match, true + for found { + _, hasAncestor := ancestors[parent.ID()] + if hasAncestor && (lca == nil || lca.Depth() < parent.Depth()) { + lca = parent + } + ancestors[parent.ID()] = true + parent, found = parent.Parent() + } + } + + // Update the least common ancestor by inserting a cel.bind() call to the alias. + inlined := ctx.NewBindMacro(lca.ID(), inlineVar.Alias(), inlineVar.Expr(), lca) + opt.inlineExpr(ctx, lca, inlined, inlineVar.Type()) + } + return a +} + +// inlineExpr replaces the current expression with the inlined one, unless the location of the inlining +// happens within a presence test, e.g. has(a.b.c) -> inline alpha for a.b.c in which case an attempt is +// made to determine whether the inlined value can be presence or existence tested. +func (opt *inliningOptimizer) inlineExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) { + switch prev.Kind() { + case ast.SelectKind: + sel := prev.AsSelect() + if !sel.IsTestOnly() { + prev.SetKindCase(inlined) + return + } + opt.rewritePresenceExpr(ctx, prev, inlined, inlinedType) + default: + prev.SetKindCase(inlined) + } +} + +// rewritePresenceExpr converts the inlined expression, when it occurs within a has() macro, to type-safe +// expression appropriate for the inlined type, if possible. +// +// If the rewrite is not possible an error is reported at the inline expression site. +func (opt *inliningOptimizer) rewritePresenceExpr(ctx *OptimizerContext, prev, inlined ast.Expr, inlinedType *Type) { + // If the input inlined expression is not a select expression it won't work with the has() + // macro. Attempt to rewrite the presence test in terms of the typed input, otherwise error. + ctx.sourceInfo.ClearMacroCall(prev.ID()) + if inlined.Kind() == ast.SelectKind { + inlinedSel := inlined.AsSelect() + prev.SetKindCase( + ctx.NewPresenceTest(prev.ID(), inlinedSel.Operand(), inlinedSel.FieldName())) + return + } + if inlinedType.IsAssignableType(NullType) { + prev.SetKindCase( + ctx.NewCall(operators.NotEquals, + inlined, + ctx.NewLiteral(types.NullValue), + )) + return + } + if inlinedType.HasTrait(traits.SizerType) { + prev.SetKindCase( + ctx.NewCall(operators.NotEquals, + ctx.NewMemberCall(overloads.Size, inlined), + ctx.NewLiteral(types.IntZero), + )) + return + } + ctx.ReportErrorAtID(prev.ID(), "unable to inline expression type %v into presence test", inlinedType) +} + +// isBindable indicates whether the inlined type can be used within a cel.bind() if the expression +// being replaced occurs within a presence test. Value types with a size() method or field selection +// support can be bound. +// +// In future iterations, support may also be added for indexer types which can be rewritten as an `in` +// expression; however, this would imply a rewrite of the inlined expression that may not be necessary +// in most cases. +func isBindable(matches []ast.NavigableExpr, inlined ast.Expr, inlinedType *Type) bool { + if inlinedType.IsAssignableType(NullType) || + inlinedType.HasTrait(traits.SizerType) || + inlinedType.HasTrait(traits.FieldTesterType) { + return true + } + for _, m := range matches { + if m.Kind() != ast.SelectKind { + continue + } + sel := m.AsSelect() + if sel.IsTestOnly() { + return false + } + } + return true +} + +// matchVariable matches simple identifiers, select expressions, and presence test expressions +// which match the (potentially) qualified variable name provided as input. +// +// Note, this function does not support inlining against select expressions which includes optional +// field selection. This may be a future refinement. +func (opt *inliningOptimizer) matchVariable(varName string) ast.ExprMatcher { + return func(e ast.NavigableExpr) bool { + if e.Kind() == ast.IdentKind && e.AsIdent() == varName { + return true + } + if e.Kind() == ast.SelectKind { + sel := e.AsSelect() + // While the `ToQualifiedName` call could take the select directly, this + // would skip presence tests from possible matches, which we would like + // to include. + qualName, found := containers.ToQualifiedName(sel.Operand()) + return found && qualName+"."+sel.FieldName() == varName + } + return false + } +} diff --git a/cel/inlining_test.go b/cel/inlining_test.go new file mode 100644 index 000000000..389e46258 --- /dev/null +++ b/cel/inlining_test.go @@ -0,0 +1,577 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cel_test + +import ( + "testing" + + "github.com/google/cel-go/cel" + + proto3pb "github.com/google/cel-go/test/proto3pb" +) + +func TestInliningOptimizer(t *testing.T) { + type varExpr struct { + name string + alias string + t *cel.Type + expr string + } + tests := []struct { + expr string + vars []varExpr + inlined string + folded string + }{ + { + expr: `a || b`, + vars: []varExpr{ + { + name: "a", + t: cel.BoolType, + }, + { + name: "b", + alias: "bravo", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + }, + inlined: `a || "hello".contains("lo")`, + folded: `true`, + }, + { + expr: `a + [a]`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.DynType, + expr: `dyn([1, 2])`, + }, + }, + inlined: `cel.bind(alpha, dyn([1, 2]), alpha + [alpha])`, + folded: `[1, 2, [1, 2]]`, + }, + { + expr: `a && (a || b)`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + { + name: "b", + t: cel.BoolType, + }, + }, + inlined: `cel.bind(alpha, "hello".contains("lo"), alpha && (alpha || b))`, + folded: `true`, + }, + { + expr: `a && b && a`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + { + name: "b", + t: cel.BoolType, + }, + }, + inlined: `cel.bind(alpha, "hello".contains("lo"), alpha && b && alpha)`, + folded: `cel.bind(alpha, true, alpha && b && alpha)`, + }, + { + expr: `(c || d) || (a && (a || b))`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + { + name: "b", + t: cel.BoolType, + }, + { + name: "c", + t: cel.BoolType, + }, + { + name: "d", + t: cel.BoolType, + expr: "!false", + }, + }, + inlined: `c || !false || cel.bind(alpha, "hello".contains("lo"), alpha && (alpha || b))`, + folded: `true`, + }, + { + expr: `a && (a || b)`, + vars: []varExpr{ + { + name: "a", + t: cel.BoolType, + }, + { + name: "b", + alias: "bravo", + t: cel.BoolType, + expr: `'hello'.contains('lo')`, + }, + }, + inlined: `a && (a || "hello".contains("lo"))`, + folded: `a`, + }, + { + expr: `a && b`, + vars: []varExpr{ + { + name: "a", + alias: "alpha", + t: cel.BoolType, + expr: `!'hello'.contains('lo')`, + }, + { + name: "b", + alias: "bravo", + t: cel.BoolType, + }, + }, + inlined: `!"hello".contains("lo") && b`, + folded: `false`, + }, + { + expr: `operation.system.consumers + operation.destination_consumers`, + vars: []varExpr{ + { + name: "operation.system", + t: cel.DynType, + }, + { + name: "operation.destination_consumers", + t: cel.ListType(cel.IntType), + expr: `productsToConsumers(operation.destination_products)`, + }, + { + name: "operation.destination_products", + t: cel.ListType(cel.IntType), + expr: `operation.system.products`, + }, + }, + inlined: `operation.system.consumers + productsToConsumers(operation.system.products)`, + folded: `operation.system.consumers + productsToConsumers(operation.system.products)`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + opts := []cel.EnvOption{cel.OptionalTypes(), + cel.EnableMacroCallTracking(), + cel.Function("productsToConsumers", + cel.Overload("productsToConsumers_list", + []*cel.Type{cel.ListType(cel.IntType)}, + cel.ListType(cel.IntType)))} + + varDecls := make([]cel.EnvOption, len(tc.vars)) + for i, v := range tc.vars { + varDecls[i] = cel.Variable(v.name, v.t) + } + e, err := cel.NewEnv(append(varDecls, opts...)...) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + inlinedVars := []*cel.InlineVariable{} + for _, v := range tc.vars { + if v.expr == "" { + continue + } + checked, iss := e.Compile(v.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", v.expr, iss.Err()) + } + if v.alias == "" { + inlinedVars = append(inlinedVars, cel.NewInlineVariable(v.name, checked)) + } else { + inlinedVars = append(inlinedVars, cel.NewInlineVariableWithAlias(v.name, v.alias, checked)) + } + } + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + + opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + inlined, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if inlined != tc.inlined { + t.Errorf("got %q, wanted %q", inlined, tc.inlined) + } + folder, err := cel.NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt = cel.NewStaticOptimizer(folder) + optimized, iss = opt.Optimize(e, optimized) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} + +func TestInliningOptimizerMultiStage(t *testing.T) { + type varDecl struct { + name string + t *cel.Type + } + type inlineVarExpr struct { + name string + alias string + t *cel.Type + expr string + } + tests := []struct { + expr string + vars []varDecl + inlineVars []inlineVarExpr + inlined string + folded string + }{ + { + expr: `has(a.b) ? a.b : 'default'`, + vars: []varDecl{ + { + name: "a", + t: cel.MapType(cel.StringType, cel.StringType), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "a.b", + alias: "alpha", + t: cel.StringType, + expr: `'hello'`, + }, + }, + inlined: `cel.bind(alpha, "hello", (alpha.size() != 0) ? alpha : "default")`, + folded: `"hello"`, + }, + { + expr: `has(a.b) ? a.b : ['default']`, + vars: []varDecl{ + { + name: "a", + t: cel.MapType(cel.StringType, cel.ListType(cel.StringType)), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "a.b", + alias: "alpha", + t: cel.StringType, + expr: `['hello']`, + }, + }, + inlined: `cel.bind(alpha, ["hello"], (alpha.size() != 0) ? alpha : ["default"])`, + folded: `["hello"]`, + }, + { + expr: `0 in msg.map_int64_nested_type`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + { + name: "nested_map", + t: cel.MapType(cel.IntType, cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes")), + }, + }, + inlineVars: []inlineVarExpr{ + + { + name: "msg.map_int64_nested_type", + t: cel.MapType(cel.IntType, cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes")), + expr: `nested_map`, + }, + }, + inlined: `0 in nested_map`, + folded: `0 in nested_map`, + }, + { + expr: `has(msg.single_any)`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + { + name: "unpacked_wrapper", + t: cel.NullableType(cel.StringType), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_any", + t: cel.NullableType(cel.StringType), + expr: `unpacked_wrapper`, + }, + }, + inlined: `unpacked_wrapper != null`, + folded: `unpacked_wrapper != null`, + }, + { + expr: `has(msg.single_any) ? msg.single_any : '10'`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + { + name: "unpacked_wrapper", + t: cel.NullableType(cel.StringType), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_any", + t: cel.NullableType(cel.StringType), + alias: "wrapped", + expr: `unpacked_wrapper`, + }, + }, + inlined: `cel.bind(wrapped, unpacked_wrapper, (wrapped != null) ? wrapped : "10")`, + folded: `cel.bind(wrapped, unpacked_wrapper, (wrapped != null) ? wrapped : "10")`, + }, + { + expr: `has(msg.child.payload.single_int32_wrapper)`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + }, + { + name: "unpacked_child", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.child.payload", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + alias: "payload", + expr: `unpacked_child.payload`, + }, + }, + inlined: `has(unpacked_child.payload.single_int32_wrapper)`, + folded: `has(unpacked_child.payload.single_int32_wrapper)`, + }, + { + expr: `has(msg.child.payload.single_int32_wrapper)`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + }, + { + name: "unpacked_payload", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.child.payload.single_int32_wrapper", + t: cel.NullableType(cel.IntType), + alias: "payload", + expr: `unpacked_payload.single_int32_wrapper`, + }, + }, + inlined: `has(unpacked_payload.single_int32_wrapper)`, + folded: `has(unpacked_payload.single_int32_wrapper)`, + }, + { + expr: `has(msg.child.payload.single_int32_wrapper) ? msg.child.payload.single_int32_wrapper : 1`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.NestedTestAllTypes"), + }, + { + name: "unpacked_payload", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.child.payload.single_int32_wrapper", + t: cel.NullableType(cel.IntType), + alias: "nullable_int", + expr: `unpacked_payload.single_int32_wrapper`, + }, + }, + inlined: `cel.bind(nullable_int, unpacked_payload.single_int32_wrapper, (nullable_int != null) ? nullable_int : 1)`, + folded: `cel.bind(nullable_int, unpacked_payload.single_int32_wrapper, (nullable_int != null) ? nullable_int : 1)`, + }, + { + expr: `has(msg.single_value) ? msg.single_value : null`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_value", + t: cel.NullableType(cel.DoubleType), + alias: "nullable_float", + expr: `dyn(1.5)`, + }, + }, + inlined: `cel.bind(nullable_float, dyn(1.5), (nullable_float != null) ? nullable_float : null)`, + folded: `1.5`, + }, + { + expr: `has(msg.single_any) ? msg.single_any : 42`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_any", + t: cel.IntType, + alias: "unpacked_nested", + expr: `google.expr.proto3.test.NestedTestAllTypes{}.payload.single_int32`, + }, + }, + inlined: `has(google.expr.proto3.test.NestedTestAllTypes{}.payload.single_int32) ? google.expr.proto3.test.NestedTestAllTypes{}.payload.single_int32 : 42`, + folded: `42`, + }, + { + expr: `has(msg.single_any.single_int32) ? msg.single_any.single_int32 : 42`, + vars: []varDecl{ + { + name: "msg", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + { + name: "unpacked_any", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + }, + }, + inlineVars: []inlineVarExpr{ + { + name: "msg.single_any", + t: cel.ObjectType("google.expr.proto3.test.TestAllTypes"), + alias: "unpacked_any", + expr: `google.expr.proto3.test.NestedTestAllTypes{}.payload`, + }, + }, + inlined: `cel.bind(unpacked_any, google.expr.proto3.test.NestedTestAllTypes{}.payload, has(unpacked_any.single_int32) ? unpacked_any.single_int32 : 42)`, + folded: `42`, + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.expr, func(t *testing.T) { + opts := []cel.EnvOption{ + cel.Container("google.expr"), + cel.Types(&proto3pb.TestAllTypes{}), + cel.OptionalTypes(), + cel.EnableMacroCallTracking()} + + varDecls := make([]cel.EnvOption, len(tc.vars)) + for i, v := range tc.vars { + varDecls[i] = cel.Variable(v.name, v.t) + } + e, err := cel.NewEnv(append(varDecls, opts...)...) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + inlinedVars := []*cel.InlineVariable{} + for _, v := range tc.inlineVars { + if v.expr == "" { + continue + } + checked, iss := e.Compile(v.expr) + if iss.Err() != nil { + t.Fatalf("Compile(%q) failed: %v", v.expr, iss.Err()) + } + if v.alias == "" { + inlinedVars = append(inlinedVars, cel.NewInlineVariable(v.name, checked)) + } else { + inlinedVars = append(inlinedVars, cel.NewInlineVariableWithAlias(v.name, v.alias, checked)) + } + } + checked, iss := e.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("Compile() failed: %v", iss.Err()) + } + + opt := cel.NewStaticOptimizer(cel.NewInliningOptimizer(inlinedVars...)) + optimized, iss := opt.Optimize(e, checked) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + inlined, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if inlined != tc.inlined { + t.Errorf("got %q, wanted %q", inlined, tc.inlined) + } + folder, err := cel.NewConstantFoldingOptimizer() + if err != nil { + t.Fatalf("NewConstantFoldingOptimizer() failed: %v", err) + } + opt = cel.NewStaticOptimizer(folder) + optimized, iss = opt.Optimize(e, optimized) + if iss.Err() != nil { + t.Fatalf("Optimize() generated an invalid AST: %v", iss.Err()) + } + folded, err := cel.AstToString(optimized) + if err != nil { + t.Fatalf("cel.AstToString() failed: %v", err) + } + if folded != tc.folded { + t.Errorf("got %q, wanted %q", folded, tc.folded) + } + }) + } +} diff --git a/cel/optimizer.go b/cel/optimizer.go index a2023ed7f..9422a7eb6 100644 --- a/cel/optimizer.go +++ b/cel/optimizer.go @@ -17,6 +17,7 @@ package cel import ( "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" + "github.com/google/cel-go/common/types" "github.com/google/cel-go/common/types/ref" ) @@ -197,6 +198,57 @@ type optimizerExprFactory struct { sourceInfo *ast.SourceInfo } +// CopyExpr copies the structure of the input ast.Expr and renumbers the identifiers in a manner +// consistent with the CEL parser / checker. +func (opt *optimizerExprFactory) CopyExpr(e ast.Expr) ast.Expr { + copy := opt.fac.CopyExpr(e) + copy.RenumberIDs(opt.renumberID) + return copy +} + +// NewBindMacro creates a cel.bind() call with a variable name, initialization expression, and remaining expression. +// +// Note: the macroID indicates the insertion point, the call id that matched the macro signature, which will be used +// for coordinating macro metadata with the bind call. This piece of data is what makes it possible to unparse +// optimized expressions which use the bind() call. +// +// Example: +// +// cel.bind(myVar, a && b || c, !myVar || (myVar && d)) +// - varName: myVar +// - varInit: a && b || c +// - remaining: !myVar || (myVar && d) +func (opt *optimizerExprFactory) NewBindMacro(macroID int64, varName string, varInit, remaining ast.Expr) ast.Expr { + bindID := opt.nextID() + varID := opt.nextID() + + varInit = opt.CopyExpr(varInit) + varInit.RenumberIDs(opt.renumberID) + + remaining = opt.fac.CopyExpr(remaining) + remaining.RenumberIDs(opt.renumberID) + + // Place the expanded macro form in the macro calls list so that the inlined + // call can be unparsed. + opt.sourceInfo.SetMacroCall(macroID, + opt.fac.NewMemberCall(0, "bind", + opt.fac.NewIdent(opt.nextID(), "cel"), + opt.fac.NewIdent(varID, varName), + varInit, + remaining)) + + // Replace the parent node with the intercepted inlining using cel.bind()-like + // generated comprehension AST. + return opt.fac.NewComprehension(bindID, + opt.fac.NewList(opt.nextID(), []ast.Expr{}, []int32{}), + "#unused", + varName, + opt.fac.CopyExpr(varInit), + opt.fac.NewLiteral(opt.nextID(), types.False), + opt.fac.NewIdent(varID, varName), + opt.fac.CopyExpr(remaining)) +} + // NewCall creates a global function call invocation expression. // // Example: @@ -277,6 +329,27 @@ func (opt *optimizerExprFactory) NewMapEntry(key, value ast.Expr, isOptional boo return opt.fac.NewMapEntry(opt.nextID(), key, value, isOptional) } +// NewPresenceTest creates a new presence test macro call. +// +// Example: +// +// has(msg.field_name) +// - operand: msg +// - field: field_name +func (opt *optimizerExprFactory) NewPresenceTest(macroID int64, operand ast.Expr, field string) ast.Expr { + // Copy the input operand and renumber it. + operand = opt.CopyExpr(operand) + operand.RenumberIDs(opt.renumberID) + + // Place the expanded macro form in the macro calls list so that the inlined call can be unparsed. + opt.sourceInfo.SetMacroCall(macroID, + opt.fac.NewCall(0, "has", + opt.fac.NewSelect(opt.nextID(), operand, field))) + + // Generate a new presence test macro. + return opt.fac.NewPresenceTest(opt.nextID(), opt.CopyExpr(operand), field) +} + // NewSelect creates a select expression where a field value is selected from an operand. // // Example: