Skip to content

Commit 0c7f720

Browse files
committed
Add types.Equal
1 parent 55f453b commit 0c7f720

File tree

2 files changed

+141
-5
lines changed

2 files changed

+141
-5
lines changed

types/types.go

Lines changed: 94 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
package types
22

33
import (
4+
"fmt"
45
"reflect"
6+
"strings"
57

68
. "github.com/expr-lang/expr/checker/nature"
79
)
810

9-
func TypeOf(v any) Type {
10-
return rtype{t: reflect.TypeOf(v)}
11+
// Type is a type that can be used to represent a value.
12+
type Type interface {
13+
Nature() Nature
14+
Equal(Type) bool
15+
String() string
1116
}
1217

1318
var (
@@ -28,9 +33,11 @@ var (
2833
Nil = nilType{}
2934
)
3035

31-
// Type is a type that can be used to represent a value.
32-
type Type interface {
33-
Nature() Nature
36+
func TypeOf(v any) Type {
37+
if v == nil {
38+
return Nil
39+
}
40+
return rtype{t: reflect.TypeOf(v)}
3441
}
3542

3643
type nilType struct{}
@@ -39,6 +46,14 @@ func (nilType) Nature() Nature {
3946
return Nature{Nil: true}
4047
}
4148

49+
func (nilType) Equal(t Type) bool {
50+
return t == Nil
51+
}
52+
53+
func (nilType) String() string {
54+
return "nil"
55+
}
56+
4257
type rtype struct {
4358
t reflect.Type
4459
}
@@ -47,6 +62,17 @@ func (r rtype) Nature() Nature {
4762
return Nature{Type: r.t}
4863
}
4964

65+
func (r rtype) Equal(t Type) bool {
66+
if rt, ok := t.(rtype); ok {
67+
return r.t.String() == rt.t.String()
68+
}
69+
return false
70+
}
71+
72+
func (r rtype) String() string {
73+
return r.t.String()
74+
}
75+
5076
// Map returns a type that represents a map of the given type.
5177
// The map is not strict, meaning that it can contain keys not defined in the map.
5278
type Map map[string]Type
@@ -62,6 +88,30 @@ func (m Map) Nature() Nature {
6288
return nt
6389
}
6490

91+
func (m Map) Equal(t Type) bool {
92+
mt, ok := t.(Map)
93+
if !ok {
94+
return false
95+
}
96+
if len(m) != len(mt) {
97+
return false
98+
}
99+
for k, v := range m {
100+
if !v.Equal(mt[k]) {
101+
return false
102+
}
103+
}
104+
return true
105+
}
106+
107+
func (m Map) String() string {
108+
pairs := make([]string, 0, len(m))
109+
for k, v := range m {
110+
pairs = append(pairs, fmt.Sprintf("%s: %s", k, v.String()))
111+
}
112+
return fmt.Sprintf("Map{%s}", strings.Join(pairs, ", "))
113+
}
114+
65115
// StrictMap returns a type that represents a map of the given type.
66116
// The map is strict, meaning that it can only contain keys defined in the map.
67117
type StrictMap map[string]Type
@@ -78,6 +128,30 @@ func (m StrictMap) Nature() Nature {
78128
return nt
79129
}
80130

131+
func (m StrictMap) Equal(t Type) bool {
132+
mt, ok := t.(StrictMap)
133+
if !ok {
134+
return false
135+
}
136+
if len(m) != len(mt) {
137+
return false
138+
}
139+
for k, v := range m {
140+
if !v.Equal(mt[k]) {
141+
return false
142+
}
143+
}
144+
return true
145+
}
146+
147+
func (m StrictMap) String() string {
148+
pairs := make([]string, 0, len(m))
149+
for k, v := range m {
150+
pairs = append(pairs, fmt.Sprintf("%s: %s", k, v.String()))
151+
}
152+
return fmt.Sprintf("StrictMap{%s}", strings.Join(pairs, ", "))
153+
}
154+
81155
// Array returns a type that represents an array of the given type.
82156
func Array(of Type) Type {
83157
return array{of}
@@ -95,3 +169,18 @@ func (a array) Nature() Nature {
95169
ArrayOf: &of,
96170
}
97171
}
172+
173+
func (a array) Equal(t Type) bool {
174+
at, ok := t.(array)
175+
if !ok {
176+
return false
177+
}
178+
if a.of.Equal(at.of) {
179+
return true
180+
}
181+
return false
182+
}
183+
184+
func (a array) String() string {
185+
return fmt.Sprintf("Array{%s}", a.of.String())
186+
}

types/types_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package types_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/expr-lang/expr/internal/testify/require"
7+
. "github.com/expr-lang/expr/types"
8+
)
9+
10+
func TestType_Equal(t *testing.T) {
11+
tests := []struct {
12+
index string
13+
a, b Type
14+
want bool
15+
}{
16+
{"1", Int, Int, true},
17+
{"2", Int, Int8, false},
18+
{"3", Int, Uint, false},
19+
{"4", Int, Float, false},
20+
{"5", Int, String, false},
21+
{"6", Int, Bool, false},
22+
{"7", Int, Nil, false},
23+
{"8", Int, Array(Int), false},
24+
{"9", Int, Map{"foo": Int}, false},
25+
{"10", Int, StrictMap{"foo": Int}, false},
26+
{"11", Int, Array(Int), false},
27+
{"12", Array(Int), Array(Int), true},
28+
{"13", Array(Int), Array(Float), false},
29+
{"14", Map{"foo": Int}, Map{"foo": Int}, true},
30+
{"15", Map{"foo": Int}, Map{"foo": Float}, false},
31+
{"16", Map{"foo": Int}, StrictMap{"foo": Int}, false},
32+
{"17", StrictMap{"foo": Int}, StrictMap{"foo": Int}, true},
33+
{"18", StrictMap{"foo": Int}, StrictMap{"foo": Float}, false},
34+
{"19", Map{"foo": Map{"bar": Int}}, Map{"foo": Map{"bar": Int}}, true},
35+
{"20", Map{"foo": Map{"bar": Int}}, Map{"foo": Map{"bar": Float}}, false},
36+
}
37+
38+
for _, tt := range tests {
39+
t.Run(tt.index, func(t *testing.T) {
40+
if tt.want {
41+
require.True(t, tt.a.Equal(tt.b), tt.a.String()+" == "+tt.b.String())
42+
} else {
43+
require.False(t, tt.a.Equal(tt.b), tt.a.String()+" == "+tt.b.String())
44+
}
45+
})
46+
}
47+
}

0 commit comments

Comments
 (0)