Skip to content

Commit 4cac5f6

Browse files
feat: Operator overload from Function (#408)
* feat: rewritten Operator overload from Function * fix: add test
1 parent 1719809 commit 4cac5f6

File tree

5 files changed

+246
-17
lines changed

5 files changed

+246
-17
lines changed

checker/checker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func (v *checker) BinaryNode(node *ast.BinaryNode) (reflect.Type, info) {
240240

241241
// check operator overloading
242242
if fns, ok := v.config.Operators[node.Operator]; ok {
243-
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, l, r)
243+
t, _, ok := conf.FindSuitableOperatorOverload(fns, v.config.Types, v.config.Functions, l, r)
244244
if ok {
245245
return t, info{}
246246
}

conf/config.go

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"github.com/expr-lang/expr/vm/runtime"
1010
)
1111

12+
type FunctionTable map[string]*builtin.Function
13+
1214
type Config struct {
1315
Env any
1416
Types TypesTable
@@ -85,21 +87,43 @@ func (c *Config) ConstExpr(name string) {
8587
func (c *Config) Check() {
8688
for operator, fns := range c.Operators {
8789
for _, fn := range fns {
88-
fnType, ok := c.Types[fn]
89-
if !ok || fnType.Type.Kind() != reflect.Func {
90+
fnType, foundType := c.Types[fn]
91+
fnFunc, foundFunc := c.Functions[fn]
92+
if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) {
9093
panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, operator))
9194
}
92-
requiredNumIn := 2
93-
if fnType.Method {
94-
requiredNumIn = 3 // As first argument of method is receiver.
95+
96+
if foundType {
97+
checkType(fnType, fn, operator)
9598
}
96-
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
97-
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
99+
if foundFunc {
100+
checkFunc(fnFunc, fn, operator)
98101
}
99102
}
100103
}
101104
}
102105

106+
func checkType(fnType Tag, fn string, operator string) {
107+
requiredNumIn := 2
108+
if fnType.Method {
109+
requiredNumIn = 3 // As first argument of method is receiver.
110+
}
111+
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
112+
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
113+
}
114+
}
115+
116+
func checkFunc(fn *builtin.Function, name string, operator string) {
117+
if len(fn.Types) == 0 {
118+
panic(fmt.Errorf("function %s for %s operator misses types", name, operator))
119+
}
120+
for _, t := range fn.Types {
121+
if t.NumIn() != 2 || t.NumOut() != 1 {
122+
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", name, operator))
123+
}
124+
}
125+
}
126+
103127
func (c *Config) IsOverridden(name string) bool {
104128
if _, ok := c.Functions[name]; ok {
105129
return true

conf/operators.go

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,65 @@ import (
1010
// Functions should be provided in the environment to allow operator overloading.
1111
type OperatorsTable map[string][]string
1212

13-
func FindSuitableOperatorOverload(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) {
13+
func FindSuitableOperatorOverload(fns []string, types TypesTable, funcs FunctionTable, l, r reflect.Type) (reflect.Type, string, bool) {
14+
t, fn, ok := FindSuitableOperatorOverloadInFunctions(fns, funcs, l, r)
15+
if !ok {
16+
t, fn, ok = FindSuitableOperatorOverloadInTypes(fns, types, l, r)
17+
}
18+
return t, fn, ok
19+
}
20+
21+
func FindSuitableOperatorOverloadInTypes(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) {
1422
for _, fn := range fns {
15-
fnType := types[fn]
23+
fnType, ok := types[fn]
24+
if !ok {
25+
continue
26+
}
1627
firstInIndex := 0
1728
if fnType.Method {
1829
firstInIndex = 1 // As first argument to method is receiver.
1930
}
20-
firstArgType := fnType.Type.In(firstInIndex)
21-
secondArgType := fnType.Type.In(firstInIndex + 1)
31+
ret, done := checkTypeSuits(fnType.Type, l, r, firstInIndex)
32+
if done {
33+
return ret, fn, true
34+
}
35+
}
36+
return nil, "", false
37+
}
2238

23-
firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
24-
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
25-
if firstArgumentFit && secondArgumentFit {
26-
return fnType.Type.Out(0), fn, true
39+
func FindSuitableOperatorOverloadInFunctions(fns []string, funcs FunctionTable, l, r reflect.Type) (reflect.Type, string, bool) {
40+
for _, fn := range fns {
41+
fnType, ok := funcs[fn]
42+
if !ok {
43+
continue
44+
}
45+
firstInIndex := 0
46+
for _, overload := range fnType.Types {
47+
ret, done := checkTypeSuits(overload, l, r, firstInIndex)
48+
if done {
49+
return ret, fn, true
50+
}
2751
}
2852
}
2953
return nil, "", false
3054
}
3155

56+
func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) {
57+
firstArgType := t.In(firstInIndex)
58+
secondArgType := t.In(firstInIndex + 1)
59+
60+
firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
61+
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
62+
if firstArgumentFit && secondArgumentFit {
63+
return t.Out(0), true
64+
}
65+
return nil, false
66+
}
67+
3268
type OperatorPatcher struct {
3369
Operators OperatorsTable
3470
Types TypesTable
71+
Functions FunctionTable
3572
}
3673

3774
func (p *OperatorPatcher) Visit(node *ast.Node) {
@@ -48,7 +85,7 @@ func (p *OperatorPatcher) Visit(node *ast.Node) {
4885
leftType := binaryNode.Left.Type()
4986
rightType := binaryNode.Right.Type()
5087

51-
ret, fn, ok := FindSuitableOperatorOverload(fns, p.Types, leftType, rightType)
88+
ret, fn, ok := FindSuitableOperatorOverload(fns, p.Types, p.Functions, leftType, rightType)
5289
if ok {
5390
newNode := &ast.CallNode{
5491
Callee: &ast.IdentifierNode{Value: fn},

expr.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ func Compile(input string, ops ...Option) (*vm.Program, error) {
192192
config.Visitors = append(config.Visitors, &conf.OperatorPatcher{
193193
Operators: config.Operators,
194194
Types: config.Types,
195+
Functions: config.Functions,
195196
})
196197
}
197198

test/operator/operator_test.go

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package operator_test
22

33
import (
4+
"fmt"
45
"testing"
56
"time"
67

@@ -55,3 +56,169 @@ func TestOperator_interface(t *testing.T) {
5556
require.NoError(t, err)
5657
require.Equal(t, true, output)
5758
}
59+
60+
type Value struct {
61+
Int int
62+
}
63+
64+
func TestOperator_Function(t *testing.T) {
65+
env := map[string]interface{}{
66+
"foo": Value{1},
67+
"bar": Value{2},
68+
}
69+
70+
tests := []struct {
71+
input string
72+
want int
73+
}{
74+
{
75+
input: `foo + bar`,
76+
want: 3,
77+
},
78+
{
79+
input: `2 + 4`,
80+
want: 6,
81+
},
82+
}
83+
84+
for _, tt := range tests {
85+
t.Run(fmt.Sprintf(`opertor function helper test %s`, tt.input), func(t *testing.T) {
86+
program, err := expr.Compile(
87+
tt.input,
88+
expr.Env(env),
89+
expr.Operator("+", "Add", "AddInt"),
90+
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
91+
return args[0].(Value).Int + args[1].(Value).Int, nil
92+
},
93+
new(func(_ Value, __ Value) int),
94+
),
95+
expr.Function("AddInt", func(args ...interface{}) (interface{}, error) {
96+
return args[0].(int) + args[1].(int), nil
97+
},
98+
new(func(_ int, __ int) int),
99+
),
100+
)
101+
require.NoError(t, err)
102+
103+
output, err := expr.Run(program, env)
104+
require.NoError(t, err)
105+
require.Equal(t, tt.want, output)
106+
})
107+
}
108+
109+
}
110+
111+
func TestOperator_Function_WithTypes(t *testing.T) {
112+
env := map[string]interface{}{
113+
"foo": Value{1},
114+
"bar": Value{2},
115+
}
116+
117+
require.PanicsWithError(t, `function Add for + operator misses types`, func() {
118+
_, _ = expr.Compile(
119+
`foo + bar`,
120+
expr.Env(env),
121+
expr.Operator("+", "Add", "AddInt"),
122+
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
123+
return args[0].(Value).Int + args[1].(Value).Int, nil
124+
}),
125+
)
126+
})
127+
128+
require.PanicsWithError(t, `function Add for + operator does not have a correct signature`, func() {
129+
_, _ = expr.Compile(
130+
`foo + bar`,
131+
expr.Env(env),
132+
expr.Operator("+", "Add", "AddInt"),
133+
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
134+
return args[0].(Value).Int + args[1].(Value).Int, nil
135+
},
136+
new(func(_ Value) int),
137+
),
138+
)
139+
})
140+
141+
}
142+
143+
func TestOperator_FunctionOverTypesPrecedence(t *testing.T) {
144+
env := struct {
145+
Add func(a, b int) int
146+
}{
147+
Add: func(a, b int) int {
148+
return a + b
149+
},
150+
}
151+
152+
program, err := expr.Compile(
153+
`1 + 2`,
154+
expr.Env(env),
155+
expr.Operator("+", "Add"),
156+
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
157+
// Wierd function that returns 100 + a + b in testing purposes.
158+
return args[0].(int) + args[1].(int) + 100, nil
159+
},
160+
new(func(_ int, __ int) int),
161+
),
162+
)
163+
require.NoError(t, err)
164+
165+
output, err := expr.Run(program, env)
166+
require.NoError(t, err)
167+
require.Equal(t, 103, output)
168+
}
169+
170+
func TestOperator_CanBeDefinedEitherInTypesOrInFunctions(t *testing.T) {
171+
env := struct {
172+
Add func(a, b int) int
173+
}{
174+
Add: func(a, b int) int {
175+
return a + b
176+
},
177+
}
178+
179+
program, err := expr.Compile(
180+
`1 + 2`,
181+
expr.Env(env),
182+
expr.Operator("+", "Add", "AddValues"),
183+
expr.Function("AddValues", func(args ...interface{}) (interface{}, error) {
184+
return args[0].(Value).Int + args[1].(Value).Int, nil
185+
},
186+
new(func(_ Value, __ Value) int),
187+
),
188+
)
189+
require.NoError(t, err)
190+
191+
output, err := expr.Run(program, env)
192+
require.NoError(t, err)
193+
require.Equal(t, 3, output)
194+
}
195+
196+
func TestOperator_Polymorphic(t *testing.T) {
197+
env := struct {
198+
Add func(a, b int) int
199+
Foo Value
200+
Bar Value
201+
}{
202+
Add: func(a, b int) int {
203+
return a + b
204+
},
205+
Foo: Value{1},
206+
Bar: Value{2},
207+
}
208+
209+
program, err := expr.Compile(
210+
`1 + 2 + (Foo + Bar)`,
211+
expr.Env(env),
212+
expr.Operator("+", "Add", "AddValues"),
213+
expr.Function("AddValues", func(args ...interface{}) (interface{}, error) {
214+
return args[0].(Value).Int + args[1].(Value).Int, nil
215+
},
216+
new(func(_ Value, __ Value) int),
217+
),
218+
)
219+
require.NoError(t, err)
220+
221+
output, err := expr.Run(program, env)
222+
require.NoError(t, err)
223+
require.Equal(t, 6, output)
224+
}

0 commit comments

Comments
 (0)