Skip to content

Commit c1e50d2

Browse files
committed
rework to handle sets
1 parent f4763b9 commit c1e50d2

File tree

15 files changed

+660
-623
lines changed

15 files changed

+660
-623
lines changed

benches/main.rs

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ extern crate test;
55
use test::{black_box, Bencher};
66

77
use pyo3::prelude::*;
8-
use pyo3::types::{PyDict, PyString};
8+
use pyo3::types::{PyDict, PySet, PyString};
99

1010
use _pydantic_core::SchemaValidator;
1111

@@ -696,3 +696,58 @@ class Foo(Enum):
696696
}
697697
})
698698
}
699+
700+
const COLLECTION_SIZE: usize = 100_000;
701+
702+
#[bench]
703+
fn constructing_pyset_from_vec_without_capacity(bench: &mut Bencher) {
704+
Python::with_gil(|py| {
705+
let input: Vec<PyObject> = (0..COLLECTION_SIZE).map(|v| v.to_object(py)).collect();
706+
707+
bench.iter(|| {
708+
black_box({
709+
let mut output = Vec::new();
710+
for x in &input {
711+
output.push(x);
712+
}
713+
let set = PySet::new(py, output.iter()).unwrap();
714+
set
715+
})
716+
})
717+
})
718+
}
719+
720+
#[bench]
721+
fn constructing_pyset_from_vec_with_capacity(bench: &mut Bencher) {
722+
Python::with_gil(|py| {
723+
let input: Vec<PyObject> = (0..COLLECTION_SIZE).map(|v| v.to_object(py)).collect();
724+
725+
bench.iter(|| {
726+
black_box({
727+
let mut output = Vec::with_capacity(COLLECTION_SIZE);
728+
for x in &input {
729+
output.push(x);
730+
}
731+
let set = PySet::new(py, output.iter()).unwrap();
732+
set
733+
})
734+
})
735+
})
736+
}
737+
738+
#[bench]
739+
fn constructing_pyset_from_vec_directly(bench: &mut Bencher) {
740+
Python::with_gil(|py| {
741+
let input: Vec<PyObject> = (0..COLLECTION_SIZE).map(|v| v.to_object(py)).collect();
742+
743+
bench.iter(|| {
744+
black_box({
745+
let output = PySet::new(py, &Vec::<i64>::new()).unwrap();
746+
for x in &input {
747+
output.add(x).unwrap();
748+
}
749+
output
750+
})
751+
})
752+
})
753+
}

src/input/any_iterable.rs

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
use super::parse_json::{JsonInput, JsonObject};
2-
use pyo3::types::{PyDict, PyFrozenSet, PyIterator, PyList, PyMapping, PySequence, PySet, PyTuple};
2+
use pyo3::{
3+
exceptions::PyTypeError,
4+
types::{
5+
PyByteArray, PyBytes, PyDict, PyFrozenSet, PyIterator, PyList, PyMapping, PySequence, PySet, PyString, PyTuple,
6+
},
7+
PyAny, PyResult, Python, ToPyObject,
8+
};
39

