Skip to content

Commit 4c44225

Browse files
committed
Implement validation based on ser_json_bytes to support round trip
1 parent fd26293 commit 4c44225

File tree

11 files changed

+85
-16
lines changed

11 files changed

+85
-16
lines changed

src/input/input_abstract.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use pyo3::{intern, prelude::*};
66

77
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
88
use crate::lookup_key::{LookupKey, LookupPath};
9+
use crate::serializers::config::BytesMode;
910
use crate::tools::py_err;
1011

1112
use super::datetime::{EitherDate, EitherDateTime, EitherTime, EitherTimedelta};
@@ -71,7 +72,7 @@ pub trait Input<'py>: fmt::Debug + ToPyObject {
7172

7273
fn validate_str(&self, strict: bool, coerce_numbers_to_str: bool) -> ValMatch<EitherString<'_>>;
7374

74-
fn validate_bytes<'a>(&'a self, strict: bool) -> ValMatch<EitherBytes<'a, 'py>>;
75+
fn validate_bytes<'a>(&'a self, strict: bool, mode: BytesMode) -> ValMatch<EitherBytes<'a, 'py>>;
7576

7677
fn validate_bool(&self, strict: bool) -> ValMatch<bool>;
7778

src/input/input_json.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use strum::EnumMessage;
99

1010
use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
1111
use crate::lookup_key::{LookupKey, LookupPath};
12+
use crate::serializers::config::BytesMode;
1213
use crate::validators::decimal::create_decimal;
1314

1415
use super::datetime::{
@@ -106,9 +107,16 @@ impl<'py, 'data> Input<'py> for JsonValue<'data> {
106107
}
107108
}
108109

109-
fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
110+
fn validate_bytes<'a>(
111+
&'a self,
112+
_strict: bool,
113+
mode: BytesMode,
114+
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
110115
match self {
111-
JsonValue::Str(s) => Ok(ValidationMatch::strict(s.as_bytes().into())),
116+
JsonValue::Str(s) => match mode.deserialize_string(s) {
117+
Ok(b) => Ok(ValidationMatch::strict(b)),
118+
Err(e) => Err(ValError::from(e)),
119+
},
112120
_ => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
113121
}
114122
}
@@ -342,8 +350,15 @@ impl<'py> Input<'py> for str {
342350
Ok(ValidationMatch::strict(self.into()))
343351
}
344352

345-
fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
346-
Ok(ValidationMatch::strict(self.as_bytes().into()))
353+
fn validate_bytes<'a>(
354+
&'a self,
355+
_strict: bool,
356+
mode: BytesMode,
357+
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
358+
match mode.deserialize_string(self) {
359+
Ok(b) => Ok(ValidationMatch::strict(b)),
360+
Err(e) => Err(ValError::from(e)),
361+
}
347362
}
348363

349364
fn validate_bool(&self, _strict: bool) -> ValResult<ValidationMatch<bool>> {

src/input/input_python.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use pyo3::PyTypeCheck;
1414
use speedate::MicrosecondsPrecisionOverflowBehavior;
1515

1616
use crate::errors::{ErrorType, ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
17+
use crate::serializers::config::BytesMode;
1718
use crate::tools::{extract_i64, safe_repr};
1819
use crate::validators::decimal::{create_decimal, get_decimal_type};
1920
use crate::validators::Exactness;
@@ -174,7 +175,7 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
174175
Err(ValError::new(ErrorTypeDefaults::StringType, self))
175176
}
176177

177-
fn validate_bytes<'a>(&'a self, strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
178+
fn validate_bytes<'a>(&'a self, strict: bool, mode: BytesMode) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
178179
if let Ok(py_bytes) = self.downcast_exact::<PyBytes>() {
179180
return Ok(ValidationMatch::exact(py_bytes.into()));
180181
} else if let Ok(py_bytes) = self.downcast::<PyBytes>() {
@@ -185,7 +186,10 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
185186
if !strict {
186187
return if let Ok(py_str) = self.downcast::<PyString>() {
187188
let str = py_string_str(py_str)?;
188-
Ok(str.as_bytes().into())
189+
match mode.deserialize_string(str) {
190+
Ok(b) => Ok(b),
191+
Err(e) => Err(ValError::from(e)),
192+
}
189193
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
190194
Ok(py_byte_array.to_vec().into())
191195
} else {

src/input/input_string.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use speedate::MicrosecondsPrecisionOverflowBehavior;
66
use crate::errors::{ErrorTypeDefaults, InputValue, LocItem, ValError, ValResult};
77
use crate::input::py_string_str;
88
use crate::lookup_key::{LookupKey, LookupPath};
9+
use crate::serializers::config::BytesMode;
910
use crate::tools::safe_repr;
1011
use crate::validators::decimal::create_decimal;
1112

@@ -105,9 +106,16 @@ impl<'py> Input<'py> for StringMapping<'py> {
105106
}
106107
}
107108

108-
fn validate_bytes<'a>(&'a self, _strict: bool) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
109+
fn validate_bytes<'a>(
110+
&'a self,
111+
_strict: bool,
112+
mode: BytesMode,
113+
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
109114
match self {
110-
Self::String(s) => py_string_str(s).map(|b| ValidationMatch::strict(b.as_bytes().into())),
115+
Self::String(s) => py_string_str(s).and_then(|b| match mode.deserialize_string(b) {
116+
Ok(b) => Ok(ValidationMatch::strict(b)),
117+
Err(e) => Err(ValError::from(e)),
118+
}),
111119
Self::Mapping(_) => Err(ValError::new(ErrorTypeDefaults::BytesType, self)),
112120
}
113121
}

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::sync::OnceLock;
77
use jiter::StringCacheMode;
88
use pyo3::exceptions::PyTypeError;
99
use pyo3::{prelude::*, sync::GILOnceCell};
10+
use serializers::config::BytesMode;
1011

