Skip to content

Commit 60ed51f

Browse files
authored
Use PyInt for inequality and multiple of checks (#634)
1 parent 6220455 commit 60ed51f

File tree

15 files changed

+361
-87
lines changed

15 files changed

+361
-87
lines changed

Cargo.lock

Lines changed: 32 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ url = "2.3.1"
4040
# idna is already required by url, added here to be explicit
4141
idna = "0.3.0"
4242
base64 = "0.13.1"
43+
num-bigint = "0.4.3"
4344

4445
[lib]
4546
name = "_pydantic_core"
@@ -50,7 +51,7 @@ crate-type = ["cdylib", "rlib"]
5051
extension-module = ["pyo3/extension-module"]
5152
# required for cargo bench
5253
auto-initialize = ["pyo3/auto-initialize"]
53-
default = ["mimalloc", "mimalloc/local_dynamic_tls", "pyo3/generate-import-lib"]
54+
default = ["mimalloc", "mimalloc/local_dynamic_tls", "pyo3/generate-import-lib", "pyo3/num-bigint"]
5455

5556
[profile.release]
5657
lto = "fat"

pydantic_core/core_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3772,8 +3772,8 @@ def definition_reference_schema(
37723772
'bool_parsing',
37733773
'int_type',
37743774
'int_parsing',
3775+
'int_parsing_size',
37753776
'int_from_float',
3776-
'int_overflow',
37773777
'float_type',
37783778
'float_parsing',
37793779
'bytes_type',

src/errors/types.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::Cow;
22
use std::fmt;
33

44
use ahash::AHashMap;
5+
use num_bigint::BigInt;
56
use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError};
67
use pyo3::once_cell::GILOnceCell;
78
use pyo3::prelude::*;
@@ -174,8 +175,8 @@ pub enum ErrorType {
174175
// int errors
175176
IntType,
176177
IntParsing,
178+
IntParsingSize,
177179
IntFromFloat,
178-
IntOverflow,
179180
// ---------------------
180181
// float errors
181182
FloatType,
@@ -489,7 +490,7 @@ impl ErrorType {
489490
Self::IntType => "Input should be a valid integer",
490491
Self::IntParsing => "Input should be a valid integer, unable to parse string as an integer",
491492
Self::IntFromFloat => "Input should be a valid integer, got a number with a fractional part",
492-
Self::IntOverflow => "Input integer too large to convert to 64-bit integer",
493+
Self::IntParsingSize => "Unable to parse input string as an integer, exceed maximum size",
493494
Self::FloatType => "Input should be a valid number",
494495
Self::FloatParsing => "Input should be a valid number, unable to parse string as an number",
495496
Self::BytesType => "Input should be a valid bytes",
@@ -699,6 +700,7 @@ impl ErrorType {
699700
#[derive(Clone, Debug)]
700701
pub enum Number {
701702
Int(i64),
703+
BigInt(BigInt),
702704
Float(f64),
703705
String(String),
704706
}
@@ -715,6 +717,12 @@ impl From<i64> for Number {
715717
}
716718
}
717719

720+
impl From<BigInt> for Number {
721+
fn from(i: BigInt) -> Self {
722+
Self::BigInt(i)
723+
}
724+
}
725+
718726
impl From<f64> for Number {
719727
fn from(f: f64) -> Self {
720728
Self::Float(f)
@@ -746,6 +754,7 @@ impl fmt::Display for Number {
746754
match self {
747755
Self::Float(s) => write!(f, "{s}"),
748756
Self::Int(i) => write!(f, "{i}"),
757+
Self::BigInt(i) => write!(f, "{i}"),
749758
Self::String(s) => write!(f, "{s}"),
750759
}
751760
}
@@ -754,6 +763,7 @@ impl ToPyObject for Number {
754763
fn to_object(&self, py: Python<'_>) -> PyObject {
755764
match self {
756765
Self::Int(i) => i.into_py(py),
766+
Self::BigInt(i) => i.clone().into_py(py),
757767
Self::Float(f) => f.into_py(py),
758768
Self::String(s) => s.into_py(py),
759769
}

src/input/input_json.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ impl<'a> Input<'a> for JsonInput {
107107
JsonInput::String(s) => str_as_bool(self, s),
108108
JsonInput::Int(int) => int_as_bool(self, *int),
109109
JsonInput::Float(float) => match float_as_int(self, *float) {
110-
Ok(int) => int_as_bool(self, int),
110+
Ok(int) => int.as_bool().ok_or_else(|| ValError::new(ErrorType::BoolParsing, self)),
111111
_ => Err(ValError::new(ErrorType::BoolType, self)),
112112
},
113113
_ => Err(ValError::new(ErrorType::BoolType, self)),
@@ -116,22 +116,23 @@ impl<'a> Input<'a> for JsonInput {
116116

117117
fn strict_int(&'a self) -> ValResult<EitherInt<'a>> {
118118
match self {
119-
JsonInput::Int(i) => Ok(EitherInt::Rust(*i)),
119+
JsonInput::Int(i) => Ok(EitherInt::I64(*i)),
120+
JsonInput::Uint(u) => Ok(EitherInt::U64(*u)),
120121
_ => Err(ValError::new(ErrorType::IntType, self)),
121122
}
122123
}
123124
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
124-
let int_result = match self {
125+
match self {
125126
JsonInput::Bool(b) => match *b {
126-
true => Ok(1),
127-
false => Ok(0),
127+
true => Ok(EitherInt::I64(1)),
128+
false => Ok(EitherInt::I64(0)),
128129
},
129-
JsonInput::Int(i) => Ok(*i),
130+
JsonInput::Int(i) => Ok(EitherInt::I64(*i)),
131+
JsonInput::Uint(u) => Ok(EitherInt::U64(*u)),
130132
JsonInput::Float(f) => float_as_int(self, *f),
131133
JsonInput::String(str) => str_as_int(self, str),
132134
_ => Err(ValError::new(ErrorType::IntType, self)),
133-
};
134-
int_result.map(EitherInt::Rust)
135+
}
135136
}
136137

137138
fn ultra_strict_float(&self) -> ValResult<f64> {
@@ -144,6 +145,7 @@ impl<'a> Input<'a> for JsonInput {
144145
match self {
145146
JsonInput::Float(f) => Ok(*f),
146147
JsonInput::Int(i) => Ok(*i as f64),
148+
JsonInput::Uint(u) => Ok(*u as f64),
147149
_ => Err(ValError::new(ErrorType::FloatType, self)),
148150
}
149151
}
@@ -155,6 +157,7 @@ impl<'a> Input<'a> for JsonInput {
155157
},
156158
JsonInput::Float(f) => Ok(*f),
157159
JsonInput::Int(i) => Ok(*i as f64),
160+
JsonInput::Uint(u) => Ok(*u as f64),
158161
JsonInput::String(str) => match str.parse::<f64>() {
159162
Ok(i) => Ok(i),
160163
Err(_) => Err(ValError::new(ErrorType::FloatParsing, self)),
@@ -363,7 +366,7 @@ impl<'a> Input<'a> for String {
363366
}
364367
fn lax_int(&'a self) -> ValResult<EitherInt<'a>> {
365368
match self.parse() {
366-
Ok(i) => Ok(EitherInt::Rust(i)),
369+
Ok(i) => Ok(EitherInt::I64(i)),
367370
Err(_) => Err(ValError::new(ErrorType::IntParsing, self)),
368371
}
369372
}

src/input/input_python.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ impl<'a> Input<'a> for PyAny {
260260
int_as_bool(self, int)
261261
} else if let Ok(float) = self.extract::<f64>() {
262262
match float_as_int(self, float) {
263-
Ok(int) => int_as_bool(self, int),
263+
Ok(int) => int.as_bool().ok_or_else(|| ValError::new(ErrorType::BoolParsing, self)),
264264
_ => Err(ValError::new(ErrorType::BoolType, self)),
265265
}
266266
} else {
@@ -287,10 +287,9 @@ impl<'a> Input<'a> for PyAny {
287287
if PyInt::is_exact_type_of(self) {
288288
Ok(EitherInt::Py(self))
289289
} else if let Some(cow_str) = maybe_as_string(self, ErrorType::IntParsing)? {
290-
let int = str_as_int(self, &cow_str)?;
291-
Ok(EitherInt::Rust(int))
290+
str_as_int(self, &cow_str)
292291
} else if let Ok(float) = self.extract::<f64>() {
293-
Ok(EitherInt::Rust(float_as_int(self, float)?))
292+
float_as_int(self, float)
294293
} else {
295294
Err(ValError::new(ErrorType::IntType, self))
296295
}

src/input/parse_json.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pub enum JsonInput {
1212
Null,
1313
Bool(bool),
1414
Int(i64),
15+
Uint(u64),
1516
Float(f64),
1617
String(String),
1718
Array(JsonArray),
@@ -26,6 +27,7 @@ impl ToPyObject for JsonInput {
2627
Self::Null => py.None(),
2728
Self::Bool(b) => b.into_py(py),
2829
Self::Int(i) => i.into_py(py),
30+
Self::Uint(i) => i.into_py(py),
2931
Self::Float(f) => f.into_py(py),
3032
Self::String(s) => s.into_py(py),
3133
Self::Array(v) => PyList::new(py, v.iter().map(|v| v.to_object(py))).into_py(py),
@@ -64,7 +66,10 @@ impl<'de> Deserialize<'de> for JsonInput {
6466
}
6567

6668
fn visit_u64<E>(self, value: u64) -> Result<JsonInput, E> {
67-
Ok(JsonInput::Int(value as i64))
69+
match i64::try_from(value) {
70+
Ok(i) => Ok(JsonInput::Int(i)),
71+
Err(_) => Ok(JsonInput::Uint(value)),
72+
}
6873
}
6974

7075
fn visit_f64<E>(self, value: f64) -> Result<JsonInput, E> {

src/input/return_enums.rs

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::borrow::Cow;
22
use std::slice::Iter as SliceIter;
33

4+
use num_bigint::BigInt;
5+
46
use pyo3::prelude::*;
57
use pyo3::types::iter::PyDictIterator;
68
use pyo3::types::{
@@ -820,25 +822,72 @@ impl<'a> IntoPy<PyObject> for EitherBytes<'a> {
820822

821823
#[cfg_attr(debug_assertions, derive(Debug))]
822824
pub enum EitherInt<'a> {
823-
Rust(i64),
825+
I64(i64),
826+
U64(u64),
827+
BigInt(BigInt),
824828
Py(&'a PyAny),
825829
}
826830

827-
impl<'a> TryInto<i64> for EitherInt<'a> {
828-
type Error = ValError<'a>;
831+
impl<'a> EitherInt<'a> {
832+
pub fn into_i64(self, py: Python<'a>) -> ValResult<'a, i64> {
833+
match self {
834+
EitherInt::I64(i) => Ok(i),
835+
EitherInt::U64(u) => match i64::try_from(u) {
836+
Ok(u) => Ok(u),
837+
Err(_) => Err(ValError::new(ErrorType::IntParsingSize, u.into_py(py).into_ref(py))),
838+
},
839+
EitherInt::BigInt(u) => match i64::try_from(u) {
840+
Ok(u) => Ok(u),
841+
Err(e) => Err(ValError::new(
842+
ErrorType::IntParsingSize,
843+
e.into_original().into_py(py).into_ref(py),
844+
)),
845+
},
846+
EitherInt::Py(i) => i.extract().map_err(|_| ValError::new(ErrorType::IntParsingSize, i)),
847+
}
848+
}
849+
850+
pub fn as_bool(&self) -> Option<bool> {
851+
match self {
852+
EitherInt::I64(i) => match i {
853+
0 => Some(false),
854+
1 => Some(true),
855+
_ => None,
856+
},
857+
EitherInt::U64(u) => match u {
858+
0 => Some(false),
859+
1 => Some(true),
860+
_ => None,
861+
},
862+
EitherInt::BigInt(i) => match u8::try_from(i) {
863+
Ok(0) => Some(false),
864+
Ok(1) => Some(true),
865+
_ => None,
866+
},
867+
EitherInt::Py(i) => match i.extract::<u8>() {
868+
Ok(0) => Some(false),
869+
Ok(1) => Some(true),
870+
_ => None,
871+
},
872+
}
873+
}
829874

830-
fn try_into(self) -> ValResult<'a, i64> {
875+
pub fn as_bigint(&self) -> PyResult<BigInt> {
831876
match self {
832-
EitherInt::Rust(i) => Ok(i),
833-
EitherInt::Py(i) => i.extract().map_err(|_| ValError::new(ErrorType::IntOverflow, i)),
877+
EitherInt::I64(i) => Ok(BigInt::from(*i)),
878+
EitherInt::U64(u) => Ok(BigInt::from(*u)),
879+
EitherInt::BigInt(i) => Ok(i.clone()),
880+
EitherInt::Py(i) => i.extract(),
834881
}
835882
}
836883
}
837884

838885
impl<'a> IntoPy<PyObject> for EitherInt<'a> {
839886
fn into_py(self, py: Python<'_>) -> PyObject {
840887
match self {
841-
Self::Rust(int) => int.into_py(py),
888+
Self::I64(int) => int.into_py(py),
889+
Self::U64(int) => int.into_py(py),
890+
Self::BigInt(int) => int.into_py(py),
842891
Self::Py(int) => int.into_py(py),
843892
}
844893
}

0 commit comments

Comments
 (0)