Skip to content

Commit bd82468

Browse files
committed
using PyInt for int inequality checks
1 parent 6220455 commit bd82468

File tree

2 files changed

+69
-23
lines changed

2 files changed

+69
-23
lines changed

src/validators/int.rs

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use pyo3::intern;
22
use pyo3::prelude::*;
3-
use pyo3::types::PyDict;
3+
use pyo3::types::{PyDict, PyInt};
44

55
use crate::build_tools::is_strict;
66
use crate::errors::{ErrorType, ValError, ValResult};
@@ -72,11 +72,11 @@ impl Validator for IntValidator {
7272
#[derive(Debug, Clone)]
7373
pub struct ConstrainedIntValidator {
7474
strict: bool,
75-
multiple_of: Option<i64>,
76-
le: Option<i64>,
77-
lt: Option<i64>,
78-
ge: Option<i64>,
79-
gt: Option<i64>,
75+
multiple_of: Option<Py<PyInt>>,
76+
le: Option<Py<PyInt>>,
77+
lt: Option<Py<PyInt>>,
78+
ge: Option<Py<PyInt>>,
79+
gt: Option<Py<PyInt>>,
8080
}
8181

8282
impl Validator for ConstrainedIntValidator {
@@ -89,38 +89,62 @@ impl Validator for ConstrainedIntValidator {
8989
_recursion_guard: &'s mut RecursionGuard,
9090
) -> ValResult<'data, PyObject> {
9191
let either_int = input.validate_int(extra.strict.unwrap_or(self.strict))?;
92-
let int: i64 = either_int.try_into()?;
93-
if let Some(multiple_of) = self.multiple_of {
94-
if int % multiple_of != 0 {
92+
let int_obj = either_int.into_py(py);
93+
let int = int_obj.as_ref(py);
94+
95+
if let Some(ref multiple_of) = self.multiple_of {
96+
let rem: i64 = int.call_method1(intern!(py, "__mod__"), (multiple_of,))?.extract()?;
97+
if rem != 0 {
9598
return Err(ValError::new(
9699
ErrorType::MultipleOf {
97-
multiple_of: multiple_of.into(),
100+
multiple_of: multiple_of.extract::<i64>(py)?.into(),
98101
},
99102
input,
100103
));
101104
}
102105
}
103-
if let Some(le) = self.le {
104-
if int > le {
105-
return Err(ValError::new(ErrorType::LessThanEqual { le: le.into() }, input));
106+
107+
if let Some(ref le) = self.le {
108+
if !int.le(le)? {
109+
return Err(ValError::new(
110+
ErrorType::LessThanEqual {
111+
le: le.extract::<i64>(py)?.into(),
112+
},
113+
input,
114+
));
106115
}
107116
}
108-
if let Some(lt) = self.lt {
109-
if int >= lt {
110-
return Err(ValError::new(ErrorType::LessThan { lt: lt.into() }, input));
117+
if let Some(ref lt) = self.lt {
118+
if !int.lt(lt)? {
119+
return Err(ValError::new(
120+
ErrorType::LessThan {
121+
lt: lt.extract::<i64>(py)?.into(),
122+
},
123+
input,
124+
));
111125
}
112126
}
113-
if let Some(ge) = self.ge {
114-
if int < ge {
115-
return Err(ValError::new(ErrorType::GreaterThanEqual { ge: ge.into() }, input));
127+
if let Some(ref ge) = self.ge {
128+
if !int.ge(ge)? {
129+
return Err(ValError::new(
130+
ErrorType::GreaterThanEqual {
131+
ge: ge.extract::<i64>(py)?.into(),
132+
},
133+
input,
134+
));
116135
}
117136
}
118-
if let Some(gt) = self.gt {
119-
if int <= gt {
120-
return Err(ValError::new(ErrorType::GreaterThan { gt: gt.into() }, input));
137+
if let Some(ref gt) = self.gt {
138+
if !int.gt(gt)? {
139+
return Err(ValError::new(
140+
ErrorType::GreaterThan {
141+
gt: gt.extract::<i64>(py)?.into(),
142+
},
143+
input,
144+
));
121145
}
122146
}
123-
Ok(int.into_py(py))
147+
Ok(int_obj)
124148
}
125149

126150
fn different_strict_behavior(

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,3 +1388,25 @@ def test_strict_int(benchmark):
13881388
v = SchemaValidator(core_schema.int_schema(strict=True))
13891389

13901390
benchmark(v.validate_python, 42)
1391+
1392+
1393+
@pytest.mark.benchmark(group='int_range')
1394+
def test_int_range(benchmark):
1395+
v = SchemaValidator(core_schema.int_schema(gt=0, lt=100))
1396+
1397+
assert v.validate_python(42) == 42
1398+
with pytest.raises(ValidationError, match='Input should be greater than 0'):
1399+
v.validate_python(0)
1400+
1401+
benchmark(v.validate_python, 42)
1402+
1403+
1404+
@pytest.mark.benchmark(group='int_range')
1405+
def test_int_range_json(benchmark):
1406+
v = SchemaValidator(core_schema.int_schema(gt=0, lt=100))
1407+
1408+
assert v.validate_json('42') == 42
1409+
with pytest.raises(ValidationError, match='Input should be greater than 0'):
1410+
v.validate_python('0')
1411+
1412+
benchmark(v.validate_json, '42')

0 commit comments

Comments
 (0)