4-
pub enum AnyIterable<'a> {
10+
#[derive(Debug)]
11+
pub enum GenericIterable<'a> {
512
List(&'a PyList),
613
Tuple(&'a PyTuple),
714
Set(&'a PySet),
@@ -13,8 +20,100 @@ pub enum AnyIterable<'a> {
1320
DictValues(&'a PyIterator),
1421
DictItems(&'a PyIterator),
1522
Mapping(&'a PyMapping),
23+
String(&'a PyString),
24+
Bytes(&'a PyBytes),
25+
PyByteArray(&'a PyByteArray),
1626
Sequence(&'a PySequence),
1727
Iterator(&'a PyIterator),
1828
JsonArray(&'a [JsonInput]),
1929
JsonObject(&'a JsonObject),
2030
}
31+
32+
type PyMappingItems<'a> = (&'a PyAny, &'a PyAny);
33+
34+
#[inline(always)]
35+
fn extract_items(item: PyResult<&PyAny>) -> PyResult<PyMappingItems<'_>> {
36+
match item {
37+
Ok(v) => v.extract::<PyMappingItems>(),
38+
Err(e) => Err(e),
39+
}
40+
}
41+
42+
impl<'a, 'py: 'a> GenericIterable<'a> {
43+
pub fn into_sequence_iterator(
44+
self,
45+
py: Python<'py>,
46+
) -> PyResult<Box<dyn Iterator<Item = PyResult<&'a PyAny>> + 'a>> {
47+
match self {
48+
GenericIterable::List(iter) => Ok(Box::new(iter.iter().map(Ok))),
49+
GenericIterable::Tuple(iter) => Ok(Box::new(iter.iter().map(Ok))),
50+
GenericIterable::Set(iter) => Ok(Box::new(iter.iter().map(Ok))),
51+
GenericIterable::FrozenSet(iter) => Ok(Box::new(iter.iter().map(Ok))),
52+
// Note that this iterates over only the keys, just like doing iter({}) in Python
53+
GenericIterable::Dict(iter) => Ok(Box::new(iter.iter().map(|(k, _)| Ok(k)))),
54+
GenericIterable::DictKeys(iter) => Ok(Box::new(iter)),
55+
GenericIterable::DictValues(iter) => Ok(Box::new(iter)),
56+
GenericIterable::DictItems(iter) => Ok(Box::new(iter)),
57+
// Note that this iterates over only the keys, just like doing iter({}) in Python
58+
GenericIterable::Mapping(iter) => Ok(Box::new(iter.keys()?.iter()?)),
59+
GenericIterable::String(iter) => Ok(Box::new(iter.iter()?)),
60+
GenericIterable::Bytes(iter) => Ok(Box::new(iter.iter()?)),
61+
GenericIterable::PyByteArray(iter) => Ok(Box::new(iter.iter()?)),
62+
GenericIterable::Sequence(iter) => Ok(Box::new(iter.iter()?)),
63+
GenericIterable::Iterator(iter) => Ok(Box::new(iter)),
64+
GenericIterable::JsonArray(iter) => Ok(Box::new(iter.iter().map(move |v| {
65+
let v = v.to_object(py);
66+
Ok(v.into_ref(py))
67+
}))),
68+
// Note that this iterates over only the keys, just like doing iter({}) in Python, just for consistency
69+
GenericIterable::JsonObject(iter) => Ok(Box::new(
70+
iter.iter().map(move |(k, _)| Ok(k.to_object(py).into_ref(py))),
71+
)),
72+
}
73+
}
74+
75+
pub fn into_mapping_items_iterator(
76+
self,
77+
py: Python<'py>,
78+
) -> PyResult<Box<dyn Iterator<Item = PyResult<PyMappingItems<'a>>> + 'a>> {
79+
match self {
80+
GenericIterable::List(iter) => Ok(Box::new(iter.iter().map(|v| extract_items(Ok(v))))),
81+
GenericIterable::Tuple(iter) => Ok(Box::new(iter.iter().map(|v| extract_items(Ok(v))))),
82+
GenericIterable::Set(iter) => Ok(Box::new(iter.iter().map(|v| extract_items(Ok(v))))),
83+
GenericIterable::FrozenSet(iter) => Ok(Box::new(iter.iter().map(|v| extract_items(Ok(v))))),
84+
// Note that we iterate over (key, value), unlike doing iter({}) in Python
85+
GenericIterable::Dict(iter) => Ok(Box::new(iter.iter().map(Ok))),
86+
// Keys or values can be tuples
87+
GenericIterable::DictKeys(iter) => Ok(Box::new(iter.map(extract_items))),
88+
GenericIterable::DictValues(iter) => Ok(Box::new(iter.map(extract_items))),
89+
GenericIterable::DictItems(iter) => Ok(Box::new(iter.map(extract_items))),
90+
// Note that we iterate over (key, value), unlike doing iter({}) in Python
91+
GenericIterable::Mapping(iter) => Ok(Box::new(iter.items()?.iter()?.map(extract_items))),
92+
// In Python if you do dict("foobar") you get "dictionary update sequence element #0 has length 1; 2 is required"
93+
// This is similar but arguably a better error message
94+
GenericIterable::String(_) => Err(PyTypeError::new_err(
95+
"Expected an iterable of (key, value) pairs, got a string",
96+
)),
97+
GenericIterable::Bytes(_) => Err(PyTypeError::new_err(
98+
"Expected an iterable of (key, value) pairs, got a bytes",
99+
)),
100+
GenericIterable::PyByteArray(_) => Err(PyTypeError::new_err(
101+
"Expected an iterable of (key, value) pairs, got a bytearray",
102+
)),
103+
// Obviously these may be things that are not convertible to a tuple of (Hashable, Any)
104+
// Python fails with a similar error message to above, ours will be slightly different (PyO3 will fail to extract) but similar enough
105+
GenericIterable::Sequence(iter) => Ok(Box::new(iter.iter()?.map(extract_items))),
106+
GenericIterable::Iterator(iter) => Ok(Box::new(iter.iter()?.map(extract_items))),
107+
GenericIterable::JsonArray(iter) => Ok(Box::new(
108+
iter.iter()
109+
.map(move |v| extract_items(Ok(v.to_object(py).into_ref(py)))),
110+
)),
111+
// Note that we iterate over (key, value), unlike doing iter({}) in Python
112+
GenericIterable::JsonObject(iter) => Ok(Box::new(iter.iter().map(move |(k, v)| {
113+
let k = PyString::new(py, k).as_ref();
114+
let v = v.to_object(py).into_ref(py);
115+
Ok((k, v))
116+
}))),
117+
}
118+
}
119+
}

src/input/input_abstract.rs

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use pyo3::{intern, prelude::*};
66
use crate::errors::{InputValue, LocItem, ValResult};
77
use crate::{PyMultiHostUrl, PyUrl};
88

9-
use super::any_iterable::AnyIterable;
9+
use super::any_iterable::GenericIterable;
1010
use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
1111
use super::return_enums::{EitherBytes, EitherString};
1212
use super::{GenericArguments, GenericCollection, GenericIterator, GenericMapping, JsonInput};
@@ -167,21 +167,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
167167
self.validate_dict(strict)
168168
}
169169

170-
fn validate_list(&'a self, strict: bool, allow_any_iter: bool) -> ValResult<GenericCollection<'a>> {
171-
if strict && !allow_any_iter {
172-
self.strict_list()
173-
} else {
174-
self.lax_list(allow_any_iter)
175-
}
176-
}
177-
178-
fn strict_list(&'a self) -> ValResult<GenericCollection<'a>>;
179-
#[cfg_attr(has_no_coverage, no_coverage)]
180-
fn lax_list(&'a self, _allow_any_iter: bool) -> ValResult<GenericCollection<'a>> {
181-
self.strict_list()
182-
}
183-
184-
fn extract_iterable(&'a self) -> ValResult<AnyIterable<'a>>;
170+
fn extract_iterable(&'a self) -> ValResult<GenericIterable<'a>>;
185171

186172
fn validate_tuple(&'a self, strict: bool) -> ValResult<GenericCollection<'a>> {
187173
if strict {

src/input/input_json.rs

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,6 @@ impl<'a> Input<'a> for JsonInput {
187187
self.validate_dict(false)
188188
}
189189

190-
fn validate_list(&'a self, _strict: bool, _allow_any_iter: bool) -> ValResult<GenericCollection<'a>> {
191-
match self {
192-
JsonInput::Array(a) => Ok(a.into()),
193-
_ => Err(ValError::new(ErrorType::ListType, self)),
194-
}
195-
}
196-
#[cfg_attr(has_no_coverage, no_coverage)]
197-
fn strict_list(&'a self) -> ValResult<GenericCollection<'a>> {
198-
self.validate_list(false, false)
199-
}
200-
201190
fn validate_tuple(&'a self, _strict: bool) -> ValResult<GenericCollection<'a>> {
202191
// just as in set's case, List has to be allowed
203192
match self {
@@ -234,10 +223,10 @@ impl<'a> Input<'a> for JsonInput {
234223
self.validate_frozenset(false)
235224
}
236225

237-
fn extract_iterable(&'a self) -> ValResult<super::any_iterable::AnyIterable<'a>> {
226+
fn extract_iterable(&'a self) -> ValResult<super::any_iterable::GenericIterable<'a>> {
238227
match self {
239-
JsonInput::Array(a) => Ok(super::any_iterable::AnyIterable::JsonArray(a)),
240-
JsonInput::Object(o) => Ok(super::any_iterable::AnyIterable::JsonObject(o)),
228+
JsonInput::Array(a) => Ok(super::any_iterable::GenericIterable::JsonArray(a)),
229+
JsonInput::Object(o) => Ok(super::any_iterable::GenericIterable::JsonObject(o)),
241230
_ => Err(ValError::new(ErrorType::IterableType, self)),
242231
}
243232
}
@@ -412,16 +401,7 @@ impl<'a> Input<'a> for String {
412401
self.validate_dict(false)
413402
}
414403

415-
#[cfg_attr(has_no_coverage, no_coverage)]
416-
fn validate_list(&'a self, _strict: bool, _allow_any_iter: bool) -> ValResult<GenericCollection<'a>> {
417-
Err(ValError::new(ErrorType::ListType, self))
418-
}
419-
#[cfg_attr(has_no_coverage, no_coverage)]
420-
fn strict_list(&'a self) -> ValResult<GenericCollection<'a>> {
421-
self.validate_list(false, false)
422-
}
423-
424-
fn extract_iterable(&'a self) -> ValResult<super::any_iterable::AnyIterable<'a>> {
404+
fn extract_iterable(&'a self) -> ValResult<super::any_iterable::GenericIterable<'a>> {
425405
Err(ValError::new(ErrorType::IterableType, self))
426406
}
427407

src/input/input_python.rs

Lines changed: 21 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use pyo3::types::{
1111
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
1212
use pyo3::{ffi, intern, AsPyPointer, PyTypeInfo};
1313

14-
use super::any_iterable::AnyIterable;
14+
use super::any_iterable::GenericIterable;
1515
use crate::build_tools::safe_repr;
1616
use crate::errors::{ErrorType, InputValue, LocItem, ValError, ValResult};
1717
use crate::{ArgsKwargs, PyMultiHostUrl, PyUrl};
@@ -437,30 +437,6 @@ impl<'a> Input<'a> for PyAny {
437437
}
438438
}
439439

440-
fn strict_list(&'a self) -> ValResult<GenericCollection<'a>> {
441-
if let Ok(list) = self.downcast::<PyList>() {
442-
Ok(list.into())
443-
} else {
444-
Err(ValError::new(ErrorType::ListType, self))
445-
}
446-
}
447-
448-
fn lax_list(&'a self, allow_any_iter: bool) -> ValResult<GenericCollection<'a>> {
449-
if let Ok(list) = self.downcast::<PyList>() {
450-
Ok(list.into())
451-
} else if let Ok(tuple) = self.downcast::<PyTuple>() {
452-
Ok(tuple.into())
453-
} else if let Some(collection) = extract_dict_iter!(self) {
454-
Ok(collection)
455-
} else if allow_any_iter && self.iter().is_ok() {
456-
Ok(self.into())
457-
} else if let Some(collection) = extract_shared_iter!(PyList, self) {
458-
Ok(collection)
459-
} else {
460-
Err(ValError::new(ErrorType::ListType, self))
461-
}
462-
}
463-
464440
fn strict_tuple(&'a self) -> ValResult<GenericCollection<'a>> {
465441
if let Ok(tuple) = self.downcast::<PyTuple>() {
466442
Ok(tuple.into())
@@ -535,37 +511,36 @@ impl<'a> Input<'a> for PyAny {
535511
}
536512
}
537513

538-
fn extract_iterable(&'a self) -> ValResult<super::any_iterable::AnyIterable<'a>> {
514+
fn extract_iterable(&'a self) -> ValResult<super::any_iterable::GenericIterable<'a>> {
539515
// Handle concrete non-overlapping types first, then abstract types
540516
if let Ok(iterable) = self.downcast::<PyList>() {
541-
Ok(AnyIterable::List(iterable))
517+
Ok(GenericIterable::List(iterable))
542518
} else if let Ok(iterable) = self.downcast::<PyTuple>() {
543-
Ok(AnyIterable::Tuple(iterable))
519+
Ok(GenericIterable::Tuple(iterable))
544520
} else if let Ok(iterable) = self.downcast::<PySet>() {
545-
Ok(AnyIterable::Set(iterable))
521+
Ok(GenericIterable::Set(iterable))
546522
} else if let Ok(iterable) = self.downcast::<PyFrozenSet>() {
547-
Ok(AnyIterable::FrozenSet(iterable))
523+
Ok(GenericIterable::FrozenSet(iterable))
548524
} else if let Ok(iterable) = self.downcast::<PyDict>() {
549-
Ok(AnyIterable::Dict(iterable))
525+
Ok(GenericIterable::Dict(iterable))
550526
} else if let Some(iterable) = extract_dict_keys!(self.py(), self) {
551-
Ok(AnyIterable::DictKeys(iterable))
527+
Ok(GenericIterable::DictKeys(iterable))
552528
} else if let Some(iterable) = extract_dict_values!(self.py(), self) {
553-
Ok(AnyIterable::DictValues(iterable))
529+
Ok(GenericIterable::DictValues(iterable))
554530
} else if let Some(iterable) = extract_dict_items!(self.py(), self) {
555-
Ok(AnyIterable::DictItems(iterable))
531+
Ok(GenericIterable::DictItems(iterable))
556532
} else if let Ok(iterable) = self.downcast::<PyMapping>() {
557-
Ok(AnyIterable::Mapping(iterable))
558-
} else if let (Ok(iterable), Err(_), Err(_)) = (
559-
self.downcast::<PySequence>(),
560-
// Explicitly disallow strings and bytes since they are sequences
561-
// but you almost never want to treat it as one
562-
// This can be worked around by allowing arbitrary iterables
563-
self.downcast::<PyString>(),
564-
self.downcast::<PyBytes>(),
565-
) {
566-
Ok(AnyIterable::Sequence(iterable))
567-
} else if let Ok(iterable) = PyIterator::from_object(self.py(), self) {
568-
Ok(AnyIterable::Iterator(iterable))
533+
Ok(GenericIterable::Mapping(iterable))
534+
} else if let Ok(iterable) = self.downcast::<PyString>() {
535+
Ok(GenericIterable::String(iterable))
536+
} else if let Ok(iterable) = self.downcast::<PyBytes>() {
537+
Ok(GenericIterable::Bytes(iterable))
538+
} else if let Ok(iterable) = self.downcast::<PyByteArray>() {
539+
Ok(GenericIterable::PyByteArray(iterable))
540+
} else if let Ok(iterable) = self.downcast::<PySequence>() {
541+
Ok(GenericIterable::Sequence(iterable))
542+
} else if let Ok(iterable) = self.iter() {
543+
Ok(GenericIterable::Iterator(iterable))
569544
} else {
570545
Err(ValError::new(ErrorType::IterableType, self))
571546
}

0 commit comments

Comments
 (0)