Skip to content

Commit 08a99b5

Browse files
Introduce exactness into Decimal validation logic (#1405)
1 parent fdd1e85 commit 08a99b5

File tree

6 files changed

+57
-51
lines changed

6 files changed

+57
-51
lines changed

src/input/input_abstract.rs

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
9898

9999
fn validate_float(&self, strict: bool) -> ValMatch<EitherFloat<'_>>;
100100

101-
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
102-
if strict {
103-
self.strict_decimal(py)
104-
} else {
105-
self.lax_decimal(py)
106-
}
107-
}
108-
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>>;
109-
#[cfg_attr(has_coverage_attribute, coverage(off))]
110-
fn lax_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
111-
self.strict_decimal(py)
112-
}
101+
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;
113102

114103
type Dict<'a>: ValidatedDict<'py>
115104
where

src/input/input_json.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,13 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
167167
}
168168
}
169169

170-
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
170+
fn validate_decimal(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
171171
match self {
172-
JsonValue::Float(f) => create_decimal(&PyString::new_bound(py, &f.to_string()), self),
173-
172+
JsonValue::Float(f) => {
173+
create_decimal(&PyString::new_bound(py, &f.to_string()), self).map(ValidationMatch::strict)
174+
}
174175
JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => {
175-
create_decimal(self.to_object(py).bind(py), self)
176+
create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::strict)
176177
}
177178
_ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
178179
}
@@ -399,8 +400,8 @@ impl<'py> Input<'py> for str {
399400
str_as_float(self, self).map(ValidationMatch::lax)
400401
}
401402

402-
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
403-
create_decimal(self.to_object(py).bind(py), self)
403+
fn validate_decimal(&self, _strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
404+
create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::lax)
404405
}
405406

406407
type Dict<'a> = Never;

src/input/input_python.rs

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
251251
str_as_int(self, s)
252252
} else if self.is_exact_instance_of::<PyFloat>() {
253253
float_as_int(self, self.extract::<f64>()?)
254-
} else if let Ok(decimal) = self.strict_decimal(self.py()) {
255-
decimal_as_int(self, &decimal)
254+
} else if let Ok(decimal) = self.validate_decimal(true, self.py()) {
255+
decimal_as_int(self, &decimal.into_inner())
256256
} else if let Ok(float) = self.extract::<f64>() {
257257
float_as_int(self, float)
258258
} else if let Some(enum_val) = maybe_as_enum(self) {
@@ -310,48 +310,44 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
310310
Err(ValError::new(ErrorTypeDefaults::FloatType, self))
311311
}
312312

313-
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
313+
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
314314
let decimal_type = get_decimal_type(py);
315+
315316
// Fast path for existing decimal objects
316317
if self.is_exact_instance(decimal_type) {
317-
return Ok(self.to_owned());
318+
return Ok(ValidationMatch::exact(self.to_owned().clone()));
319+
}
320+
321+
if !strict {
322+
if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>())
323+
{
324+
// Checking isinstance for str / int / bool is fast compared to decimal / float
325+
return create_decimal(self, self).map(ValidationMatch::lax);
326+
}
327+
328+
if self.is_instance_of::<PyFloat>() {
329+
return create_decimal(self.str()?.as_any(), self).map(ValidationMatch::lax);
330+
}
318331
}
319332

320-
// Try subclasses of decimals, they will be upcast to Decimal
321333
if self.is_instance(decimal_type)? {
322-
return create_decimal(self, self);
334+
// Upcast subclasses to decimal
335+
return create_decimal(self, self).map(ValidationMatch::strict);
323336
}
324337

325-
Err(ValError::new(
338+
let error_type = if strict {
326339
ErrorType::IsInstanceOf {
327340
class: decimal_type
328341
.qualname()
329342
.and_then(|name| name.extract())
330343
.unwrap_or_else(|_| "Decimal".to_owned()),
331344
context: None,
332-
},
333-
self,
334-
))
335-
}
336-
337-
fn lax_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
338-
let decimal_type = get_decimal_type(py);
339-
// Fast path for existing decimal objects
340-
if self.is_exact_instance(decimal_type) {
341-
return Ok(self.to_owned().clone());
342-
}
343-
344-
if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>()) {
345-
// checking isinstance for str / int / bool is fast compared to decimal / float
346-
create_decimal(self, self)
347-
} else if self.is_instance(decimal_type)? {
348-
// upcast subclasses to decimal
349-
return create_decimal(self, self);
350-
} else if self.is_instance_of::<PyFloat>() {
351-
create_decimal(self.str()?.as_any(), self)
345+
}
352346
} else {
353-
Err(ValError::new(ErrorTypeDefaults::DecimalType, self))
354-
}
347+
ErrorTypeDefaults::DecimalType
348+
};
349+
350+
Err(ValError::new(error_type, self))
355351
}
356352

357353
type Dict<'a> = GenericPyMapping<'a, 'py> where Self: 'a;

src/input/input_string.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ impl<'py> Input<'py> for StringMapping<'py> {
143143
}
144144
}
145145

146-
fn strict_decimal(&self, _py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
146+
fn validate_decimal(&self, _strict: bool, _py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
147147
match self {
148-
Self::String(s) => create_decimal(s, self),
148+
Self::String(s) => create_decimal(s, self).map(ValidationMatch::strict),
149149
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
150150
}
151151
}

src/validators/decimal.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl Validator for DecimalValidator {
122122
input: &(impl Input<'py> + ?Sized),
123123
state: &mut ValidationState<'_, 'py>,
124124
) -> ValResult<PyObject> {
125-
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?;
125+
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?.unpack(state);
126126

127127
if !self.allow_inf_nan || self.check_digits {
128128
if !decimal.call_method0(intern!(py, "is_finite"))?.extract()? {

tests/validators/test_decimal.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from dirty_equals import FunctionCheck, IsStr
1111

12-
from pydantic_core import SchemaValidator, ValidationError
12+
from pydantic_core import SchemaValidator, ValidationError, core_schema
1313

1414
from ..conftest import Err, PyAndJson, plain_repr
1515

@@ -467,3 +467,23 @@ def test_validate_max_digits_and_decimal_places_edge_case() -> None:
467467
assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal(
468468
'9999999999999999.999999999999999999'
469469
)
470+
471+
472+
def test_str_validation_w_strict() -> None:
473+
s = SchemaValidator(core_schema.decimal_schema(strict=True))
474+
475+
with pytest.raises(ValidationError):
476+
assert s.validate_python('1.23')
477+
478+
479+
def test_str_validation_w_lax() -> None:
480+
s = SchemaValidator(core_schema.decimal_schema(strict=False))
481+
482+
assert s.validate_python('1.23') == Decimal('1.23')
483+
484+
485+
def test_union_with_str_prefers_str() -> None:
486+
s = SchemaValidator(core_schema.union_schema([core_schema.decimal_schema(), core_schema.str_schema()]))
487+
488+
assert s.validate_python('1.23') == '1.23'
489+
assert s.validate_python(1.23) == Decimal('1.23')

0 commit comments

Comments
 (0)