1112
// parse this first to get access to the contained macro
1213
#[macro_use]
@@ -55,7 +56,7 @@ pub fn from_json<'py>(
5556
allow_partial: bool,
5657
) -> PyResult<Bound<'py, PyAny>> {
5758
let v_match = data
58-
.validate_bytes(false)
59+
.validate_bytes(false, BytesMode::Utf8)
5960
.map_err(|_| PyTypeError::new_err("Expected bytes, bytearray or str"))?;
6061
let json_either_bytes = v_match.into_inner();
6162
let json_bytes = json_either_bytes.as_slice();

src/serializers/config.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@ use std::borrow::Cow;
22
use std::str::{from_utf8, FromStr, Utf8Error};
33

44
use base64::Engine;
5+
use pyo3::exceptions::PyValueError;
56
use pyo3::intern;
67
use pyo3::prelude::*;
78
use pyo3::types::{PyDelta, PyDict, PyString};
89

910
use serde::ser::Error;
1011

1112
use crate::build_tools::py_schema_err;
12-
use crate::input::EitherTimedelta;
13+
use crate::input::{EitherBytes, EitherTimedelta};
1314
use crate::tools::SchemaDict;
1415

1516
use super::errors::py_err_se_err;
@@ -187,6 +188,17 @@ impl BytesMode {
187188
}
188189
}
189190
}
191+
192+
pub fn deserialize_string<'a, 'py>(&self, s: &'a str) -> PyResult<EitherBytes<'a, 'py>> {
193+
match self {
194+
Self::Utf8 => Ok(EitherBytes::Cow(Cow::Borrowed(s.as_bytes()))),
195+
Self::Base64 => match base64::engine::general_purpose::URL_SAFE.decode(s) {
196+
Ok(bytes) => Ok(EitherBytes::from(bytes)),
197+
Err(err) => Err(PyValueError::new_err(format!("Base64 decode error: {}", err))),
198+
},
199+
Self::Hex => Err(PyValueError::new_err("Hex deserialization is not supported")),
200+
}
201+
}
190202
}
191203

