Skip to content

Commit a55bd2e

Browse files
Add eitherfloat (#691)
Co-authored-by: David Montague <[email protected]>
1 parent 703b7b2 commit a55bd2e

File tree

11 files changed

+86
-50
lines changed

11 files changed

+86
-50
lines changed

src/errors/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ impl ErrorType {
497497
Self::IntFromFloat => "Input should be a valid integer, got a number with a fractional part",
498498
Self::IntParsingSize => "Unable to parse input string as an integer, exceeded maximum size",
499499
Self::FloatType => "Input should be a valid number",
500-
Self::FloatParsing => "Input should be a valid number, unable to parse string as an number",
500+
Self::FloatParsing => "Input should be a valid number, unable to parse string as a number",
501501
Self::BytesType => "Input should be a valid bytes",
502502
Self::BytesTooShort {..} => "Data should have at least {min_length} bytes",
503503
Self::BytesTooLong {..} => "Data should have at most {max_length} bytes",

src/input/input_abstract.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{PyMultiHostUrl, PyUrl};
88

99
use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
1010
use super::return_enums::{EitherBytes, EitherInt, EitherString};
11-
use super::{GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonInput};
11+
use super::{EitherFloat, GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonInput};
1212

1313
#[derive(Debug, Clone, Copy)]
1414
pub enum InputType {
@@ -136,7 +136,7 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
136136
self.strict_str()
137137
}
138138

139-
fn validate_float(&self, strict: bool, ultra_strict: bool) -> ValResult<f64> {
139+
fn validate_float(&'a self, strict: bool, ultra_strict: bool) -> ValResult<EitherFloat<'a>> {
140140
if ultra_strict {
141141
self.ultra_strict_float()
142142
} else if strict {
@@ -145,10 +145,10 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
145145
self.lax_float()
146146
}
147147
}
148-
fn ultra_strict_float(&self) -> ValResult<f64>;
149-
fn strict_float(&self) -> ValResult<f64>;
148+
fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
149+
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>>;
150150
#[cfg_attr(has_no_coverage, no_coverage)]
151-
fn lax_float(&self) -> ValResult<f64> {
151+
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
152152
self.strict_float()
153153
}
154154

src/input/input_json.rs

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use super::datetime::{
1010
use super::parse_json::JsonArray;
1111
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_int};
1212
use super::{
13-
EitherBytes, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable, GenericIterator,
14-
GenericMapping, Input, JsonArgs, JsonInput,
13+
EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
14+
GenericIterator, GenericMapping, Input, JsonArgs, JsonInput,
1515
};
1616

1717
impl<'a> Input<'a> for JsonInput {
@@ -135,31 +135,31 @@ impl<'a> Input<'a> for JsonInput {
135135
}
136136
}
137137

138-
fn ultra_strict_float(&self) -> ValResult<f64> {
138+
fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
139139
match self {
140-
JsonInput::Float(f) => Ok(*f),
140+
JsonInput::Float(f) => Ok(EitherFloat::F64(*f)),
141141
_ => Err(ValError::new(ErrorType::FloatType, self)),
142142
}
143143
}
144-
fn strict_float(&self) -> ValResult<f64> {
144+
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
145145
match self {
146-
JsonInput::Float(f) => Ok(*f),
147-
JsonInput::Int(i) => Ok(*i as f64),
148-
JsonInput::Uint(u) => Ok(*u as f64),
146+
JsonInput::Float(f) => Ok(EitherFloat::F64(*f)),
147+
JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)),
148+
JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)),
149149
_ => Err(ValError::new(ErrorType::FloatType, self)),
150150
}
151151
}
152-
fn lax_float(&self) -> ValResult<f64> {
152+
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
153153
match self {
154154
JsonInput::Bool(b) => match *b {
155-
true => Ok(1.0),
156-
false => Ok(0.0),
155+
true => Ok(EitherFloat::F64(1.0)),
156+
false => Ok(EitherFloat::F64(0.0)),
157157
},
158-
JsonInput::Float(f) => Ok(*f),
159-
JsonInput::Int(i) => Ok(*i as f64),
160-
JsonInput::Uint(u) => Ok(*u as f64),
158+
JsonInput::Float(f) => Ok(EitherFloat::F64(*f)),
159+
JsonInput::Int(i) => Ok(EitherFloat::F64(*i as f64)),
160+
JsonInput::Uint(u) => Ok(EitherFloat::F64(*u as f64)),
161161
JsonInput::String(str) => match str.parse::<f64>() {
162-
Ok(i) => Ok(i),
162+
Ok(i) => Ok(EitherFloat::F64(i)),
163163
Err(_) => Err(ValError::new(ErrorType::FloatParsing, self)),
164164
},
165165
_ => Err(ValError::new(ErrorType::FloatType, self)),
@@ -372,16 +372,16 @@ impl<'a> Input<'a> for String {
372372
}
373373

374374
#[cfg_attr(has_no_coverage, no_coverage)]
375-
fn ultra_strict_float(&self) -> ValResult<f64> {
375+
fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
376376
self.strict_float()
377377
}
378378
#[cfg_attr(has_no_coverage, no_coverage)]
379-
fn strict_float(&self) -> ValResult<f64> {
379+
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
380380
Err(ValError::new(ErrorType::FloatType, self))
381381
}
382-
fn lax_float(&self) -> ValResult<f64> {
382+
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
383383
match self.parse() {
384-
Ok(i) => Ok(i),
384+
Ok(f) => Ok(EitherFloat::F64(f)),
385385
Err(_) => Err(ValError::new(ErrorType::FloatParsing, self)),
386386
}
387387
}

src/input/input_python.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::str::from_utf8;
33

44
use pyo3::prelude::*;
55
use pyo3::types::{
6-
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFrozenSet, PyInt, PyIterator, PyList,
6+
PyBool, PyByteArray, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyFrozenSet, PyInt, PyIterator, PyList,
77
PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, PyType,
88
};
99
#[cfg(not(PyPy))]
@@ -21,8 +21,8 @@ use super::datetime::{
2121
};
2222
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_int};
2323
use super::{
24-
py_string_str, EitherBytes, EitherInt, EitherString, EitherTimedelta, GenericArguments, GenericIterable,
25-
GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
24+
py_string_str, EitherBytes, EitherFloat, EitherInt, EitherString, EitherTimedelta, GenericArguments,
25+
GenericIterable, GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
2626
};
2727

2828
#[cfg(not(PyPy))]
@@ -308,35 +308,40 @@ impl<'a> Input<'a> for PyAny {
308308
}
309309
}
310310

