Skip to content

Commit 8d12f96

Browse files
authored
Int validation performance improvement (#620)
1 parent 29f6895 commit 8d12f96

File tree

13 files changed

+155
-106
lines changed

13 files changed

+155
-106
lines changed

pydantic_core/core_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3761,6 +3761,7 @@ def definition_reference_schema(
37613761
'int_type',
37623762
'int_parsing',
37633763
'int_from_float',
3764+
'int_overflow',
37643765
'float_type',
37653766
'float_parsing',
37663767
'bytes_type',

src/errors/types.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ pub enum ErrorType {
175175
IntType,
176176
IntParsing,
177177
IntFromFloat,
178+
IntOverflow,
178179
// ---------------------
179180
// float errors
180181
FloatType,
@@ -488,6 +489,7 @@ impl ErrorType {
488489
Self::IntType => "Input should be a valid integer",
489490
Self::IntParsing => "Input should be a valid integer, unable to parse string as an integer",
490491
Self::IntFromFloat => "Input should be a valid integer, got a number with a fractional part",
492+
Self::IntOverflow => "Input integer too large to convert to 64-bit integer",
491493
Self::FloatType => "Input should be a valid number",
492494
Self::FloatParsing => "Input should be a valid number, unable to parse string as an number",
493495
Self::BytesType => "Input should be a valid bytes",

src/input/input_abstract.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::errors::{InputValue, LocItem, ValResult};
77
use crate::{PyMultiHostUrl, PyUrl};
88

99
use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
10-
use super::return_enums::{EitherBytes, EitherString};
10+
use super::return_enums::{EitherBytes, EitherInt, EitherString};
1111
use super::{GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonInput};
1212

1313
#[derive(Debug, Clone, Copy)]
@@ -90,8 +90,6 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
9090
self.strict_str()
9191
}
9292

93-
fn as_str_strict(&self) -> Option<&str>;
94-
9593
fn validate_bytes(&'a self, strict: bool) -> ValResult<EitherBytes<'a>> {
9694
if strict {
9795
self.strict_bytes()
@@ -118,21 +116,19 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
118116
self.strict_bool()
119117
}
120118

121-
fn validate_int(&self, strict: bool) -> ValResult<i64> {
119+
fn validate_int(&'a self, strict: bool) -> ValResult<EitherInt<'a>> {
122120
if strict {
123121
self.strict_int()
124122
} else {
125123
self.lax_int()
126124
}
127125
}
128-
fn strict_int(&self) -> ValResult<i64>;
126+
fn strict_int(&'a self) -> ValResult<EitherInt<'a>>;
129127
#[cfg_attr(has_no_coverage, no_coverage)]
130-
fn lax_int(&self) -> ValResult<i64> {
128+
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
131129
self.strict_int()
132130
}
133131

134-
fn as_int_strict(&self) -> Option<i64>;
135-
136132
fn validate_float(&self, strict: bool, ultra_strict: bool) -> ValResult<f64> {
137133
if ultra_strict {
138134
self.ultra_strict_float()

src/input/input_json.rs

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use super::datetime::{
1010
use super::parse_json::JsonArray;
1111
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_int};
1212
use super::{
13-
EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator, GenericMapping,
14-
Input, JsonArgs, JsonInput,
13+
EitherBytes, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator,
14+
GenericMapping, Input, JsonArgs, JsonInput,
1515
};
1616

1717
impl<'a> Input<'a> for JsonInput {
@@ -84,13 +84,6 @@ impl<'a> Input<'a> for JsonInput {
8484
}
8585
}
8686

87-
fn as_str_strict(&self) -> Option<&str> {
88-
match self {
89-
JsonInput::String(s) => Some(s.as_str()),
90-
_ => None,
91-
}
92-
}
93-
9487
fn validate_bytes(&'a self, _strict: bool) -> ValResult<EitherBytes<'a>> {
9588
match self {
9689
JsonInput::String(s) => Ok(s.as_bytes().into()),
@@ -121,14 +114,14 @@ impl<'a> Input<'a> for JsonInput {
121114
}
122115
}
123116

124-
fn strict_int(&self) -> ValResult<i64> {
117+
fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
125118
match self {
126-
JsonInput::Int(i) => Ok(*i),
119+
JsonInput::Int(i) => Ok(EitherInt::Rust(*i)),
127120
_ => Err(ValError::new(ErrorType::IntType, self)),
128121
}
129122
}
130-
fn lax_int(&self) -> ValResult<i64> {
131-
match self {
123+
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
124+
let int_result = match self {
132125
JsonInput::Bool(b) => match *b {
133126
true => Ok(1),
134127
false => Ok(0),
@@ -137,14 +130,8 @@ impl<'a> Input<'a> for JsonInput {
137130
JsonInput::Float(f) => float_as_int(self, *f),
138131
JsonInput::String(str) => str_as_int(self, str),
139132
_ => Err(ValError::new(ErrorType::IntType, self)),
140-
}
141-
}
142-
143-
fn as_int_strict(&self) -> Option<i64> {
144-
match self {
145-
JsonInput::Int(i) => Some(*i),
146-
_ => None,
147-
}
133+
};
134+
int_result.map(EitherInt::Rust)
148135
}
149136

150137
fn ultra_strict_float(&self) -> ValResult<f64> {
@@ -356,10 +343,6 @@ impl<'a> Input<'a> for String {
356343
self.validate_str(false)
357344
}
358345

359-
fn as_str_strict(&self) -> Option<&str> {
360-
Some(self.as_str())
361-
}
362-
363346
fn validate_bytes(&'a self, _strict: bool) -> ValResult<EitherBytes<'a>> {
364347
Ok(self.as_bytes().into())
365348
}
@@ -375,20 +358,16 @@ impl<'a> Input<'a> for String {
375358
str_as_bool(self, self)
376359
}
377360

378-
fn strict_int(&self) -> ValResult<i64> {
361+
fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
379362
Err(ValError::new(ErrorType::IntType, self))
380363
}
381-
fn lax_int(&self) -> ValResult<i64> {
364+
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
382365
match self.parse() {
383-
Ok(i) => Ok(i),
366+
Ok(i) => Ok(EitherInt::Rust(i)),
384367
Err(_) => Err(ValError::new(ErrorType::IntParsing, self)),
385368
}
386369
}
387370

388-
fn as_int_strict(&self) -> Option<i64> {
389-
None
390-
}
391-
392371
#[cfg_attr(has_no_coverage, no_coverage)]
393372
fn ultra_strict_float(&self) -> ValResult<f64> {
394373
self.strict_float()

src/input/input_python.rs

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
use std::borrow::Cow;
22
use std::str::from_utf8;
33

4-
use pyo3::once_cell::GILOnceCell;
54
use pyo3::prelude::*;
65
use pyo3::types::{
76
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyInt, PyIterator, PyList,
@@ -22,8 +21,8 @@ use super::datetime::{
2221
};
2322
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_int};
2423
use super::{
25-
py_string_str, EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator,
26-
GenericMapping, Input, JsonInput, PyArgs,
24+
py_string_str, EitherBytes, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
25+
GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
2726
};
2827

2928
#[cfg(not(PyPy))]
@@ -227,14 +226,6 @@ impl<'a> Input<'a> for PyAny {
227226
}
228227
}
229228

230-
fn as_str_strict(&self) -> Option<&str> {
231-
if self.get_type().is(get_py_str_type(self.py())) {
232-
self.extract().ok()
233-
} else {
234-
None
235-
}
236-
}
237-
238229
fn strict_bytes(&'a self) -> ValResult<EitherBytes<'a>> {
239230
if let Ok(py_bytes) = self.downcast::<PyBytes>() {
240231
Ok(py_bytes.into())
@@ -281,37 +272,34 @@ impl<'a> Input<'a> for PyAny {
281272
}
282273
}
283274

284-
fn strict_int(&self) -> ValResult<i64> {
285-
// bool check has to come before int check as bools would be cast to ints below
286-
if self.extract::<bool>().is_ok() {
287-
Err(ValError::new(ErrorType::IntType, self))
288-
} else if let Ok(int) = self.extract::<i64>() {
289-
Ok(int)
275+
fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
276+
if PyInt::is_exact_type_of(self) {
277+
Ok(EitherInt::Py(self))
278+
} else if PyInt::is_type_of(self) {
279+
// bools are a subclass of int, so check for bool type in this specific case
280+
if PyBool::is_exact_type_of(self) {
281+
Err(ValError::new(ErrorType::IntType, self))
282+
} else {
283+
Ok(EitherInt::Py(self))
284+
}
290285
} else {
291286
Err(ValError::new(ErrorType::IntType, self))
292287
}
293288
}
294289

295-
fn lax_int(&self) -> ValResult<i64> {
296-
if let Ok(int) = self.extract::<i64>() {
297-
Ok(int)
290+
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
291+
if PyInt::is_exact_type_of(self) {
292+
Ok(EitherInt::Py(self))
298293
} else if let Some(cow_str) = maybe_as_string(self, ErrorType::IntParsing)? {
299-
str_as_int(self, &cow_str)
294+
let int = str_as_int(self, &cow_str)?;
295+
Ok(EitherInt::Rust(int))
300296
} else if let Ok(float) = self.extract::<f64>() {
301-
float_as_int(self, float)
297+
Ok(EitherInt::Rust(float_as_int(self, float)?))
302298
} else {
303299
Err(ValError::new(ErrorType::IntType, self))
304300
}
305301
}
306302

307-
fn as_int_strict(&self) -> Option<i64> {
308-
if self.get_type().is(get_py_int_type(self.py())) {
309-
self.extract().ok()
310-
} else {
311-
None
312-
}
313-
}
314-
315303
fn ultra_strict_float(&self) -> ValResult<f64> {
316304
if matches!(self.is_instance_of::<PyInt>(), Ok(true)) {
317305
Err(ValError::new(ErrorType::FloatType, self))
@@ -322,10 +310,13 @@ impl<'a> Input<'a> for PyAny {
322310
}
323311
}
324312
fn strict_float(&self) -> ValResult<f64> {
325-
if self.extract::<bool>().is_ok() {
326-
Err(ValError::new(ErrorType::FloatType, self))
327-
} else if let Ok(float) = self.extract::<f64>() {
328-
Ok(float)
313+
if let Ok(float) = self.extract::<f64>() {
314+
// bools are cast to floats as either 0.0 or 1.0, so check for bool type in this specific case
315+
if (float == 0.0 || float == 1.0) && PyBool::is_exact_type_of(self) {
316+
Err(ValError::new(ErrorType::FloatType, self))
317+
} else {
318+
Ok(float)
319+
}
329320
} else {
330321
Err(ValError::new(ErrorType::FloatType, self))
331322
}
@@ -515,7 +506,7 @@ impl<'a> Input<'a> for PyAny {
515506
}
516507

517508
fn strict_date(&self) -> ValResult<EitherDate> {
518-
if self.downcast::<PyDateTime>().is_ok() {
509+
if PyDateTime::is_type_of(self) {
519510
// have to check if it's a datetime first, otherwise the line below converts to a date
520511
Err(ValError::new(ErrorType::DateType, self))
521512
} else if let Ok(date) = self.downcast::<PyDate>() {
@@ -526,7 +517,7 @@ impl<'a> Input<'a> for PyAny {
526517
}
527518

528519
fn lax_date(&self) -> ValResult<EitherDate> {
529-
if self.downcast::<PyDateTime>().is_ok() {
520+
if PyDateTime::is_type_of(self) {
530521
// have to check if it's a datetime first, otherwise the line below converts to a date
531522
// even if we later try coercion from a datetime, we don't want to return a datetime now
532523
Err(ValError::new(ErrorType::DateType, self))
@@ -558,7 +549,7 @@ impl<'a> Input<'a> for PyAny {
558549
bytes_as_time(self, str.as_bytes())
559550
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
560551
bytes_as_time(self, py_bytes.as_bytes())
561-
} else if self.downcast::<PyBool>().is_ok() {
552+
} else if PyBool::is_exact_type_of(self) {
562553
Err(ValError::new(ErrorType::TimeType, self))
563554
} else if let Ok(int) = self.extract::<i64>() {
564555
int_as_time(self, int, 0)
@@ -585,7 +576,7 @@ impl<'a> Input<'a> for PyAny {
585576
bytes_as_datetime(self, str.as_bytes())
586577
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
587578
bytes_as_datetime(self, py_bytes.as_bytes())
588-
} else if self.downcast::<PyBool>().is_ok() {
579+
} else if PyBool::is_exact_type_of(self) {
589580
Err(ValError::new(ErrorType::DatetimeType, self))
590581
} else if let Ok(int) = self.extract::<i64>() {
591582
int_as_datetime(self, int, 0)
@@ -661,7 +652,7 @@ fn is_builtin_str(py_str: &PyString) -> bool {
661652
}
662653

663654
#[cfg(PyPy)]
664-
static DICT_KEYS_TYPE: GILOnceCell<Py<PyType>> = GILOnceCell::new();
655+
static DICT_KEYS_TYPE: pyo3::once_cell::GILOnceCell<Py<PyType>> = pyo3::once_cell::GILOnceCell::new();
665656

666657
#[cfg(PyPy)]
667658
fn is_dict_keys_type(v: &PyAny) -> bool {
@@ -679,7 +670,7 @@ fn is_dict_keys_type(v: &PyAny) -> bool {
679670
}
680671

681672
#[cfg(PyPy)]
682-
static DICT_VALUES_TYPE: GILOnceCell<Py<PyType>> = GILOnceCell::new();
673+
static DICT_VALUES_TYPE: pyo3::once_cell::GILOnceCell<Py<PyType>> = pyo3::once_cell::GILOnceCell::new();
683674

684675
#[cfg(PyPy)]
685676
fn is_dict_values_type(v: &PyAny) -> bool {
@@ -697,7 +688,7 @@ fn is_dict_values_type(v: &PyAny) -> bool {
697688
}
698689

699690
#[cfg(PyPy)]
700-
static DICT_ITEMS_TYPE: GILOnceCell<Py<PyType>> = GILOnceCell::new();
691+
static DICT_ITEMS_TYPE: pyo3::once_cell::GILOnceCell<Py<PyType>> = pyo3::once_cell::GILOnceCell::new();
701692

702693
#[cfg(PyPy)]
703694
fn is_dict_items_type(v: &PyAny) -> bool {
@@ -722,15 +713,3 @@ pub fn list_as_tuple(list: &PyList) -> &PyTuple {
722713
};
723714
py_tuple.into_ref(list.py())
724715
}
725-
726-
static PY_INT_TYPE: GILOnceCell<PyObject> = GILOnceCell::new();
727-
728-
fn get_py_int_type(py: Python) -> &PyObject {
729-
PY_INT_TYPE.get_or_init(py, || PyInt::type_object(py).into())
730-
}
731-
732-
static PY_STR_TYPE: GILOnceCell<PyObject> = GILOnceCell::new();
733-
734-
fn get_py_str_type(py: Python) -> &PyObject {
735-
PY_STR_TYPE.get_or_init(py, || PyString::type_object(py).into())
736-
}

src/input/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ pub(crate) use datetime::{
1717
pub(crate) use input_abstract::{Input, InputType};
1818
pub(crate) use parse_json::{JsonInput, JsonObject};
1919
pub(crate) use return_enums::{
20-
py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherString, GenericArguments,
21-
GenericIterable, GenericIterator, GenericMapping, JsonArgs, JsonObjectGenericIterator, MappingGenericIterator,
22-
PyArgs,
20+
py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherInt, EitherString,
21+
GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonArgs, JsonObjectGenericIterator,
22+
MappingGenericIterator, PyArgs,
2323
};
2424

2525
// Defined here as it's not exported by pyo3

src/input/return_enums.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,3 +817,29 @@ impl<'a> IntoPy<PyObject> for EitherBytes<'a> {
817817
}
818818
}
819819
}
820+
821+
#[cfg_attr(debug_assertions, derive(Debug))]
822+
pub enum EitherInt<'a> {
823+
Rust(i64),
824+
Py(&'a PyAny),
825+
}
826+
827+
impl<'a> TryInto<i64> for EitherInt<'a> {
828+
type Error = ValError<'a>;
829+
830+
fn try_into(self) -> ValResult<'a, i64> {
831+
match self {
832+
EitherInt::Rust(i) => Ok(i),
833+
EitherInt::Py(i) => i.extract().map_err(|_| ValError::new(ErrorType::IntOverflow, i)),
834+
}
835+
}
836+
}
837+
838+
impl<'a> IntoPy<PyObject> for EitherInt<'a> {
839+
fn into_py(self, py: Python<'_>) -> PyObject {
840+
match self {
841+
Self::Rust(int) => int.into_py(py),
842+
Self::Py(int) => int.into_py(py),
843+
}
844+
}
845+
}

0 commit comments

Comments
 (0)