diff --git a/cel/env.go b/cel/env.go index f04900c4c..a7aa6db34 100644 --- a/cel/env.go +++ b/cel/env.go @@ -141,6 +141,7 @@ type Env struct { provider types.Provider features map[int]bool appliedFeatures map[int]bool + limits map[limitID]int libraries map[string]SingletonLibrary validators []ASTValidator costOptions []checker.CostOption @@ -288,6 +289,15 @@ func (e *Env) ToConfig(name string) (*env.Config, error) { conf.AddFeatures(env.NewFeature(featName, enabled)) } + for id, val := range e.limits { + limitName, found := limitNameByID(id) + if !found || val == 0 { + // skip if explicitly defaulted or not supported in config + continue + } + conf.AddLimits(env.NewLimit(limitName, val)) + } + // Sort repeated fields in config where reasonable to make the export // stable. slices.SortFunc(conf.Imports, func(a *env.Import, b *env.Import) int { @@ -314,6 +324,10 @@ func (e *Env) ToConfig(name string) (*env.Config, error) { return strings.Compare(a.Name, b.Name) }) + slices.SortFunc(conf.Limits, func(a *env.Limit, b *env.Limit) int { + return strings.Compare(a.Name, b.Name) + }) + return conf, nil } @@ -361,6 +375,7 @@ func NewCustomEnv(opts ...EnvOption) (*Env, error) { provider: registry, features: map[int]bool{}, appliedFeatures: map[int]bool{}, + limits: map[limitID]int{}, libraries: map[string]SingletonLibrary{}, validators: []ASTValidator{}, progOpts: []ProgramOption{}, @@ -525,6 +540,10 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { for k, v := range e.appliedFeatures { appliedFeaturesCopy[k] = v } + limitsCopy := make(map[limitID]int, len(e.limits)) + for k, v := range e.limits { + limitsCopy[k] = v + } funcsCopy := make(map[string]*decls.FunctionDecl, len(e.functions)) for k, v := range e.functions { funcsCopy[k] = v @@ -547,6 +566,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { progOpts: progOptsCopy, adapter: adapter, features: featuresCopy, + limits: limitsCopy, appliedFeatures: appliedFeaturesCopy, libraries: libsCopy, validators: validatorsCopy, @@ -813,6 +833,15 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) { if e.HasFeature(featureIdentEscapeSyntax) { prsrOpts = append(prsrOpts, parser.EnableIdentEscapeSyntax(true)) } + if l := e.limits[limitParseErrorRecovery]; l != 0 { + prsrOpts = append(prsrOpts, parser.ErrorRecoveryLimit(l)) + } + if l := e.limits[limitCodePointSize]; l != 0 { + prsrOpts = append(prsrOpts, parser.ExpressionSizeCodePointLimit(l)) + } + if l := e.limits[limitParseRecursionDepth]; l != 0 { + prsrOpts = append(prsrOpts, parser.MaxRecursionDepth(l)) + } e.prsr, err = parser.NewParser(prsrOpts...) if err != nil { return nil, err diff --git a/cel/env_test.go b/cel/env_test.go index fa568257c..38322caa1 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -734,6 +734,52 @@ func TestEnvFromConfig(t *testing.T) { }, }, }, + { + name: "limits_recursion", + conf: env.NewConfig("limits"). + AddLimits(env.NewLimit("cel.limit.parse_recursion_depth", 3)), + exprs: []exprCase{ + { + name: "under limit", + expr: "1 + 2", + out: types.Int(3), + }, + { + name: "over limit", + expr: "1 + 2 + 3 + 4 + 5", + iss: errors.New("max recursion depth exceeded"), + }, + }, + }, + { + name: "limits_codepoints", + conf: env.NewConfig("limits"). + AddLimits(env.NewLimit("cel.limit.expression_code_points", 10)), + exprs: []exprCase{ + { + name: "under limit", + expr: "'12345'", + out: types.String("12345"), + }, + { + name: "over limit", + expr: "'1234567890'", + iss: errors.New("code point size exceeds limit: size: 12, limit 10"), + }, + }, + }, + { + name: "limits_error_recovery", + conf: env.NewConfig("limits"). + AddLimits(env.NewLimit("cel.limit.parse_error_recovery", 4)), + exprs: []exprCase{ + { + name: "over limit", + expr: "a ? b ((?))", + iss: errors.New("Syntax error: error recovery attempt limit exceeded: 4"), + }, + }, + }, { name: "validators", conf: env.NewConfig("validators"). diff --git a/cel/options.go b/cel/options.go index fee67323c..f36188921 100644 --- a/cel/options.go +++ b/cel/options.go @@ -93,6 +93,40 @@ func featureIDByName(name string) (int, bool) { return 0, false } +// limitID is used as a key for configurable limits. These are options that +// support exporting to YAML environment config. +type limitID int + +const ( + _ = limitID(iota) + // The number of recursive calls permitted in parsing. + limitParseRecursionDepth + // The number of code points permitted in an input expression string. + limitCodePointSize + // The number of attempts to recover from a parse error. + limitParseErrorRecovery +) + +var limitIDsToNames = map[limitID]string{ + limitCodePointSize: "cel.limit.expression_code_points", + limitParseErrorRecovery: "cel.limit.parse_error_recovery", + limitParseRecursionDepth: "cel.limit.parse_recursion_depth", +} + +func limitNameByID(id limitID) (string, bool) { + v, ok := limitIDsToNames[id] + return v, ok +} + +func limitIDByName(name string) (limitID, bool) { + for k, v := range limitIDsToNames { + if v == name { + return k, true + } + } + return limitID(0), false +} + // EnvOption is a functional interface for configuring the environment. type EnvOption func(e *Env) (*Env, error) @@ -564,7 +598,7 @@ func configToEnvOptions(config *env.Config, provider types.Provider, optFactorie envOpts = append(envOpts, FunctionDecls(funcs...)) } - // Configure features + // Configure features and common limits. for _, feat := range config.Features { // Note, if a feature is not found, it is skipped as it is possible the feature // is not intended to be supported publicly. In the future, a refinement of @@ -575,6 +609,12 @@ func configToEnvOptions(config *env.Config, provider types.Provider, optFactorie } } + for _, limit := range config.Limits { + if id, found := limitIDByName(limit.Name); found { + envOpts = append(envOpts, setLimit(id, limit.Value)) + } + } + // Configure validators for _, val := range config.Validators { if fac, found := astValidatorFactories[val.Name]; found { @@ -847,22 +887,32 @@ func features(flag int, enabled bool) EnvOption { } } -// ParserRecursionLimit adjusts the AST depth the parser will tolerate. -// Defaults defined in the parser package. -func ParserRecursionLimit(limit int) EnvOption { +func setLimit(id limitID, limit int) EnvOption { + if limit < 0 { + limit = -1 + } return func(e *Env) (*Env, error) { - e.prsrOpts = append(e.prsrOpts, parser.MaxRecursionDepth(limit)) + e.limits[id] = limit return e, nil } } -// ParserExpressionSizeLimit adjusts the number of code points the expression parser is allowed to parse. +// ParserRecursionLimit adjusts the AST depth the parser will tolerate. +// Defaults defined in the parser package. +func ParserRecursionLimit(limit int) EnvOption { + return setLimit(limitParseRecursionDepth, limit) +} + +// ParserRecursionLimit adjusts the AST depth the parser will tolerate. // Defaults defined in the parser package. +func ParserErrorRecoveryLimit(limit int) EnvOption { + return setLimit(limitParseErrorRecovery, limit) +} + +// ParserExpressionSizeLimit adjusts the number of code points the expression parser is allowed to parse. +// Defaults are defined in the parser package. A negative value means unbounded. func ParserExpressionSizeLimit(limit int) EnvOption { - return func(e *Env) (*Env, error) { - e.prsrOpts = append(e.prsrOpts, parser.ExpressionSizeCodePointLimit(limit)) - return e, nil - } + return setLimit(limitCodePointSize, limit) } // EnableHiddenAccumulatorName sets the parser to use the identifier '@result' for accumulators diff --git a/common/env/env.go b/common/env/env.go index deff85766..7f5732e94 100644 --- a/common/env/env.go +++ b/common/env/env.go @@ -50,6 +50,7 @@ type Config struct { Functions []*Function `yaml:"functions,omitempty"` Validators []*Validator `yaml:"validators,omitempty"` Features []*Feature `yaml:"features,omitempty"` + Limits []*Limit `yaml:"limits,omitempty"` } // Validate validates the whole configuration is well-formed. @@ -92,6 +93,11 @@ func (c *Config) Validate() error { errs = append(errs, err) } } + for _, limit := range c.Limits { + if err := limit.Validate(); err != nil { + errs = append(errs, err) + } + } for _, val := range c.Validators { if err := val.Validate(); err != nil { errs = append(errs, err) @@ -206,6 +212,12 @@ func (c *Config) AddFeatures(feats ...*Feature) *Config { return c } +// AddLimits appends one or more limits to the config. +func (c *Config) AddLimits(limits ...*Limit) *Config { + c.Limits = append(c.Limits, limits...) + return c +} + // NewImport returns a serializable import value from the qualified type name. func NewImport(name string) *Import { return &Import{Name: name} @@ -734,6 +746,25 @@ func (feat *Feature) Validate() error { return nil } +type Limit struct { + Name string `yaml:"name"` + Value int `yaml:"value"` +} + +func NewLimit(name string, value int) *Limit { + return &Limit{name, value} +} + +func (l *Limit) Validate() error { + if l == nil { + return errors.New("invalid limit: nil") + } + if l.Name == "" { + return errors.New("invalid limit: missing name") + } + return nil +} + // NewTypeDesc describes a simple or complex type with parameters. func NewTypeDesc(typeName string, params ...*TypeDesc) *TypeDesc { return &TypeDesc{TypeName: typeName, Params: params} diff --git a/common/env/env_test.go b/common/env/env_test.go index b22a8a444..683ca3cf4 100644 --- a/common/env/env_test.go +++ b/common/env/env_test.go @@ -102,6 +102,8 @@ func TestConfig(t *testing.T) { ), ).AddFeatures( NewFeature("cel.feature.macro_call_tracking", true), + ).AddLimits( + NewLimit("cel.limit.parse_recursion_depth", 7), ).AddValidators( NewValidator("cel.validator.duration"), NewValidator("cel.validator.matches"), @@ -207,6 +209,19 @@ func TestConfig(t *testing.T) { } } } + if len(got.Limits) != len(tc.want.Limits) { + t.Errorf("Limits count got %d, wanted %d", len(got.Limits), len(tc.want.Limits)) + } else { + for i, l := range got.Limits { + wl := tc.want.Limits[i] + if l.Name != wl.Name { + t.Errorf("Limits[%d] got name %s, wanted %s", i, l.Name, wl.Name) + } + if l.Value != wl.Value { + t.Errorf("Limits[%d] got enabled %d, wanted %d", i, l.Value, wl.Value) + } + } + } if len(got.Validators) != len(tc.want.Validators) { t.Errorf("Validators count got %d, wanted %d", len(got.Validators), len(tc.want.Validators)) } else { diff --git a/common/env/testdata/extended_env.yaml b/common/env/testdata/extended_env.yaml index fed70b16c..9796f0853 100644 --- a/common/env/testdata/extended_env.yaml +++ b/common/env/testdata/extended_env.yaml @@ -59,3 +59,6 @@ validators: features: - name: cel.feature.macro_call_tracking enabled: true +limits: + - name: cel.limit.parse_recursion_depth + value: 7