Skip to content

Commit 14df075

Browse files
committed
tests + fix
1 parent bb67044 commit 14df075

File tree

6 files changed

+48
-24
lines changed

6 files changed

+48
-24
lines changed

src/input/input_abstract.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,16 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
9898

9999
fn validate_float(&self, strict: bool) -> ValMatch<EitherFloat<'_>>;
100100

101-
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
101+
fn validate_decimal(&self, strict: bool, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
102102
if strict {
103103
self.strict_decimal(py)
104104
} else {
105105
self.lax_decimal(py)
106106
}
107107
}
108-
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>>;
108+
fn strict_decimal(&self, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>>;
109109
#[cfg_attr(has_coverage_attribute, coverage(off))]
110-
fn lax_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
110+
fn lax_decimal(&self, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
111111
self.strict_decimal(py)
112112
}
113113

src/input/input_json.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,12 +165,13 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
165165
}
166166
}
167167

168-
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
168+
fn strict_decimal(&self, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
169169
match self {
170-
JsonValue::Float(f) => create_decimal(&PyString::new_bound(py, &f.to_string()), self),
171-
170+
JsonValue::Float(f) => {
171+
create_decimal(&PyString::new_bound(py, &f.to_string()), self).map(ValidationMatch::strict)
172+
}
172173
JsonValue::Str(..) | JsonValue::Int(..) | JsonValue::BigInt(..) => {
173-
create_decimal(self.to_object(py).bind(py), self)
174+
create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::lax)
174175
}
175176
_ => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
176177
}
@@ -373,8 +374,8 @@ impl<'py> Input<'py> for str {
373374
str_as_float(self, self).map(ValidationMatch::lax)
374375
}
375376

376-
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
377-
create_decimal(self.to_object(py).bind(py), self)
377+
fn strict_decimal(&self, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
378+
create_decimal(self.to_object(py).bind(py), self).map(ValidationMatch::strict)
378379
}
379380

380381
type Dict<'a> = Never;

src/input/input_python.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
249249
} else if self.is_exact_instance_of::<PyFloat>() {
250250
float_as_int(self, self.extract::<f64>()?)
251251
} else if let Ok(decimal) = self.strict_decimal(self.py()) {
252-
decimal_as_int(self, &decimal)
252+
decimal_as_int(self, &decimal.into_inner())
253253
} else if let Ok(float) = self.extract::<f64>() {
254254
float_as_int(self, float)
255255
} else if let Some(enum_val) = maybe_as_enum(self) {
@@ -307,16 +307,16 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
307307
Err(ValError::new(ErrorTypeDefaults::FloatType, self))
308308
}
309309

310-
fn strict_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
310+
fn strict_decimal(&self, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
311311
let decimal_type = get_decimal_type(py);
312312
// Fast path for existing decimal objects
313313
if self.is_exact_instance(decimal_type) {
314-
return Ok(self.to_owned());
314+
return Ok(ValidationMatch::exact(self.to_owned()));
315315
}
316316

317317
// Try subclasses of decimals, they will be upcast to Decimal
318318
if self.is_instance(decimal_type)? {
319-
return create_decimal(self, self);
319+
return create_decimal(self, self).map(ValidationMatch::strict);
320320
}
321321

322322
Err(ValError::new(
@@ -331,24 +331,27 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
331331
))
332332
}
333333

334-
fn lax_decimal(&self, py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
334+
fn lax_decimal(&self, py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
335335
let decimal_type = get_decimal_type(py);
336336
// Fast path for existing decimal objects
337337
if self.is_exact_instance(decimal_type) {
338-
return Ok(self.to_owned().clone());
338+
return Ok(ValidationMatch::exact(self.to_owned().clone()));
339339
}
340340

341-
if self.is_instance_of::<PyString>() || (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>()) {
341+
// TODO: I can see the case for int and float being strict - wdyt @davidhewitt?
342+
return if self.is_instance_of::<PyString>()
343+
|| (self.is_instance_of::<PyInt>() && !self.is_instance_of::<PyBool>())
344+
{
342345
// checking isinstance for str / int / bool is fast compared to decimal / float
343-
create_decimal(self, self)
346+
create_decimal(self, self).map(ValidationMatch::lax)
344347
} else if self.is_instance(decimal_type)? {
345348
// upcast subclasses to decimal
346-
return create_decimal(self, self);
349+
create_decimal(self, self).map(ValidationMatch::strict)
347350
} else if self.is_instance_of::<PyFloat>() {
348-
create_decimal(self.str()?.as_any(), self)
351+
create_decimal(self.str()?.as_any(), self).map(ValidationMatch::lax)
349352
} else {
350353
Err(ValError::new(ErrorTypeDefaults::DecimalType, self))
351-
}
354+
};
352355
}
353356

354357
type Dict<'a> = GenericPyMapping<'a, 'py> where Self: 'a;

src/input/input_string.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ impl<'py> Input<'py> for StringMapping<'py> {
141141
}
142142
}
143143

144-
fn strict_decimal(&self, _py: Python<'py>) -> ValResult<Bound<'py, PyAny>> {
144+
fn strict_decimal(&self, _py: Python<'py>) -> ValMatch<Bound<'py, PyAny>> {
145145
match self {
146-
Self::String(s) => create_decimal(s, self),
146+
Self::String(s) => create_decimal(s, self).map(ValidationMatch::strict),
147147
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::DecimalType, self)),
148148
}
149149
}

src/validators/decimal.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ impl Validator for DecimalValidator {
122122
input: &(impl Input<'py> + ?Sized),
123123
state: &mut ValidationState<'_, 'py>,
124124
) -> ValResult<PyObject> {
125-
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?;
125+
let decimal = input.validate_decimal(state.strict_or(self.strict), py)?.unpack(state);
126126

127127
if !self.allow_inf_nan || self.check_digits {
128128
if !decimal.call_method0(intern!(py, "is_finite"))?.extract()? {

tests/validators/test_decimal.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010
from dirty_equals import FunctionCheck, IsStr
1111

12-
from pydantic_core import SchemaValidator, ValidationError
12+
from pydantic_core import SchemaValidator, ValidationError, core_schema
1313

1414
from ..conftest import Err, PyAndJson, plain_repr
1515

@@ -467,3 +467,23 @@ def test_validate_max_digits_and_decimal_places_edge_case() -> None:
467467
assert v.validate_python(Decimal('9999999999999999.999999999999999999')) == Decimal(
468468
'9999999999999999.999999999999999999'
469469
)
470+
471+
472+
def test_str_validation_w_strict() -> None:
473+
s = SchemaValidator(core_schema.decimal_schema(strict=True))
474+
475+
with pytest.raises(ValidationError):
476+
assert s.validate_python('1.23')
477+
478+
479+
def test_str_validation_w_lax() -> None:
480+
s = SchemaValidator(core_schema.decimal_schema(strict=False))
481+
482+
assert s.validate_python('1.23') == Decimal('1.23')
483+
484+
485+
def test_union_with_str_prefers_str() -> None:
486+
s = SchemaValidator(core_schema.union_schema([core_schema.decimal_schema(), core_schema.str_schema()]))
487+
488+
assert s.validate_python('1.23') == '1.23'
489+
assert s.validate_python(1.23) == Decimal('1.23')

0 commit comments

Comments
 (0)