Skip to content

Allow to override builtins #522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 59 additions & 25 deletions builtin/builtin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,37 +284,71 @@ func TestBuiltin_memory_limits(t *testing.T) {
}
}

func TestBuiltin_disallow_builtins_override(t *testing.T) {
t.Run("via env", func(t *testing.T) {
env := map[string]any{
"len": func() int { return 42 },
"repeat": func(a string) string {
return a
},
func TestBuiltin_allow_builtins_override(t *testing.T) {
t.Run("via env var", func(t *testing.T) {
for _, name := range builtin.Names {
t.Run(name, func(t *testing.T) {
env := map[string]any{
name: "hello world",
}
program, err := expr.Compile(name, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
assert.Equal(t, "hello world", out)
})
}
})
t.Run("via env func", func(t *testing.T) {
for _, name := range builtin.Names {
t.Run(name, func(t *testing.T) {
env := map[string]any{
name: func() int { return 1 },
}
program, err := expr.Compile(fmt.Sprintf("%s()", name), expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
assert.Equal(t, 1, out)
})
}
assert.Panics(t, func() {
_, _ = expr.Compile(`string(len("foo")) + repeat("0", 2)`, expr.Env(env))
})
})
t.Run("via expr.Function", func(t *testing.T) {
length := expr.Function("len",
func(params ...any) (any, error) {
return 42, nil
},
new(func() int),
)
repeat := expr.Function("repeat",
func(params ...any) (any, error) {
return params[0], nil
},
new(func(string) string),
)
assert.Panics(t, func() {
_, _ = expr.Compile(`string(len("foo")) + repeat("0", 2)`, length, repeat)
})
for _, name := range builtin.Names {
t.Run(name, func(t *testing.T) {
fn := expr.Function(name,
func(params ...any) (any, error) {
return 42, nil
},
new(func() int),
)
program, err := expr.Compile(fmt.Sprintf("%s()", name), fn)
require.NoError(t, err)

out, err := expr.Run(program, nil)
require.NoError(t, err)
assert.Equal(t, 42, out)
})
}
})
}

func TestBuiltin_override_and_still_accessible(t *testing.T) {
env := map[string]any{
"len": func() int { return 42 },
"all": []int{1, 2, 3},
}

program, err := expr.Compile(`::all(all, #>0) && len() == 42 && ::len(all) == 3`, expr.Env(env))
require.NoError(t, err)

out, err := expr.Run(program, env)
require.NoError(t, err)
assert.Equal(t, true, out)
}

func TestBuiltin_DisableBuiltin(t *testing.T) {
t.Run("via env", func(t *testing.T) {
for _, b := range builtin.Builtins {
Expand Down
24 changes: 13 additions & 11 deletions checker/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,25 @@ func (v *checker) IdentifierNode(node *ast.IdentifierNode) (reflect.Type, info)
if node.Value == "$env" {
return mapType, info{}
}
if fn, ok := v.config.Builtins[node.Value]; ok {
return functionType, info{fn: fn}
}
if fn, ok := v.config.Functions[node.Value]; ok {
return functionType, info{fn: fn}
}
return v.env(node, node.Value, true)
return v.ident(node, node.Value, true, true)
}

// env method returns type of environment variable. env only lookups for
// environment variables, no builtins, no custom functions.
func (v *checker) env(node ast.Node, name string, strict bool) (reflect.Type, info) {
// ident method returns type of environment variable, builtin or function.
func (v *checker) ident(node ast.Node, name string, strict, builtins bool) (reflect.Type, info) {
if t, ok := v.config.Types[name]; ok {
if t.Ambiguous {
return v.error(node, "ambiguous identifier %v", name)
}
return t.Type, info{method: t.Method}
}
if builtins {
if fn, ok := v.config.Functions[name]; ok {
return functionType, info{fn: fn}
}
if fn, ok := v.config.Builtins[name]; ok {
return functionType, info{fn: fn}
}
}
if v.config.Strict && strict {
return v.error(node, "unknown name %v", name)
}
Expand Down Expand Up @@ -433,6 +434,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
base, _ := v.visit(node.Node)
prop, _ := v.visit(node.Property)

// $env variable
if an, ok := node.Node.(*ast.IdentifierNode); ok && an.Value == "$env" {
if name, ok := node.Property.(*ast.StringNode); ok {
strict := v.config.Strict
Expand All @@ -443,7 +445,7 @@ func (v *checker) MemberNode(node *ast.MemberNode) (reflect.Type, info) {
// should throw error if field is not found & v.config.Strict.
strict = false
}
return v.env(node, name.Value, strict)
return v.ident(node, name.Value, strict, false /* no builtins and no functions */)
}
return anyType, info{}
}
Expand Down
22 changes: 8 additions & 14 deletions conf/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,20 +98,14 @@ func (c *Config) Check() {
}
}
}
for fnName, t := range c.Types {
if kind(t.Type) == reflect.Func {
for _, b := range c.Builtins {
if b.Name == fnName {
panic(fmt.Errorf(`cannot override builtin %s(): use expr.DisableBuiltin("%s") to override`, b.Name, b.Name))
}
}
}
}

func (c *Config) IsOverridden(name string) bool {
if _, ok := c.Functions[name]; ok {
return true
}
for _, f := range c.Functions {
for _, b := range c.Builtins {
if b.Name == f.Name {
panic(fmt.Errorf(`cannot override builtin %s(); use expr.DisableBuiltin("%s") to override`, f.Name, f.Name))
}
}
if _, ok := c.Types[name]; ok {
return true
}
return false
}
8 changes: 8 additions & 0 deletions parser/lexer/lexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,14 @@ func TestLex(t *testing.T) {
{Kind: EOF},
},
},
{
`: ::`,
[]Token{
{Kind: Operator, Value: ":"},
{Kind: Operator, Value: "::"},
{Kind: EOF},
},
},
}

for _, test := range tests {
Expand Down
5 changes: 4 additions & 1 deletion parser/lexer/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ func root(l *lexer) stateFn {
case r == '|':
l.accept("|")
l.emit(Operator)
case r == ':':
l.accept(":")
l.emit(Operator)
case strings.ContainsRune("([{", r):
l.emit(Bracket)
case strings.ContainsRune(")]}", r):
l.emit(Bracket)
case strings.ContainsRune(",:;%+-^", r): // single rune operator
case strings.ContainsRune(",;%+-^", r): // single rune operator
l.emit(Operator)
case strings.ContainsRune("&!=*<>", r): // possible double rune operator
l.accept("&=*")
Expand Down
21 changes: 15 additions & 6 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,13 @@ func (p *parser) parsePrimary() Node {
}
}

if token.Is(Operator, "::") {
p.next()
token = p.current
p.expect(Identifier)
return p.parsePostfixExpression(p.parseCall(token, false))
}

return p.parseSecondary()
}

Expand All @@ -300,7 +307,7 @@ func (p *parser) parseSecondary() Node {
node.SetLocation(token.Location)
return node
default:
node = p.parseCall(token)
node = p.parseCall(token, true)
}

case Number:
Expand Down Expand Up @@ -379,15 +386,17 @@ func (p *parser) toFloatNode(number float64) Node {
return &FloatNode{Value: number}
}

func (p *parser) parseCall(token Token) Node {
func (p *parser) parseCall(token Token, checkOverrides bool) Node {
var node Node
if p.current.Is(Bracket, "(") {
var arguments []Node

if b, ok := predicates[token.Value]; ok {
p.expect(Bracket, "(")
isOverridden := p.config.IsOverridden(token.Value)
isOverridden = isOverridden && checkOverrides

// TODO: Refactor parser to use builtin.Builtins instead of predicates map.
// TODO: Refactor parser to use builtin.Builtins instead of predicates map.
if b, ok := predicates[token.Value]; ok && !isOverridden {
p.expect(Bracket, "(")

if b.arity == 1 {
arguments = make([]Node, 1)
Expand Down Expand Up @@ -417,7 +426,7 @@ func (p *parser) parseCall(token Token) Node {
Arguments: arguments,
}
node.SetLocation(token.Location)
} else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] {
} else if _, ok := builtin.Index[token.Value]; ok && !p.config.Disabled[token.Value] && !isOverridden {
node = &BuiltinNode{
Name: token.Value,
Arguments: p.parseArguments(),
Expand Down
23 changes: 23 additions & 0 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,29 @@ world`},
},
},
},
{
`::split("a,b,c", ",")`,
&BuiltinNode{
Name: "split",
Arguments: []Node{
&StringNode{Value: "a,b,c"},
&StringNode{Value: ","},
},
},
},
{
`::split("a,b,c", ",")[0]`,
&MemberNode{
Node: &BuiltinNode{
Name: "split",
Arguments: []Node{
&StringNode{Value: "a,b,c"},
&StringNode{Value: ","},
},
},
Property: &IntegerNode{Value: 0},
},
},
}
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
Expand Down