@@ -36,7 +36,8 @@ pub struct LiteralLookup<T: Debug> {
36
36
// Catch all for unhashable types like list
37
37
expected_py_values : Option < Vec < ( Py < PyAny > , usize ) > > ,
38
38
// Fallback for ints, bools, and strings to use Python hash and equality checks
39
- expected_py_primitives : Option < Vec < ( Py < PyAny > , usize ) > > ,
39
+ // which we can't mix with `expected_py_dict`, see tests/test_validators/test_literal.py::test_mix_int_enum_with_int
40
+ expected_py_primitives : Option < Py < PyDict > > ,
40
41
41
42
pub values : Vec < T > ,
42
43
}
@@ -48,7 +49,7 @@ impl<T: Debug> LiteralLookup<T> {
48
49
let mut expected_str: AHashMap < String , usize > = AHashMap :: new ( ) ;
49
50
let expected_py_dict = PyDict :: new_bound ( py) ;
50
51
let mut expected_py_values = Vec :: new ( ) ;
51
- let mut expected_py_primitives = Vec :: new ( ) ;
52
+ let expected_py_primitives = PyDict :: new_bound ( py ) ;
52
53
let mut values = Vec :: new ( ) ;
53
54
for ( k, v) in expected {
54
55
let id = values. len ( ) ;
@@ -60,12 +61,12 @@ impl<T: Debug> LiteralLookup<T> {
60
61
} else {
61
62
expected_bool. false_id = Some ( id) ;
62
63
}
63
- expected_py_primitives. push ( ( k . as_unbound ( ) . clone_ref ( py ) , id) ) ;
64
+ expected_py_primitives. set_item ( & k , id) ? ;
64
65
}
65
66
if k. is_exact_instance_of :: < PyInt > ( ) {
66
67
if let Ok ( int_64) = k. extract :: < i64 > ( ) {
67
68
expected_int. insert ( int_64, id) ;
68
- expected_py_primitives. push ( ( k . as_unbound ( ) . clone_ref ( py ) , id) ) ;
69
+ expected_py_primitives. set_item ( & k , id) ? ;
69
70
} else {
70
71
// cover the case of an int that's > i64::MAX etc.
71
72
expected_py_dict. set_item ( k, id) ?;
@@ -75,7 +76,7 @@ impl<T: Debug> LiteralLookup<T> {
75
76
. as_cow ( )
76
77
. map_err ( |_| py_schema_error_type ! ( "error extracting str {:?}" , k) ) ?;
77
78
expected_str. insert ( str. to_string ( ) , id) ;
78
- expected_py_primitives. push ( ( k . as_unbound ( ) . clone_ref ( py ) , id) ) ;
79
+ expected_py_primitives. set_item ( & k , id) ? ;
79
80
} else if expected_py_dict. set_item ( & k, id) . is_err ( ) {
80
81
expected_py_values. push ( ( k. as_unbound ( ) . clone_ref ( py) , id) ) ;
81
82
}
@@ -88,7 +89,7 @@ impl<T: Debug> LiteralLookup<T> {
88
89
expected_str : ( !expected_str. is_empty ( ) ) . then_some ( expected_str) ,
89
90
expected_py_dict : ( !expected_py_dict. is_empty ( ) ) . then_some ( expected_py_dict. into ( ) ) ,
90
91
expected_py_values : ( !expected_py_values. is_empty ( ) ) . then_some ( expected_py_values) ,
91
- expected_py_primitives : ( !expected_py_primitives. is_empty ( ) ) . then_some ( expected_py_primitives) ,
92
+ expected_py_primitives : ( !expected_py_primitives. is_empty ( ) ) . then_some ( expected_py_primitives. into ( ) ) ,
92
93
values,
93
94
} )
94
95
}
@@ -157,20 +158,16 @@ impl<T: Debug> LiteralLookup<T> {
157
158
}
158
159
} ;
159
160
161
+ // this one must be last to avoid conflicts with the other lookups, think of this
162
+ // almost as a lax fallback
160
163
if let Some ( expected_py_primitives) = & self . expected_py_primitives {
161
164
let py_input = py_input. get_or_insert_with ( || input. to_object ( py) ) ;
162
- let py_input_bound = py_input. bind ( py) ;
163
-
164
- for ( k, id) in expected_py_primitives {
165
- let bound_k = k. bind ( py) ;
166
- if bound_k. eq ( & * py_input) . unwrap_or ( false ) {
167
- match ( bound_k. hash ( ) , py_input_bound. hash ( ) ) {
168
- ( Ok ( k_hash) , Ok ( input_hash) ) if k_hash == input_hash => {
169
- return Ok ( Some ( ( input, & self . values [ * id] ) ) ) ;
170
- }
171
- _ => continue , // Skip to the next item on hash failure or mismatch
172
- }
173
- }
165
+ // We don't use ? to unpack the result of `get_item` in the next line because unhashable
166
+ // inputs will produce a TypeError, which in this case we just want to treat equivalently
167
+ // to a failed lookup
168
+ if let Ok ( Some ( v) ) = expected_py_primitives. bind ( py) . get_item ( & * py_input) {
169
+ let id: usize = v. extract ( ) . unwrap ( ) ;
170
+ return Ok ( Some ( ( input, & self . values [ id] ) ) ) ;
174
171
}
175
172
} ;
176
173
Ok ( None )
0 commit comments