Skip to content

Commit e213a66

Browse files
authored
Fix context patcher (#526)
1 parent 523a091 commit e213a66

File tree

6 files changed

+57
-15
lines changed

6 files changed

+57
-15
lines changed

builtin/builtin.go

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,6 @@ import (
1212
"github.com/expr-lang/expr/vm/runtime"
1313
)
1414

15-
type Function struct {
16-
Name string
17-
Func func(args ...any) (any, error)
18-
Fast func(arg any) any
19-
ValidateArgs func(args ...any) (any, error)
20-
Types []reflect.Type
21-
Validate func(args []reflect.Type) (reflect.Type, error)
22-
Predicate bool
23-
}
24-
2515
var (
2616
Index map[string]int
2717
Names []string

builtin/function.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package builtin
2+
3+
import (
4+
"reflect"
5+
)
6+
7+
type Function struct {
8+
Name string
9+
Func func(args ...any) (any, error)
10+
Fast func(arg any) any
11+
ValidateArgs func(args ...any) (any, error)
12+
Types []reflect.Type
13+
Validate func(args []reflect.Type) (reflect.Type, error)
14+
Predicate bool
15+
}
16+
17+
func (f *Function) Type() reflect.Type {
18+
if len(f.Types) > 0 {
19+
return f.Types[0]
20+
}
21+
return reflect.TypeOf(f.Func)
22+
}

checker/checker.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,10 @@ func (v *checker) ident(node ast.Node, name string, strict, builtins bool) (refl
169169
}
170170
if builtins {
171171
if fn, ok := v.config.Functions[name]; ok {
172-
return functionType, info{fn: fn}
172+
return fn.Type(), info{fn: fn}
173173
}
174174
if fn, ok := v.config.Builtins[name]; ok {
175-
return functionType, info{fn: fn}
175+
return fn.Type(), info{fn: fn}
176176
}
177177
}
178178
if v.config.Strict && strict {
@@ -833,7 +833,7 @@ func (v *checker) checkFunction(f *builtin.Function, node ast.Node, arguments []
833833
}
834834
return t, info{}
835835
} else if len(f.Types) == 0 {
836-
t, err := v.checkArguments(f.Name, functionType, false, arguments, node)
836+
t, err := v.checkArguments(f.Name, f.Type(), false, arguments, node)
837837
if err != nil {
838838
if v.err == nil {
839839
v.err = err

checker/checker_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ func TestCheck_builtin_without_call(t *testing.T) {
994994
err string
995995
}{
996996
{`len + 1`, "invalid operation: + (mismatched types func(...interface {}) (interface {}, error) and int) (1:5)\n | len + 1\n | ....^"},
997-
{`string.A`, "type func(...interface {}) (interface {}, error)[string] is undefined (1:8)\n | string.A\n | .......^"},
997+
{`string.A`, "type func(interface {}) string[string] is undefined (1:8)\n | string.A\n | .......^"},
998998
}
999999

10001000
for _, test := range tests {

checker/types.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ var (
1818
anyType = reflect.TypeOf(new(any)).Elem()
1919
timeType = reflect.TypeOf(time.Time{})
2020
durationType = reflect.TypeOf(time.Duration(0))
21-
functionType = reflect.TypeOf(new(func(...any) (any, error))).Elem()
2221
)
2322

2423
func combined(a, b reflect.Type) reflect.Type {

patcher/with_context_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,34 @@ func TestWithContext(t *testing.T) {
3030
require.NoError(t, err)
3131
require.Equal(t, 42, output)
3232
}
33+
34+
func TestWithContext_with_env_Function(t *testing.T) {
35+
env := map[string]any{
36+
"ctx": context.TODO(),
37+
}
38+
39+
fn := expr.Function("fn",
40+
func(params ...any) (any, error) {
41+
ctx := params[0].(context.Context)
42+
a := params[1].(int)
43+
44+
return ctx.Value("value").(int) + a, nil
45+
},
46+
new(func(context.Context, int) int),
47+
)
48+
49+
program, err := expr.Compile(
50+
`fn(40)`,
51+
expr.Env(env),
52+
expr.WithContext("ctx"),
53+
fn,
54+
)
55+
require.NoError(t, err)
56+
57+
ctx := context.WithValue(context.Background(), "value", 2)
58+
env["ctx"] = ctx
59+
60+
output, err := expr.Run(program, env)
61+
require.NoError(t, err)
62+
require.Equal(t, 42, output)
63+
}

0 commit comments

Comments
 (0)