311-
fn ultra_strict_float(&self) -> ValResult<f64> {
311+
fn ultra_strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
312312
if self.is_instance_of::<PyInt>() {
313313
Err(ValError::new(ErrorType::FloatType, self))
314-
} else if let Ok(float) = self.extract::<f64>() {
315-
Ok(float)
314+
} else if self.is_instance_of::<PyFloat>() {
315+
Ok(EitherFloat::Py(self))
316316
} else {
317317
Err(ValError::new(ErrorType::FloatType, self))
318318
}
319319
}
320-
fn strict_float(&self) -> ValResult<f64> {
321-
if let Ok(float) = self.extract::<f64>() {
320+
fn strict_float(&'a self) -> ValResult<EitherFloat<'a>> {
321+
if PyFloat::is_exact_type_of(self) {
322+
Ok(EitherFloat::Py(self))
323+
} else if let Ok(float) = self.extract::<f64>() {
322324
// bools are cast to floats as either 0.0 or 1.0, so check for bool type in this specific case
323325
if (float == 0.0 || float == 1.0) && PyBool::is_exact_type_of(self) {
324326
Err(ValError::new(ErrorType::FloatType, self))
325327
} else {
326-
Ok(float)
328+
Ok(EitherFloat::Py(self))
327329
}
328330
} else {
329331
Err(ValError::new(ErrorType::FloatType, self))
330332
}
331333
}
332-
fn lax_float(&self) -> ValResult<f64> {
333-
if let Ok(float) = self.extract::<f64>() {
334-
Ok(float)
334+
335+
fn lax_float(&'a self) -> ValResult<EitherFloat<'a>> {
336+
if PyFloat::is_exact_type_of(self) {
337+
Ok(EitherFloat::Py(self))
335338
} else if let Some(cow_str) = maybe_as_string(self, ErrorType::FloatParsing)? {
336339
match cow_str.as_ref().parse::<f64>() {
337-
Ok(i) => Ok(i),
340+
Ok(i) => Ok(EitherFloat::F64(i)),
338341
Err(_) => Err(ValError::new(ErrorType::FloatParsing, self)),
339342
}
343+
} else if let Ok(float) = self.extract::<f64>() {
344+
Ok(EitherFloat::F64(float))
340345
} else {
341346
Err(ValError::new(ErrorType::FloatType, self))
342347
}

src/input/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ pub(crate) use datetime::{
1717
pub(crate) use input_abstract::{Input, InputType};
1818
pub(crate) use parse_json::{JsonInput, JsonObject};
1919
pub(crate) use return_enums::{
20-
py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherInt, EitherString,
20+
py_string_str, AttributesGenericIterator, DictGenericIterator, EitherBytes, EitherFloat, EitherInt, EitherString,
2121
GenericArguments, GenericIterable, GenericIterator, GenericMapping, JsonArgs, JsonObjectGenericIterator,
2222
MappingGenericIterator, PyArgs,
2323
};

src/input/return_enums.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,3 +892,29 @@ impl<'a> IntoPy<PyObject> for EitherInt<'a> {
892892
}
893893
}
894894
}
895+
896+
#[cfg_attr(debug_assertions, derive(Debug))]
897+
pub enum EitherFloat<'a> {
898+
F64(f64),
899+
Py(&'a PyAny),
900+
}
901+
902+
impl<'a> TryInto<f64> for EitherFloat<'a> {
903+
type Error = ValError<'a>;
904+
905+
fn try_into(self) -> ValResult<'a, f64> {
906+
match self {
907+
EitherFloat::F64(f) => Ok(f),
908+
EitherFloat::Py(i) => i.extract().map_err(|_| ValError::new(ErrorType::FloatParsing, i)),
909+
}
910+
}
911+
}
912+
913+
impl<'a> IntoPy<PyObject> for EitherFloat<'a> {
914+
fn into_py(self, py: Python<'_>) -> PyObject {
915+
match self {
916+
Self::F64(float) => float.into_py(py),
917+
Self::Py(float) => float.into_py(py),
918+
}
919+
}
920+
}

src/validators/float.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,13 @@ impl Validator for FloatValidator {
6969
_definitions: &'data Definitions<CombinedValidator>,
7070
_recursion_guard: &'s mut RecursionGuard,
7171
) -> ValResult<'data, PyObject> {
72-
let float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?;
72+
let either_float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?;
73+
let float: f64 = either_float.try_into()?;
7374
if !self.allow_inf_nan && !float.is_finite() {
7475
return Err(ValError::new(ErrorType::FiniteNumber, input));
76+
} else {
77+
Ok(float.into_py(py))
7578
}
76-
Ok(float.into_py(py))
7779
}
7880

7981
fn different_strict_behavior(
@@ -113,7 +115,8 @@ impl Validator for ConstrainedFloatValidator {
113115
_definitions: &'data Definitions<CombinedValidator>,
114116
_recursion_guard: &'s mut RecursionGuard,
115117
) -> ValResult<'data, PyObject> {
116-
let float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?;
118+
let either_float = input.validate_float(extra.strict.unwrap_or(self.strict), extra.ultra_strict)?;
119+
let float: f64 = either_float.try_into()?;
117120
if !self.allow_inf_nan && !float.is_finite() {
118121
return Err(ValError::new(ErrorType::FiniteNumber, input));
119122
}

tests/test_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def f(input_value, info):
231231
('less_than', 'Input should be less than 42.1', {'lt': 42.1}),
232232
('less_than_equal', 'Input should be less than or equal to 42.1', {'le': 42.1}),
233233
('float_type', 'Input should be a valid number', None),
234-
('float_parsing', 'Input should be a valid number, unable to parse string as an number', None),
234+
('float_parsing', 'Input should be a valid number, unable to parse string as a number', None),
235235
('bytes_type', 'Input should be a valid bytes', None),
236236
('bytes_too_short', 'Data should have at least 42 bytes', {'min_length': 42}),
237237
('bytes_too_long', 'Data should have at most 42 bytes', {'max_length': 42}),

tests/test_json.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_int(input_value, expected):
9090
('"123.4"', 123.4),
9191
('"123.0"', 123.0),
9292
('"123"', 123.0),
93-
('"string"', Err('Input should be a valid number, unable to parse string as an number [type=float_parsing,')),
93+
('"string"', Err('Input should be a valid number, unable to parse string as a number [type=float_parsing,')),
9494
],
9595
)
9696
def test_float(input_value, expected):

tests/test_misc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class MyModel:
126126
{
127127
'type': 'float_parsing',
128128
'loc': ('x',),
129-
'msg': 'Input should be a valid number, unable to parse string as an number',
129+
'msg': 'Input should be a valid number, unable to parse string as a number',
130130
'input': 'x' * 60,
131131
},
132132
{
@@ -139,7 +139,7 @@ class MyModel:
139139
assert repr(exc_info.value) == (
140140
'2 validation errors for MyModel\n'
141141
'x\n'
142-
' Input should be a valid number, unable to parse string as an number '
142+
' Input should be a valid number, unable to parse string as a number '
143143
"[type=float_parsing, input_value='xxxxxxxxxxxxxxxxxxxxxxxx...xxxxxxxxxxxxxxxxxxxxxxx', input_type=str]\n"
144144
f' For further information visit https://errors.pydantic.dev/{__version__}/v/float_parsing\n'
145145
'y\n'

tests/validators/test_float.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from ..conftest import Err, PyAndJson, plain_repr
1212

13+
f64_max = 1.7976931348623157e308
14+
1315

1416
@pytest.mark.parametrize(
1517
'input_value,expected',
@@ -24,7 +26,7 @@
2426
(1e10, 1e10),
2527
(True, 1),
2628
(False, 0),
27-
('wrong', Err('Input should be a valid number, unable to parse string as an number [type=float_parsing')),
29+
('wrong', Err('Input should be a valid number, unable to parse string as a number [type=float_parsing')),
2830
([1, 2], Err('Input should be a valid number [type=float_type, input_value=[1, 2], input_type=list]')),
2931
],
3032
)
@@ -161,7 +163,7 @@ def test_union_float_simple(py_and_json: PyAndJson):
161163
{
162164
'type': 'float_parsing',
163165
'loc': ('float',),
164-
'msg': 'Input should be a valid number, unable to parse string as an number',
166+
'msg': 'Input should be a valid number, unable to parse string as a number',
165167
'input': 'xxx',
166168
},
167169
{
@@ -251,15 +253,15 @@ def test_float_key(py_and_json: PyAndJson):
251253
'pika',
252254
True,
253255
Err(
254-
'Input should be a valid number, unable to parse string as an number '
256+
'Input should be a valid number, unable to parse string as a number '
255257
"[type=float_parsing, input_value='pika', input_type=str]"
256258
),
257259
),
258260
(
259261
'pika',
260262
False,
261263
Err(
262-
'Input should be a valid number, unable to parse string as an number '
264+
'Input should be a valid number, unable to parse string as a number '
263265
"[type=float_parsing, input_value='pika', input_type=str]"
264266
),
265267
),

0 commit comments

Comments
 (0)