Skip to content

Commit 62cdd42

Browse files
committed
Refactor operator override to be purely patcher pkg
1 parent 95084fb commit 62cdd42

File tree

6 files changed

+122
-131
lines changed

6 files changed

+122
-131
lines changed

conf/config.go

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,28 @@ import (
99
"github.com/expr-lang/expr/vm/runtime"
1010
)
1111

12-
type FunctionTable map[string]*builtin.Function
13-
14-
// OperatorsTable maps binary operators to corresponding list of functions.
15-
// Functions should be provided in the environment to allow operator overloading.
16-
type OperatorsTable map[string][]string
12+
type FunctionsTable map[string]*builtin.Function
1713

1814
type Config struct {
1915
Env any
2016
Types TypesTable
2117
MapEnv bool
2218
DefaultType reflect.Type
23-
Operators OperatorsTable
2419
Expect reflect.Kind
2520
ExpectAny bool
2621
Optimize bool
2722
Strict bool
2823
ConstFns map[string]reflect.Value
2924
Visitors []ast.Visitor
30-
Functions map[string]*builtin.Function
31-
Builtins map[string]*builtin.Function
25+
Functions FunctionsTable
26+
Builtins FunctionsTable
3227
Disabled map[string]bool // disabled builtins
3328
}
3429

