Skip to content

Commit 1f5a98b

Browse files
committed
Rework iterable validation to better enforce constraints across all iterables
1 parent 47c2df1 commit 1f5a98b

28 files changed

+1645
-1158
lines changed

benches/main.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ extern crate test;
55
use test::{black_box, Bencher};
66

77
use pyo3::prelude::*;
8-
use pyo3::types::{PyDict, PyString};
8+
use pyo3::types::{PyDict, PySet, PyString};
99

1010
use _pydantic_core::SchemaValidator;
1111

@@ -265,6 +265,7 @@ fn dict_python(bench: &mut Bencher) {
265265
.collect::<Vec<String>>()
266266
.join(", ")
267267
);
268+
dbg!(code.clone());
268269
let input = py.eval(&code, None, None).unwrap();
269270
let input = black_box(input);
270271
bench.iter(|| {
@@ -696,3 +697,58 @@ class Foo(Enum):
696697
}
697698
})
698699
}
700+
701+
const COLLECTION_SIZE: usize = 100_000;
702+
703+
#[bench]
704+
fn constructing_pyset_from_vec_without_capacity(bench: &mut Bencher) {
705+
Python::with_gil(|py| {
706+
let input: Vec<PyObject> = (0..COLLECTION_SIZE).map(|v| v.to_object(py)).collect();
707+
708+
bench.iter(|| {
709+
black_box({
710+
let mut output = Vec::new();
711+
for x in &input {
712+
output.push(x);
713+
}
714+
let set = PySet::new(py, output.iter()).unwrap();
715+
set
716+
})
717+
})
718+
})
719+
}
720+
721+
#[bench]
722+
fn constructing_pyset_from_vec_with_capacity(bench: &mut Bencher) {
723+
Python::with_gil(|py| {
724+
let input: Vec<PyObject> = (0..COLLECTION_SIZE).map(|v| v.to_object(py)).collect();
725+
726+
bench.iter(|| {
727+
black_box({
728+
let mut output = Vec::with_capacity(COLLECTION_SIZE);
729+
for x in &input {
730+
output.push(x);
731+
}
732+
let set = PySet::new(py, output.iter()).unwrap();
733+
set
734+
})
735+
})
736+
})
737+
}
738+
739+
#[bench]
740+
fn constructing_pyset_from_vec_directly(bench: &mut Bencher) {
741+
Python::with_gil(|py| {
742+
let input: Vec<PyObject> = (0..COLLECTION_SIZE).map(|v| v.to_object(py)).collect();
743+
744+
bench.iter(|| {
745+
black_box({
746+
let output = PySet::new(py, &Vec::<i64>::new()).unwrap();
747+
for x in &input {
748+
output.add(x).unwrap();
749+
}
750+
output
751+
})
752+
})
753+
})
754+
}