192204
pub fn utf8_py_error(py: Python, err: Utf8Error, data: &[u8]) -> PyErr {

src/serializers/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub use shared::CombinedSerializer;
1616
use shared::{to_json_bytes, BuildSerializer, TypeSerializer};
1717

1818
mod computed_fields;
19-
mod config;
19+
pub(crate) mod config;
2020
mod errors;
2121
mod extra;
2222
mod fields;

src/validators/bytes.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@ use crate::build_tools::is_strict;
66
use crate::errors::{ErrorType, ValError, ValResult};
77
use crate::input::Input;
88

9+
use crate::serializers::config::{BytesMode, FromConfig};
910
use crate::tools::SchemaDict;
1011

1112
use super::{BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
1213

1314
#[derive(Debug, Clone)]
1415
pub struct BytesValidator {
1516
strict: bool,
17+
bytes_mode: BytesMode,
1618
}
1719

1820
impl BuildValidator for BytesValidator {
@@ -31,6 +33,7 @@ impl BuildValidator for BytesValidator {
3133
} else {
3234
Ok(Self {
3335
strict: is_strict(schema, config)?,
36+
bytes_mode: BytesMode::from_config(config)?,
3437
}
3538
.into())
3639
}
@@ -47,7 +50,7 @@ impl Validator for BytesValidator {
4750
state: &mut ValidationState<'_, 'py>,
4851
) -> ValResult<PyObject> {
4952
input
50-
.validate_bytes(state.strict_or(self.strict))
53+
.validate_bytes(state.strict_or(self.strict), self.bytes_mode.clone())
5154
.map(|m| m.unpack(state).into_py(py))
5255
}
5356

@@ -59,6 +62,7 @@ impl Validator for BytesValidator {
5962
#[derive(Debug, Clone)]
6063
pub struct BytesConstrainedValidator {
6164
strict: bool,
65+
bytes_mode: BytesMode,
6266
max_length: Option<usize>,
6367
min_length: Option<usize>,
6468
}
@@ -72,7 +76,9 @@ impl Validator for BytesConstrainedValidator {
7276
input: &(impl Input<'py> + ?Sized),
7377
state: &mut ValidationState<'_, 'py>,
7478
) -> ValResult<PyObject> {
75-
let either_bytes = input.validate_bytes(state.strict_or(self.strict))?.unpack(state);
79+
let either_bytes = input
80+
.validate_bytes(state.strict_or(self.strict), self.bytes_mode.clone())?
81+
.unpack(state);
7682
let len = either_bytes.len()?;
7783

7884
if let Some(min_length) = self.min_length {
@@ -110,6 +116,7 @@ impl BytesConstrainedValidator {
110116
let py = schema.py();
111117
Ok(Self {
112118
strict: is_strict(schema, config)?,
119+
bytes_mode: BytesMode::from_config(config)?,
113120
min_length: schema.get_as(intern!(py, "min_length"))?,
114121
max_length: schema.get_as(intern!(py, "max_length"))?,
115122
}

src/validators/json.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use jiter::JsonValue;
66

77
use crate::errors::{ErrorType, ErrorTypeDefaults, ValError, ValLineError, ValResult};
88
use crate::input::{EitherBytes, Input, InputType, ValidationMatch};
9+
use crate::serializers::config::BytesMode;
910
use crate::tools::SchemaDict;
1011

1112
use super::{build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, ValidationState, Validator};
@@ -79,7 +80,7 @@ impl Validator for JsonValidator {
7980
pub fn validate_json_bytes<'a, 'py>(
8081
input: &'a (impl Input<'py> + ?Sized),
8182
) -> ValResult<ValidationMatch<EitherBytes<'a, 'py>>> {
82-
match input.validate_bytes(false) {
83+
match input.validate_bytes(false, BytesMode::Utf8) {
8384
Ok(v_match) => Ok(v_match),
8485
Err(ValError::LineErrors(e)) => Err(ValError::LineErrors(
8586
e.into_iter().map(map_bytes_error).collect::<Vec<_>>(),

src/validators/uuid.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::input::input_as_python_instance;
1313
use crate::input::Input;
1414
use crate::input::InputType;
1515
use crate::input::ValidationMatch;
16+
use crate::serializers::config::BytesMode;
1617
use crate::tools::SchemaDict;
1718

1819
use super::model::create_class;
@@ -169,7 +170,7 @@ impl UuidValidator {
169170
}
170171
None => {
171172
let either_bytes = input
172-
.validate_bytes(true)
173+
.validate_bytes(true, BytesMode::Utf8)
173174
.map_err(|_| ValError::new(ErrorTypeDefaults::UuidType, input))?
174175
.into_inner();
175176
let bytes_slice = either_bytes.as_slice();

tests/test_json.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,22 @@ def test_partial_parse():
376376
with pytest.raises(ValueError, match='EOF while parsing a string at line 1 column 15'):
377377
from_json(b'["aa", "bb", "c')
378378
assert from_json(b'["aa", "bb", "c', allow_partial=True) == ['aa', 'bb']
379+
380+
381+
def test_json_bytes_base64_round_trip():
382+
data = b'hello'
383+
encoded = b'"aGVsbG8="'
384+
assert to_json(data, bytes_mode='base64') == encoded
385+
386+
v = SchemaValidator({'type': 'bytes'}, {'ser_json_bytes': 'base64'})
387+
assert v.validate_json(encoded) == data
388+
389+
with pytest.raises(ValueError):
390+
v.validate_json('"wrong!"')
391+
392+
assert to_json({'key': data}, bytes_mode='base64') == b'{"key":"aGVsbG8="}'
393+
v = SchemaValidator(
394+
{'type': 'dict', 'keys_schema': {'type': 'str'}, 'values_schema': {'type': 'bytes'}},
395+
{'ser_json_bytes': 'base64'},
396+
)
397+
assert v.validate_json('{"key":"aGVsbG8="}') == {'key': data}

0 commit comments

Comments
 (0)