3530
// CreateNew creates new config with default values.
3631
func CreateNew() *Config {
3732
c := &Config{
3833
Optimize: true,
39-
Operators: make(map[string][]string),
4034
ConstFns: make(map[string]reflect.Value),
4135
Functions: make(map[string]*builtin.Function),
4236
Builtins: make(map[string]*builtin.Function),
@@ -73,10 +67,6 @@ func (c *Config) WithEnv(env any) {
7367
c.Strict = true
7468
}
7569

76-
func (c *Config) Operator(operator string, fns ...string) {
77-
c.Operators[operator] = append(c.Operators[operator], fns...)
78-
}
79-
8070
func (c *Config) ConstExpr(name string) {
8171
if c.Env == nil {
8272
panic("no environment is specified for ConstExpr()")
@@ -88,42 +78,14 @@ func (c *Config) ConstExpr(name string) {
8878
c.ConstFns[name] = fn
8979
}
9080

91-
func (c *Config) Check() {
92-
for operator, fns := range c.Operators {
93-
for _, fn := range fns {
94-
fnType, foundType := c.Types[fn]
95-
fnFunc, foundFunc := c.Functions[fn]
96-
if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) {
97-
panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, operator))
98-
}
99-
100-
if foundType {
101-
checkType(fnType, fn, operator)
102-
}
103-
if foundFunc {
104-
checkFunc(fnFunc, fn, operator)
105-
}
106-
}
107-
}
81+
type Checker interface {
82+
Check()
10883
}
10984

110-
func checkType(fnType Tag, fn string, operator string) {
111-
requiredNumIn := 2
112-
if fnType.Method {
113-
requiredNumIn = 3 // As first argument of method is receiver.
114-
}
115-
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
116-
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
117-
}
118-
}
119-
120-
func checkFunc(fn *builtin.Function, name string, operator string) {
121-
if len(fn.Types) == 0 {
122-
panic(fmt.Errorf("function %s for %s operator misses types", name, operator))
123-
}
124-
for _, t := range fn.Types {
125-
if t.NumIn() != 2 || t.NumOut() != 1 {
126-
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", name, operator))
85+
func (c *Config) Check() {
86+
for _, v := range c.Visitors {
87+
if c, ok := v.(Checker); ok {
88+
c.Check()
12789
}
12890
}
12991
}

conf/functions.go

Lines changed: 0 additions & 1 deletion
This file was deleted.

conf/operators.go

Lines changed: 0 additions & 60 deletions
This file was deleted.

expr.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@ func AllowUndefinedVariables() Option {
4242
// Operator allows to replace a binary operator with a function.
4343
func Operator(operator string, fn ...string) Option {
4444
return func(c *conf.Config) {
45-
c.Operator(operator, fn...)
45+
p := &patcher.OperatorOverride{
46+
Operator: operator,
47+
Overrides: fn,
48+
Types: c.Types,
49+
Functions: c.Functions,
50+
}
51+
c.Visitors = append(c.Visitors, p)
4652
}
4753
}
4854

@@ -188,14 +194,6 @@ func Compile(input string, ops ...Option) (*vm.Program, error) {
188194
}
189195
config.Check()
190196

191-
if len(config.Operators) > 0 {
192-
config.Visitors = append(config.Visitors, &patcher.Operator{
193-
Operators: config.Operators,
194-
Types: config.Types,
195-
Functions: config.Functions,
196-
})
197-
}
198-
199197
tree, err := parser.ParseWithConfig(input, config)
200198
if err != nil {
201199
return nil, err

patcher/operator_override.go

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,35 @@
11
package patcher
22

33
import (
4+
"fmt"
5+
"reflect"
6+
47
"github.com/expr-lang/expr/ast"
8+
"github.com/expr-lang/expr/builtin"
59
"github.com/expr-lang/expr/conf"
610
)
711

8-
type Operator struct {
9-
Operators conf.OperatorsTable
10-
Types conf.TypesTable
11-
Functions conf.FunctionTable
12+
type OperatorOverride struct {
13+
Operator string // Operator token to override.
14+
Overrides []string // List of function names to override operator with.
15+
Types conf.TypesTable // Env types.
16+
Functions conf.FunctionsTable // Env functions.
1217
}
1318

14-
func (p *Operator) Visit(node *ast.Node) {
19+
func (p *OperatorOverride) Visit(node *ast.Node) {
1520
binaryNode, ok := (*node).(*ast.BinaryNode)
1621
if !ok {
1722
return
1823
}
1924

20-
fns, ok := p.Operators[binaryNode.Operator]
21-
if !ok {
25+
if binaryNode.Operator != p.Operator {
2226
return
2327
}
2428

2529
leftType := binaryNode.Left.Type()
2630
rightType := binaryNode.Right.Type()
2731

28-
ret, fn, ok := conf.FindSuitableOperatorOverload(fns, p.Types, p.Functions, leftType, rightType)
32+
ret, fn, ok := p.FindSuitableOperatorOverload(leftType, rightType)
2933
if ok {
3034
newNode := &ast.CallNode{
3135
Callee: &ast.IdentifierNode{Value: fn},
@@ -35,3 +39,97 @@ func (p *Operator) Visit(node *ast.Node) {
3539
ast.Patch(node, newNode)
3640
}
3741
}
42+
43+
func (p *OperatorOverride) FindSuitableOperatorOverload(l, r reflect.Type) (reflect.Type, string, bool) {
44+
t, fn, ok := p.findSuitableOperatorOverloadInFunctions(l, r)
45+
if !ok {
46+
t, fn, ok = p.findSuitableOperatorOverloadInTypes(l, r)
47+
}
48+
return t, fn, ok
49+
}
50+
51+
func (p *OperatorOverride) findSuitableOperatorOverloadInTypes(l, r reflect.Type) (reflect.Type, string, bool) {
52+
for _, fn := range p.Overrides {
53+
fnType, ok := p.Types[fn]
54+
if !ok {
55+
continue
56+
}
57+
firstInIndex := 0
58+
if fnType.Method {
59+
firstInIndex = 1 // As first argument to method is receiver.
60+
}
61+
ret, done := checkTypeSuits(fnType.Type, l, r, firstInIndex)
62+
if done {
63+
return ret, fn, true
64+
}
65+
}
66+
return nil, "", false
67+
}
68+
69+
func (p *OperatorOverride) findSuitableOperatorOverloadInFunctions(l, r reflect.Type) (reflect.Type, string, bool) {
70+
for _, fn := range p.Overrides {
71+
fnType, ok := p.Functions[fn]
72+
if !ok {
73+
continue
74+
}
75+
firstInIndex := 0
76+
for _, overload := range fnType.Types {
77+
ret, done := checkTypeSuits(overload, l, r, firstInIndex)
78+
if done {
79+
return ret, fn, true
80+
}
81+
}
82+
}
83+
return nil, "", false
84+
}
85+
86+
func checkTypeSuits(t reflect.Type, l reflect.Type, r reflect.Type, firstInIndex int) (reflect.Type, bool) {
87+
firstArgType := t.In(firstInIndex)
88+
secondArgType := t.In(firstInIndex + 1)
89+
90+
firstArgumentFit := l == firstArgType || (firstArgType.Kind() == reflect.Interface && (l == nil || l.Implements(firstArgType)))
91+
secondArgumentFit := r == secondArgType || (secondArgType.Kind() == reflect.Interface && (r == nil || r.Implements(secondArgType)))
92+
if firstArgumentFit && secondArgumentFit {
93+
return t.Out(0), true
94+
}
95+
return nil, false
96+
}
97+
98+
func (p *OperatorOverride) Check() {
99+
for _, fn := range p.Overrides {
100+
fnType, foundType := p.Types[fn]
101+
fnFunc, foundFunc := p.Functions[fn]
102+
if !foundFunc && (!foundType || fnType.Type.Kind() != reflect.Func) {
103+
panic(fmt.Errorf("function %s for %s operator does not exist in the environment", fn, p.Operator))
104+
}
105+
106+
if foundType {
107+
checkType(fnType, fn, p.Operator)
108+
}
109+
110+
if foundFunc {
111+
checkFunc(fnFunc, fn, p.Operator)
112+
}
113+
}
114+
}
115+
116+
func checkType(fnType conf.Tag, fn string, operator string) {
117+
requiredNumIn := 2
118+
if fnType.Method {
119+
requiredNumIn = 3 // As first argument of method is receiver.
120+
}
121+
if fnType.Type.NumIn() != requiredNumIn || fnType.Type.NumOut() != 1 {
122+
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", fn, operator))
123+
}
124+
}
125+
126+
func checkFunc(fn *builtin.Function, name string, operator string) {
127+
if len(fn.Types) == 0 {
128+
panic(fmt.Errorf("function %s for %s operator misses types", name, operator))
129+
}
130+
for _, t := range fn.Types {
131+
if t.NumIn() != 2 || t.NumOut() != 1 {
132+
panic(fmt.Errorf("function %s for %s operator does not have a correct signature", name, operator))
133+
}
134+
}
135+
}

test/operator/operator_test.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,6 @@ func TestOperator_struct(t *testing.T) {
2626
require.Equal(t, true, output)
2727
}
2828

29-
func TestOperator_options_another_order(t *testing.T) {
30-
code := `Time == "2017-10-23"`
31-
_, err := expr.Compile(code, expr.Operator("==", "TimeEqualString"), expr.Env(mock.Env{}))
32-
require.NoError(t, err)
33-
}
34-
3529
func TestOperator_no_env(t *testing.T) {
3630
code := `Time == "2017-10-23"`
3731
require.Panics(t, func() {

0 commit comments

Comments
 (0)