src/errors/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ impl ErrorType {
469469
Self::MultipleOf {..} => "Input should be a multiple of {multiple_of}",
470470
Self::FiniteNumber => "Input should be a finite number",
471471
Self::TooShort {..} => "{field_type} should have at least {min_length} item{expected_plural} after validation, not {actual_length}",
472-
Self::TooLong {..} => "{field_type} should have at most {max_length} item{expected_plural} after validation, not {actual_length}",
472+
Self::TooLong {..} => "{field_type} should have at most {max_length} item{expected_plural} after validation, not >= {actual_length}",
473473
Self::IterableType => "Input should be iterable",
474474
Self::IterationError {..} => "Error iterating over object, error: {error}",
475475
Self::StringType => "Input should be a valid string",

src/input/generic_iterable.rs

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
use crate::errors::{py_err_string, ErrorType, ValError, ValResult};
2+
3+
use super::parse_json::{JsonInput, JsonObject};
4+
use pyo3::{
5+
exceptions::PyTypeError,
6+
types::{
7+
PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PyMapping, PySequence, PySet, PyString, PyTuple,
8+
},
9+
PyAny, PyErr, PyResult, Python, ToPyObject,
10+
};
11+
12+
#[derive(Debug)]
13+
pub enum GenericIterable<'a> {
14+
List(&'a PyList),
15+
Tuple(&'a PyTuple),
16+
Set(&'a PySet),
17+
FrozenSet(&'a PyFrozenSet),
18+
Dict(&'a PyDict),
19+
// Treat dict values / keys / items as generic iterators
20+
// since PyPy doesn't export the concrete types
21+
DictKeys(&'a PyIterator),
22+
DictValues(&'a PyIterator),
23+
DictItems(&'a PyIterator),
24+
Mapping(&'a PyMapping),
25+
String(&'a PyString),
26+
Bytes(&'a PyBytes),
27+
PyByteArray(&'a PyByteArray),
28+
Sequence(&'a PySequence),
29+
Iterator(&'a PyIterator),
30+
JsonArray(&'a [JsonInput]),
31+
JsonObject(&'a JsonObject),
32+
}
33+
34+
type PyMappingItems<'a> = (&'a PyAny, &'a PyAny);
35+
36+
#[inline(always)]
37+
fn extract_items(item: PyResult<&PyAny>) -> PyResult<PyMappingItems<'_>> {
38+
match item {
39+
Ok(v) => v.extract::<PyMappingItems>(),
40+
Err(e) => Err(e),
41+
}
42+
}
43+
44+
#[inline(always)]
45+
fn map_err<'data>(py: Python<'data>, err: PyErr, input: &'data PyAny) -> ValError<'data> {
46+
ValError::new(
47+
ErrorType::IterationError {
48+
error: py_err_string(py, err),
49+
},
50+
input,
51+
)
52+
}
53+
54+
impl<'a, 'py: 'a> GenericIterable<'a> {
55+
pub fn len(&self) -> Option<usize> {
56+
match &self {
57+
GenericIterable::List(iter) => Some(iter.len()),
58+
GenericIterable::Tuple(iter) => Some(iter.len()),
59+
GenericIterable::Set(iter) => Some(iter.len()),
60+
GenericIterable::FrozenSet(iter) => Some(iter.len()),
61+
GenericIterable::Dict(iter) => Some(iter.len()),
62+
GenericIterable::DictKeys(iter) => iter.len().ok(),
63+
GenericIterable::DictValues(iter) => iter.len().ok(),
64+
GenericIterable::DictItems(iter) => iter.len().ok(),
65+
GenericIterable::Mapping(iter) => iter.len().ok(),
66+
GenericIterable::String(iter) => iter.len().ok(),
67+
GenericIterable::Bytes(iter) => iter.len().ok(),
68+
GenericIterable::PyByteArray(iter) => Some(iter.len()),
69+
GenericIterable::Sequence(iter) => iter.len().ok(),
70+
GenericIterable::Iterator(iter) => iter.len().ok(),
71+
GenericIterable::JsonArray(iter) => Some(iter.len()),
72+
GenericIterable::JsonObject(iter) => Some(iter.len()),
73+
}
74+
}
75+
pub fn into_sequence_iterator(
76+
self,
77+
py: Python<'py>,
78+
) -> PyResult<Box<dyn Iterator<Item = PyResult<&'a PyAny>> + 'a>> {
79+
match self {
80+
GenericIterable::List(iter) => Ok(Box::new(iter.iter().map(Ok))),
81+
GenericIterable::Tuple(iter) => Ok(Box::new(iter.iter().map(Ok))),
82+
GenericIterable::Set(iter) => Ok(Box::new(iter.iter().map(Ok))),
83+
GenericIterable::FrozenSet(iter) => Ok(Box::new(iter.iter().map(Ok))),
84+
// Note that this iterates over only the keys, just like doing iter({}) in Python
85+
GenericIterable::Dict(iter) => Ok(Box::new(iter.iter().map(|(k, _)| Ok(k)))),
86+
GenericIterable::DictKeys(iter) => Ok(Box::new(iter)),
87+
GenericIterable::DictValues(iter) => Ok(Box::new(iter)),
88+
GenericIterable::DictItems(iter) => Ok(Box::new(iter)),
89+
// Note that this iterates over only the keys, just like doing iter({}) in Python
90+
GenericIterable::Mapping(iter) => Ok(Box::new(iter.keys()?.iter()?)),
91+
GenericIterable::String(iter) => Ok(Box::new(iter.iter()?)),
92+
GenericIterable::Bytes(iter) => Ok(Box::new(iter.iter()?)),
93+
GenericIterable::PyByteArray(iter) => Ok(Box::new(iter.iter()?)),
94+
GenericIterable::Sequence(iter) => Ok(Box::new(iter.iter()?)),
95+
GenericIterable::Iterator(iter) => Ok(Box::new(iter)),
96+
GenericIterable::JsonArray(iter) => Ok(Box::new(iter.iter().map(move |v| {
97+
let v = v.to_object(py);
98+
Ok(v.into_ref(py))
99+
}))),
100+
// Note that this iterates over only the keys, just like doing iter({}) in Python, just for consistency
101+
GenericIterable::JsonObject(iter) => Ok(Box::new(
102+
iter.iter().map(move |(k, _)| Ok(k.to_object(py).into_ref(py))),
103+
)),
104+
}
105+
}
106+
107+
pub fn into_mapping_items_iterator(
108+
self,
109+
py: Python<'a>,
110+
) -> PyResult<Box<dyn Iterator<Item = ValResult<'a, PyMappingItems<'a>>> + 'a>> {
111+
let py2 = py;
112+
match self {
113+
GenericIterable::List(iter) => {
114+
Ok(Box::new(iter.iter().map(move |v| {
115+
extract_items(Ok(v)).map_err(|e| map_err(py2, e, iter.as_ref()))
116+
})))
117+
}
118+
GenericIterable::Tuple(iter) => {
119+
Ok(Box::new(iter.iter().map(move |v| {
120+
extract_items(Ok(v)).map_err(|e| map_err(py2, e, iter.as_ref()))
121+
})))
122+
}
123+
GenericIterable::Set(iter) => {
124+
Ok(Box::new(iter.iter().map(move |v| {
125+
extract_items(Ok(v)).map_err(|e| map_err(py2, e, iter.as_ref()))
126+
})))
127+
}
128+
GenericIterable::FrozenSet(iter) => {
129+
Ok(Box::new(iter.iter().map(move |v| {
130+
extract_items(Ok(v)).map_err(|e| map_err(py2, e, iter.as_ref()))
131+
})))
132+
}
133+
// Note that we iterate over (key, value), unlike doing iter({}) in Python
134+
GenericIterable::Dict(iter) => Ok(Box::new(iter.iter().map(Ok))),
135+
// Keys or values can be tuples
136+
GenericIterable::DictKeys(iter) => Ok(Box::new(
137+
iter.map(extract_items)
138+
.map(move |r| r.map_err(|e| map_err(py2, e, iter.as_ref()))),
139+
)),
140+
GenericIterable::DictValues(iter) => Ok(Box::new(
141+
iter.map(extract_items)
142+
.map(move |r| r.map_err(|e| map_err(py2, e, iter.as_ref()))),
143+
)),
144+
GenericIterable::DictItems(iter) => Ok(Box::new(
145+
iter.map(extract_items)
146+
.map(move |r| r.map_err(|e| map_err(py2, e, iter.as_ref()))),
147+
)),
148+
// Note that we iterate over (key, value), unlike doing iter({}) in Python
149+
GenericIterable::Mapping(iter) => Ok(Box::new(
150+
iter.items()?
151+
.iter()?
152+
.map(extract_items)
153+
.map(move |r| r.map_err(|e| map_err(py2, e, iter.as_ref()))),
154+
)),
155+
// In Python if you do dict("foobar") you get "dictionary update sequence element #0 has length 1; 2 is required"
156+
// This is similar but arguably a better error message
157+
GenericIterable::String(_) => Err(PyTypeError::new_err(
158+
"Expected an iterable of (key, value) pairs, got a string",
159+
)),
160+
GenericIterable::Bytes(_) => Err(PyTypeError::new_err(
161+
"Expected an iterable of (key, value) pairs, got a bytes",
162+
)),
163+
GenericIterable::PyByteArray(_) => Err(PyTypeError::new_err(
164+
"Expected an iterable of (key, value) pairs, got a bytearray",
165+
)),
166+
// Obviously these may be things that are not convertible to a tuple of (Hashable, Any)
167+
// Python fails with a similar error message to above, ours will be slightly different (PyO3 will fail to extract) but similar enough
168+
GenericIterable::Sequence(iter) => Ok(Box::new(
169+
iter.iter()?
170+
.map(extract_items)
171+
.map(move |r| r.map_err(|e| map_err(py2, e, iter.as_ref()))),
172+
)),
173+
GenericIterable::Iterator(iter) => Ok(Box::new(
174+
iter.iter()?
175+
.map(extract_items)
176+
.map(move |r| r.map_err(|e| map_err(py2, e, iter.as_ref()))),
177+
)),
178+
GenericIterable::JsonArray(iter) => Ok(Box::new(
179+
iter.iter()
180+
.map(move |v| extract_items(Ok(v.to_object(py).into_ref(py))))
181+
.map(move |r| r.map_err(|e| map_err(py2, e, iter.to_object(py).into_ref(py)))),
182+
)),
183+
// Note that we iterate over (key, value), unlike doing iter({}) in Python
184+
GenericIterable::JsonObject(iter) => Ok(Box::new(iter.iter().map(move |(k, v)| {
185+
let k = PyString::new(py, k).as_ref();
186+
let v = v.to_object(py).into_ref(py);
187+
Ok((k, v))
188+
}))),
189+
}
190+
}
191+
}

src/input/input_abstract.rs

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ use crate::errors::{InputValue, LocItem, ValResult};
77
use crate::{PyMultiHostUrl, PyUrl};
88

99
use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
10+
use super::generic_iterable::GenericIterable;
1011
use super::return_enums::{EitherBytes, EitherString};
11-
use super::{GenericArguments, GenericCollection, GenericIterator, GenericMapping, JsonInput};
12+
use super::{GenericArguments, GenericIterator, GenericMapping, JsonInput};
1213

1314
#[derive(Debug, Clone, Copy)]
1415
pub enum InputType {
@@ -166,57 +167,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
166167
self.validate_dict(strict)
167168
}
168169

169-
fn validate_list(&'a self, strict: bool, allow_any_iter: bool) -> ValResult<GenericCollection<'a>> {
170-
if strict && !allow_any_iter {
171-
self.strict_list()
172-
} else {
173-
self.lax_list(allow_any_iter)
174-
}
175-
}
176-
fn strict_list(&'a self) -> ValResult<GenericCollection<'a>>;
177-
#[cfg_attr(has_no_coverage, no_coverage)]
178-
fn lax_list(&'a self, _allow_any_iter: bool) -> ValResult<GenericCollection<'a>> {
179-
self.strict_list()
180-
}
181-
182-
fn validate_tuple(&'a self, strict: bool) -> ValResult<GenericCollection<'a>> {
183-
if strict {
184-
self.strict_tuple()
185-
} else {
186-
self.lax_tuple()
187-
}
188-
}
189-
fn strict_tuple(&'a self) -> ValResult<GenericCollection<'a>>;
190-
#[cfg_attr(has_no_coverage, no_coverage)]
191-
fn lax_tuple(&'a self) -> ValResult<GenericCollection<'a>> {
192-
self.strict_tuple()
193-
}
194-
195-
fn validate_set(&'a self, strict: bool) -> ValResult<GenericCollection<'a>> {
196-
if strict {
197-
self.strict_set()
198-
} else {
199-
self.lax_set()
200-
}
201-
}
202-
fn strict_set(&'a self) -> ValResult<GenericCollection<'a>>;
203-
#[cfg_attr(has_no_coverage, no_coverage)]
204-
fn lax_set(&'a self) -> ValResult<GenericCollection<'a>> {
205-
self.strict_set()
206-
}
207-
208-
fn validate_frozenset(&'a self, strict: bool) -> ValResult<GenericCollection<'a>> {
209-
if strict {
210-
self.strict_frozenset()
211-
} else {
212-
self.lax_frozenset()
213-
}
214-
}
215-
fn strict_frozenset(&'a self) -> ValResult<GenericCollection<'a>>;
216-
#[cfg_attr(has_no_coverage, no_coverage)]
217-
fn lax_frozenset(&'a self) -> ValResult<GenericCollection<'a>> {
218-
self.strict_frozenset()
219-
}
170+
fn extract_iterable(&'a self) -> ValResult<GenericIterable<'a>>;
220171

221172
fn validate_iter(&self) -> ValResult<GenericIterator>;
222173

0 commit comments

Comments
 (0)