Skip to content

Commit 89e89f4

Browse files
committed
Re-use literal validator logic for tagged unions
1 parent 093f6c5 commit 89e89f4

File tree

3 files changed

+112
-88
lines changed

3 files changed

+112
-88
lines changed

src/input/input_python.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,6 @@ impl<'a> Input<'a> for PyAny {
271271
fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
272272
if PyInt::is_exact_type_of(self) {
273273
Ok(EitherInt::Py(self))
274-
} else if PyInt::is_type_of(self) {
275-
// bools are a subclass of int, so check for bool type in this specific case
276-
if PyBool::is_exact_type_of(self) {
277-
Err(ValError::new(ErrorType::IntType, self))
278-
} else {
279-
Ok(EitherInt::Py(self))
280-
}
281274
} else {
282275
Err(ValError::new(ErrorType::IntType, self))
283276
}

src/validators/literal.rs

Lines changed: 96 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// Validator for things inside of a typing.Literal[]
22
// which can be an int, a string, bytes or an Enum value (including `class Foo(str, Enum)` type enums)
3+
use core::fmt::Debug;
34

4-
use ahash::AHashSet;
5+
use ahash::AHashMap;
56
use pyo3::intern;
67
use pyo3::prelude::*;
78
use pyo3::types::{PyDict, PyList};
@@ -15,15 +16,96 @@ use crate::tools::SchemaDict;
1516
use super::{BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
1617

1718
#[derive(Debug, Clone)]
18-
pub struct LiteralValidator {
19+
pub struct LiteralLookup<T: Clone + Debug> {
1920
// Specialized lookups for ints and strings because they
2021
// (1) are easy to convert between Rust and Python
2122
// (2) hashing them in Rust is very fast
2223
// (3) are the most commonly used things in Literal[...]
23-
expected_int: Option<AHashSet<i64>>,
24-
expected_str: Option<AHashSet<String>>,
24+
expected_int: Option<AHashMap<i64, usize>>,
25+
expected_str: Option<AHashMap<String, usize>>,
2526
// Catch all for Enum and bytes (the latter only because it is seldom used)
2627
expected_py: Option<Py<PyDict>>,
28+
pub values: Vec<T>,
29+
}
30+
31+
impl<T: Clone + Debug> LiteralLookup<T> {
32+
pub fn new<'py>(py: Python<'py>, expected: impl Iterator<Item = (&'py PyAny, T)>) -> PyResult<Self> {
33+
let mut expected_int = AHashMap::new();
34+
let mut expected_str = AHashMap::new();
35+
let expected_py = PyDict::new(py);
36+
let mut values = Vec::new();
37+
for (k, v) in expected {
38+
let id = values.len();
39+
values.push(v);
40+
if let Ok(either_int) = k.strict_int() {
41+
let int = either_int
42+
.into_i64(py)
43+
.map_err(|_| py_schema_error_type!("error extracting int {:?}", k))?;
44+
expected_int.insert(int, id);
45+
} else if let Ok(either_str) = k.strict_str() {
46+
let str = either_str
47+
.as_cow()
48+
.map_err(|_| py_schema_error_type!("error extracting str {:?}", k))?;
49+
expected_str.insert(str.to_string(), id);
50+
} else {
51+
expected_py.set_item(k, id)?;
52+
}
53+
}
54+
55+
Ok(Self {
56+
expected_int: match expected_int.is_empty() {
57+
true => None,
58+
false => Some(expected_int),
59+
},
60+
expected_str: match expected_str.is_empty() {
61+
true => None,
62+
false => Some(expected_str),
63+
},
64+
expected_py: match expected_py.is_empty() {
65+
true => None,
66+
false => Some(expected_py.into()),
67+
},
68+
values,
69+
})
70+
}
71+
72+
pub fn validate<'data, I: Input<'data>>(
73+
&self,
74+
py: Python<'data>,
75+
input: &'data I,
76+
) -> ValResult<'data, Option<(&'data I, &T)>> {
77+
// dbg!(input.to_object(py).as_ref(py).repr().unwrap());
78+
if let Some(expected_ints) = &self.expected_int {
79+
if let Ok(either_int) = input.strict_int() {
80+
let int = either_int.into_i64(py)?;
81+
if let Some(id) = expected_ints.get(&int) {
82+
return Ok(Some((input, &self.values[*id])));
83+
}
84+
}
85+
}
86+
if let Some(expected_strings) = &self.expected_str {
87+
// dbg!(expected_strings);
88+
if let Ok(either_str) = input.strict_str() {
89+
let cow = either_str.as_cow()?;
90+
if let Some(id) = expected_strings.get(cow.as_ref()) {
91+
return Ok(Some((input, &self.values[*id])));
92+
}
93+
}
94+
}
95+
// must be an enum or bytes
96+
if let Some(expected_py) = &self.expected_py {
97+
if let Some(v) = expected_py.as_ref(py).get_item(input) {
98+
let id: usize = v.extract().unwrap();
99+
return Ok(Some((input, &self.values[id])));
100+
}
101+
};
102+
Ok(None)
103+
}
104+
}
105+
106+
#[derive(Debug, Clone)]
107+
pub struct LiteralValidator {
108+
lookup: LiteralLookup<PyObject>,
27109
expected_repr: String,
28110
name: String,
29111
}
@@ -41,32 +123,14 @@ impl BuildValidator for LiteralValidator {
41123
return py_schema_err!("`expected` should have length > 0");
42124
}
43125
let py = expected.py();
44-
// Literal[...] only supports int, str, bytes or enums, all of which can be hashed
45-
let mut expected_int = AHashSet::new();
46-
let mut expected_str = AHashSet::new();
47-
let expected_py = PyDict::new(py);
48126
let mut repr_args: Vec<String> = Vec::new();
49127
for item in expected.iter() {
50128
repr_args.push(item.repr()?.extract()?);
51-
if let Ok(either_int) = item.strict_int() {
52-
let int = either_int
53-
.into_i64(py)
54-
.map_err(|_| py_schema_error_type!("error extracting int {:?}", item))?;
55-
expected_int.insert(int);
56-
} else if let Ok(either_str) = item.strict_str() {
57-
let str = either_str
58-
.as_cow()
59-
.map_err(|_| py_schema_error_type!("error extracting str {:?}", item))?;
60-
expected_str.insert(str.to_string());
61-
} else {
62-
expected_py.set_item(item, item)?;
63-
}
64129
}
65130
let (expected_repr, name) = expected_repr_name(repr_args, "literal");
131+
let lookup = LiteralLookup::new(py, expected.iter().map(|v| (v, v.to_object(py))))?;
66132
Ok(CombinedValidator::Literal(Self {
67-
expected_int: (!expected_int.is_empty()).then_some(expected_int),
68-
expected_str: (!expected_str.is_empty()).then_some(expected_str),
69-
expected_py: (!expected_py.is_empty()).then_some(expected_py.into()),
133+
lookup,
70134
expected_repr,
71135
name,
72136
}))
@@ -82,34 +146,15 @@ impl Validator for LiteralValidator {
82146
_definitions: &'data Definitions<CombinedValidator>,
83147
_recursion_guard: &'s mut RecursionGuard,
84148
) -> ValResult<'data, PyObject> {
85-
if let Some(expected_ints) = &self.expected_int {
86-
if let Ok(either_int) = input.strict_int() {
87-
let int = either_int.into_i64(py)?;
88-
if expected_ints.contains(&int) {
89-
return Ok(input.to_object(py));
90-
}
91-
}
92-
}
93-
if let Some(expected_strings) = &self.expected_str {
94-
if let Ok(either_str) = input.strict_str() {
95-
let cow = either_str.as_cow()?;
96-
if expected_strings.contains(cow.as_ref()) {
97-
return Ok(input.to_object(py));
98-
}
99-
}
149+
match self.lookup.validate(py, input)? {
150+
Some((_, v)) => Ok(v.clone()),
151+
None => Err(ValError::new(
152+
ErrorType::LiteralError {
153+
expected: self.expected_repr.clone(),
154+
},
155+
input,
156+
)),
100157
}
101-
// must be an enum or bytes
102-
if let Some(expected_py) = &self.expected_py {
103-
if let Some(v) = expected_py.as_ref(py).get_item(input) {
104-
return Ok(v.into());
105-
}
106-
};
107-
Err(ValError::new(
108-
ErrorType::LiteralError {
109-
expected: self.expected_repr.clone(),
110-
},
111-
input,
112-
))
113158
}
114159

115160
fn different_strict_behavior(

src/validators/union.rs

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::recursion_guard::RecursionGuard;
1313
use crate::tools::SchemaDict;
1414

1515
use super::custom_error::CustomError;
16-
use super::literal::LiteralValidator;
16+
use super::literal::LiteralLookup;
1717
use super::{build_validator, BuildValidator, CombinedValidator, Definitions, DefinitionsBuilder, Extra, Validator};
1818

1919
#[derive(Debug, Clone)]
@@ -218,10 +218,8 @@ impl Discriminator {
218218

219219
#[derive(Debug, Clone)]
220220
pub struct TaggedUnionValidator {
221-
choices: Py<PyDict>,
222-
validators: Vec<CombinedValidator>,
223221
discriminator: Discriminator,
224-
discriminator_validator: Box<CombinedValidator>,
222+
lookup: LiteralLookup<CombinedValidator>,
225223
from_attributes: bool,
226224
strict: bool,
227225
custom_error: Option<CustomError>,
@@ -243,21 +241,15 @@ impl BuildValidator for TaggedUnionValidator {
243241
let discriminator_repr = discriminator.to_string_py(py)?;
244242

245243
let choices = PyDict::new(py);
246-
let mut validators = Vec::with_capacity(choices.len());
247244
let mut tags_repr = String::with_capacity(50);
248245
let mut descr = String::with_capacity(50);
249246
let mut first = true;
250247
let mut discriminators = Vec::with_capacity(choices.len());
251248
let schema_choices: &PyDict = schema.get_as_req(intern!(py, "choices"))?;
252-
let schema_choice_keys = schema_choices.keys();
253-
let discriminator_validator_schema = PyDict::new(py);
254-
discriminator_validator_schema.set_item(intern!(py, "type"), intern!(py, "literal").as_ref())?;
255-
discriminator_validator_schema.set_item(intern!(py, "expected"), schema_choice_keys.as_ref())?;
256-
let discriminator_validator = LiteralValidator::build(discriminator_validator_schema, config, definitions)?;
249+
let mut lookup_map = Vec::with_capacity(choices.len());
257250
for (choice_key, choice_schema) in schema_choices.iter() {
258251
discriminators.push(choice_key);
259252
let validator = build_validator(choice_schema, config, definitions)?;
260-
choices.set_item(choice_key, validators.len())?;
261253
let tag_repr = choice_key.repr()?.to_string();
262254
if first {
263255
first = false;
@@ -268,9 +260,11 @@ impl BuildValidator for TaggedUnionValidator {
268260
// no spaces in get_name() output to make loc easy to read
269261
write!(descr, ",{}", validator.get_name()).unwrap();
270262
}
271-
validators.push(validator);
263+
lookup_map.push((choice_key, validator));
272264
}
273265

266+
let lookup = LiteralLookup::new(py, lookup_map.into_iter())?;
267+
274268
let key = intern!(py, "from_attributes");
275269
let from_attributes = schema_or_config(schema, config, key, key)?.unwrap_or(true);
276270

@@ -280,10 +274,8 @@ impl BuildValidator for TaggedUnionValidator {
280274
};
281275

282276
Ok(Self {
283-
choices: choices.into(),
284-
validators,
285277
discriminator,
286-
discriminator_validator: Box::new(discriminator_validator),
278+
lookup,
287279
from_attributes,
288280
strict: is_strict(schema, config)?,
289281
custom_error: CustomError::build(schema, config, definitions)?,
@@ -352,7 +344,8 @@ impl Validator for TaggedUnionValidator {
352344
definitions: Option<&DefinitionsBuilder<CombinedValidator>>,
353345
ultra_strict: bool,
354346
) -> bool {
355-
self.validators
347+
self.lookup
348+
.values
356349
.iter()
357350
.any(|v| v.different_strict_behavior(definitions, ultra_strict))
358351
}
@@ -362,7 +355,8 @@ impl Validator for TaggedUnionValidator {
362355
}
363356

364357
fn complete(&mut self, definitions: &DefinitionsBuilder<CombinedValidator>) -> PyResult<()> {
365-
self.validators
358+
self.lookup
359+
.values
366360
.iter_mut()
367361
.try_for_each(|validator| validator.complete(definitions))
368362
}
@@ -423,19 +417,11 @@ impl TaggedUnionValidator {
423417
definitions: &'data Definitions<CombinedValidator>,
424418
recursion_guard: &'s mut RecursionGuard,
425419
) -> ValResult<'data, PyObject> {
426-
if let Ok(tag) = self
427-
.discriminator_validator
428-
.validate(py, tag, extra, definitions, recursion_guard)
429-
{
430-
let tag = tag.as_ref(py);
431-
if let Some(validator_idx) = self.choices.as_ref(py).get_item(tag) {
432-
// We know this will always be a usize because we put it there ourselves
433-
let validator = &self.validators[usize::extract(validator_idx).unwrap()];
434-
return match validator.validate(py, input, extra, definitions, recursion_guard) {
435-
Ok(res) => Ok(res),
436-
Err(err) => Err(err.with_outer_location(LocItem::try_from(tag)?)),
437-
};
438-
}
420+
if let Ok(Some((tag, validator))) = self.lookup.validate(py, tag) {
421+
return match validator.validate(py, input, extra, definitions, recursion_guard) {
422+
Ok(res) => Ok(res),
423+
Err(err) => Err(err.with_outer_location(LocItem::try_from(tag.to_object(py).into_ref(py))?)),
424+
};
439425
}
440426
match self.custom_error {
441427
Some(ref custom_error) => Err(custom_error.as_val_error(input)),

0 commit comments

Comments
 (0)