Skip to content

Commit 1b886c4

Browse files
committed
use dict structure
1 parent caa9db3 commit 1b886c4

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

src/validators/literal.rs

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ pub struct LiteralLookup<T: Debug> {
3636
// Catch all for unhashable types like list
3737
expected_py_values: Option<Vec<(Py<PyAny>, usize)>>,
3838
// Fallback for ints, bools, and strings to use Python hash and equality checks
39-
expected_py_primitives: Option<Vec<(Py<PyAny>, usize)>>,
39+
// which we can't mix with `expected_py_dict`, see tests/test_validators/test_literal.py::test_mix_int_enum_with_int
40+
expected_py_primitives: Option<Py<PyDict>>,
4041

4142
pub values: Vec<T>,
4243
}
@@ -48,7 +49,7 @@ impl<T: Debug> LiteralLookup<T> {
4849
let mut expected_str: AHashMap<String, usize> = AHashMap::new();
4950
let expected_py_dict = PyDict::new_bound(py);
5051
let mut expected_py_values = Vec::new();
51-
let mut expected_py_primitives = Vec::new();
52+
let expected_py_primitives = PyDict::new_bound(py);
5253
let mut values = Vec::new();
5354
for (k, v) in expected {
5455
let id = values.len();
@@ -60,12 +61,12 @@ impl<T: Debug> LiteralLookup<T> {
6061
} else {
6162
expected_bool.false_id = Some(id);
6263
}
63-
expected_py_primitives.push((k.as_unbound().clone_ref(py), id));
64+
expected_py_primitives.set_item(&k, id)?;
6465
}
6566
if k.is_exact_instance_of::<PyInt>() {
6667
if let Ok(int_64) = k.extract::<i64>() {
6768
expected_int.insert(int_64, id);
68-
expected_py_primitives.push((k.as_unbound().clone_ref(py), id));
69+
expected_py_primitives.set_item(&k, id)?;
6970
} else {
7071
// cover the case of an int that's > i64::MAX etc.
7172
expected_py_dict.set_item(k, id)?;
@@ -75,7 +76,7 @@ impl<T: Debug> LiteralLookup<T> {
7576
.as_cow()
7677
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
7778
expected_str.insert(str.to_string(), id);
78-
expected_py_primitives.push((k.as_unbound().clone_ref(py), id));
79+
expected_py_primitives.set_item(&k, id)?;
7980
} else if expected_py_dict.set_item(&k, id).is_err() {
8081
expected_py_values.push((k.as_unbound().clone_ref(py), id));
8182
}
@@ -88,7 +89,7 @@ impl<T: Debug> LiteralLookup<T> {
8889
expected_str: (!expected_str.is_empty()).then_some(expected_str),
8990
expected_py_dict: (!expected_py_dict.is_empty()).then_some(expected_py_dict.into()),
9091
expected_py_values: (!expected_py_values.is_empty()).then_some(expected_py_values),
91-
expected_py_primitives: (!expected_py_primitives.is_empty()).then_some(expected_py_primitives),
92+
expected_py_primitives: (!expected_py_primitives.is_empty()).then_some(expected_py_primitives.into()),
9293
values,
9394
})
9495
}
@@ -157,20 +158,16 @@ impl<T: Debug> LiteralLookup<T> {
157158
}
158159
};
159160

161+
// this one must be last to avoid conflicts with the other lookups, think of this
162+
// almost as a lax fallback
160163
if let Some(expected_py_primitives) = &self.expected_py_primitives {
161164
let py_input = py_input.get_or_insert_with(|| input.to_object(py));
162-
let py_input_bound = py_input.bind(py);
163-
164-
for (k, id) in expected_py_primitives {
165-
let bound_k = k.bind(py);
166-
if bound_k.eq(&*py_input).unwrap_or(false) {
167-
match (bound_k.hash(), py_input_bound.hash()) {
168-
(Ok(k_hash), Ok(input_hash)) if k_hash == input_hash => {
169-
return Ok(Some((input, &self.values[*id])));
170-
}
171-
_ => continue, // Skip to the next item on hash failure or mismatch
172-
}
173-
}
165+
// We don't use ? to unpack the result of `get_item` in the next line because unhashable
166+
// inputs will produce a TypeError, which in this case we just want to treat equivalently
167+
// to a failed lookup
168+
if let Ok(Some(v)) = expected_py_primitives.bind(py).get_item(&*py_input) {
169+
let id: usize = v.extract().unwrap();
170+
return Ok(Some((input, &self.values[id])));
174171
}
175172
};
176173
Ok(None)

0 commit comments

Comments
 (0)