Skip to content

Commit 71e88c2

Browse files
feat: use types provided in Function to discriminate proper operator overload
1 parent fc691e3 commit 71e88c2

File tree

4 files changed

+65
-28
lines changed

4 files changed

+65
-28
lines changed

conf/operators.go

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,40 @@ type OperatorsTable map[string][]string
1313
func FindSuitableOperatorOverload(fns []string, types TypesTable, l, r reflect.Type) (reflect.Type, string, bool) {
1414
for _, fn := range fns {
1515
fnType := types[fn]
16+
1617
firstInIndex := 0
1718
if fnType.Method {
1819
firstInIndex = 1 // As first argument to method is receiver.
1920
}
20-
var firstArgType reflect.Type
21-
var secondArgType reflect.Type
22-
23-
if fnType.Type.NumIn() == 1 && fnType.Type.In(0).Kind() == reflect.Slice {
24-
firstArgType = fnType.Type.In(0).Elem()
25-
secondArgType = fnType.Type.In(0).Elem()
21+
if fnType.Overloads != nil {
22+
for _, overload := range fnType.Overloads {
23+
ret, done := checkType(overload, l, r, firstInIndex)
24+
if done {
25+
return ret, fn, true
26+
}
27+
}
2628
} else {
27-
firstArgType = fnType.Type.In(firstInIndex)
28-
secondArgType = fnType.Type.In(firstInIndex + 1)
29-
}
30-
31-
firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
32-
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
33-
if firstArgumentFit && secondArgumentFit {
34-
return fnType.Type.Out(0), fn, true
29+
ret, done := checkType(fnType.Type, l, r, firstInIndex)
30+
if done {
31+
return ret, fn, true
32+
}
3533
}
3634
}
3735
return nil, "", false
3836
}
3937

38+
func checkType(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) {
39+
firstArgType := t.In(firstInIndex)
40+
secondArgType := t.In(firstInIndex + 1)
41+
42+
firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
43+
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
44+
if firstArgumentFit && secondArgumentFit {
45+
return t.Out(0), true
46+
}
47+
return nil, false
48+
}
49+
4050
type OperatorPatcher struct {
4151
Operators OperatorsTable
4252
Types TypesTable

conf/types_table.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ type Tag struct {
1010
FieldIndex []int
1111
Method bool
1212
MethodIndex int
13+
Overloads []reflect.Type
1314
}
1415

1516
type TypesTable map[string]Tag

expr.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ func Function(name string, fn func(params ...interface{}) (interface{}, error),
128128
Types: ts,
129129
}
130130
c.Types[name] = conf.Tag{
131-
Type: reflect.TypeOf(fn),
131+
Type: reflect.TypeOf(fn),
132+
Overloads: ts,
132133
}
133134
}
134135
}

test/operator/operator_test.go

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package operator_test
22

33
import (
4+
"fmt"
45
"testing"
56
"time"
67

@@ -65,19 +66,43 @@ func TestOperator_Function(t *testing.T) {
6566
"bar": Value{2},
6667
}
6768

68-
program, err := expr.Compile(
69-
`foo + bar`,
70-
expr.Env(env),
71-
expr.Operator("+", "Add"),
72-
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
73-
return args[0].(Value).Int + args[1].(Value).Int, nil
69+
tests := []struct {
70+
input string
71+
want int
72+
}{
73+
{
74+
input: `foo + bar`,
75+
want: 3,
7476
},
75-
new(func(_ Value, __ Value) int),
76-
),
77-
)
78-
require.NoError(t, err)
77+
{
78+
input: `2 + 4`,
79+
want: 6,
80+
},
81+
}
82+
83+
for _, tt := range tests {
84+
t.Run(fmt.Sprintf(`opertor function helper test %s`, tt.input), func(t *testing.T) {
85+
program, err := expr.Compile(
86+
tt.input,
87+
expr.Env(env),
88+
expr.Operator("+", "Add", "AddInt"),
89+
expr.Function("Add", func(args ...interface{}) (interface{}, error) {
90+
return args[0].(Value).Int + args[1].(Value).Int, nil
91+
},
92+
new(func(_ Value, __ Value) int),
93+
),
94+
expr.Function("AddInt", func(args ...interface{}) (interface{}, error) {
95+
return args[0].(int) + args[1].(int), nil
96+
},
97+
new(func(_ int, __ int) int),
98+
),
99+
)
100+
require.NoError(t, err)
101+
102+
output, err := expr.Run(program, env)
103+
require.NoError(t, err)
104+
require.Equal(t, tt.want, output)
105+
})
106+
}
79107

80-
output, err := expr.Run(program, env)
81-
require.NoError(t, err)
82-
require.Equal(t, 3, output)
83108
}

0 commit comments

Comments
 (0)