Skip to content

Commit 90dfb74

Browse files
committed
clean up some string handling cases
1 parent 40b8a94 commit 90dfb74

File tree

3 files changed

+31
-25
lines changed

3 files changed

+31
-25
lines changed

src/input/input_python.rs

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use std::borrow::Cow;
21
use std::str::from_utf8;
32

43
use pyo3::intern;
@@ -144,12 +143,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
144143
Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)),
145144
}
146145
} else if let Ok(py_byte_array) = self.downcast::<PyByteArray>() {
147-
// Safety: the gil is held while from_utf8 is running so py_byte_array is not mutated,
148-
// and we immediately copy the bytes into a new Python string
149-
match from_utf8(unsafe { py_byte_array.as_bytes() }) {
150-
// Why Python not Rust? to avoid an unnecessary allocation on the Rust side, the
151-
// final output needs to be Python anyway.
152-
Ok(s) => Ok(PyString::new_bound(self.py(), s).into()),
146+
match bytearray_to_str(py_byte_array) {
147+
Ok(py_str) => Ok(py_str.into()),
153148
Err(_) => Err(ValError::new(ErrorTypeDefaults::StringUnicode, self)),
154149
}
155150
} else if coerce_numbers_to_str && !self.is_exact_instance_of::<PyBool>() && {
@@ -204,8 +199,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
204199
}
205200

206201
if !strict {
207-
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
208-
return str_as_bool(self, &cow_str).map(ValidationMatch::lax);
202+
if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::BoolParsing)? {
203+
return str_as_bool(self, s).map(ValidationMatch::lax);
209204
} else if let Some(int) = extract_i64(self) {
210205
return int_as_bool(self, int).map(ValidationMatch::lax);
211206
} else if let Ok(float) = self.extract::<f64>() {
@@ -241,8 +236,8 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
241236

242237
'lax: {
243238
if !strict {
244-
return if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? {
245-
str_as_int(self, &cow_str)
239+
return if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::IntParsing)? {
240+
str_as_int(self, s)
246241
} else if self.is_exact_instance_of::<PyFloat>() {
247242
float_as_int(self, self.extract::<f64>()?)
248243
} else if let Ok(decimal) = self.strict_decimal(self.py()) {
@@ -283,9 +278,9 @@ impl<'py> Input<'py> for Bound<'py, PyAny> {
283278
}
284279

285280
if !strict {
286-
if let Some(cow_str) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? {
281+
if let Some(s) = maybe_as_string(self, ErrorTypeDefaults::FloatParsing)? {
287282
// checking for bytes and string is fast, so do this before isinstance(float)
288-
return str_as_float(self, &cow_str).map(ValidationMatch::lax);
283+
return str_as_float(self, s).map(ValidationMatch::lax);
289284
}
290285
}
291286

@@ -630,20 +625,31 @@ fn from_attributes_applicable(obj: &Bound<'_, PyAny>) -> bool {
630625
}
631626

632627
/// Utility for extracting a string from a PyAny, if possible.
633-
fn maybe_as_string<'a>(v: &'a Bound<'_, PyAny>, unicode_error: ErrorType) -> ValResult<Option<Cow<'a, str>>> {
628+
fn maybe_as_string<'a>(v: &'a Bound<'_, PyAny>, unicode_error: ErrorType) -> ValResult<Option<&'a str>> {
634629
if let Ok(py_string) = v.downcast::<PyString>() {
635-
let str = py_string_str(py_string)?;
636-
Ok(Some(Cow::Borrowed(str)))
630+
py_string_str(py_string).map(Some)
637631
} else if let Ok(bytes) = v.downcast::<PyBytes>() {
638632
match from_utf8(bytes.as_bytes()) {
639-
Ok(s) => Ok(Some(Cow::Owned(s.to_string()))),
633+
Ok(s) => Ok(Some(s)),
640634
Err(_) => Err(ValError::new(unicode_error, v)),
641635
}
642636
} else {
643637
Ok(None)
644638
}
645639
}
646640

641+
/// Decode a Python bytearray to a Python string.
642+
///
643+
/// Using Python's built-in machinery for this should be efficient and avoids questions around
644+
/// safety of concurrent mutation of the bytearray (by leaving that to the Python interpreter).
645+
fn bytearray_to_str<'py>(bytearray: &Bound<'py, PyByteArray>) -> PyResult<Bound<'py, PyString>> {
646+
let py = bytearray.py();
647+
let py_string = bytearray
648+
.call_method1(intern!(py, "decode"), (intern!(py, "utf-8"),))?
649+
.downcast_into()?;
650+
Ok(py_string)
651+
}
652+
647653
/// Utility for extracting an enum value, if possible.
648654
fn maybe_as_enum<'py>(v: &Bound<'py, PyAny>) -> Option<Bound<'py, PyAny>> {
649655
let py = v.py();

src/serializers/config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ macro_rules! serialization_mode {
7777
return Ok(Self::default());
7878
};
7979
let raw_mode = config_dict.get_as::<Bound<'_, PyString>>(intern!(config_dict.py(), $config_key))?;
80-
raw_mode.map_or_else(|| Ok(Self::default()), |raw| Self::from_str(&raw.to_cow()?))
80+
raw_mode.map_or_else(|| Ok(Self::default()), |raw| Self::from_str(raw.to_str()?))
8181
}
8282
}
8383

src/serializers/fields.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl GeneralFieldsSerializer {
159159
for result in main_iter {
160160
let (key, value) = result?;
161161
let key_str = key_str(&key)?;
162-
let op_field = self.fields.get(key_str.as_ref());
162+
let op_field = self.fields.get(key_str);
163163
if extra.exclude_none && value.is_none() {
164164
if let Some(field) = op_field {
165165
if field.required {
@@ -169,7 +169,7 @@ impl GeneralFieldsSerializer {
169169
continue;
170170
}
171171
let field_extra = Extra {
172-
field_name: Some(&key_str),
172+
field_name: Some(key_str),
173173
..extra
174174
};
175175
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
@@ -236,13 +236,13 @@ impl GeneralFieldsSerializer {
236236
}
237237
let key_str = key_str(&key).map_err(py_err_se_err)?;
238238
let field_extra = Extra {
239-
field_name: Some(&key_str),
239+
field_name: Some(key_str),
240240
..extra
241241
};
242242

243243
let filter = self.filter.key_filter(&key, include, exclude).map_err(py_err_se_err)?;
244244
if let Some((next_include, next_exclude)) = filter {
245-
if let Some(field) = self.fields.get(key_str.as_ref()) {
245+
if let Some(field) = self.fields.get(key_str) {
246246
if let Some(ref serializer) = field.serializer {
247247
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
248248
let s = PydanticSerializer::new(
@@ -252,7 +252,7 @@ impl GeneralFieldsSerializer {
252252
next_exclude.as_ref(),
253253
&field_extra,
254254
);
255-
let output_key = field.get_key_json(&key_str, &field_extra);
255+
let output_key = field.get_key_json(key_str, &field_extra);
256256
map.serialize_entry(&output_key, &s)?;
257257
}
258258
}
@@ -446,8 +446,8 @@ impl TypeSerializer for GeneralFieldsSerializer {
446446
}
447447
}
448448

449-
fn key_str<'a>(key: &'a Bound<'_, PyAny>) -> PyResult<Cow<'a, str>> {
450-
key.downcast::<PyString>()?.to_cow()
449+
fn key_str<'a>(key: &'a Bound<'_, PyAny>) -> PyResult<&'a str> {
450+
key.downcast::<PyString>()?.to_str()
451451
}
452452

453453
fn dict_items<'py>(

0 commit comments

Comments
 (0)