Skip to content

Commit a86769e

Browse files
authored
Merge pull request RustPython#2084 from skinny121/ascii_bytes_like
Refactor struct module
2 parents 7ff974e + fbbefed commit a86769e

File tree

3 files changed

+71
-57
lines changed

3 files changed

+71
-57
lines changed

vm/src/builtins/memory.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ impl PyMemoryView {
7474
}
7575

7676
fn parse_format(format: &str, vm: &VirtualMachine) -> PyResult<FormatSpec> {
77-
FormatSpec::parse(format, vm)
77+
FormatSpec::parse(format.as_bytes(), vm)
7878
}
7979

8080
pub fn from_buffer(buffer: PyBuffer, vm: &VirtualMachine) -> PyResult<Self> {

vm/src/builtins/pystr.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ impl PyStr {
283283
}
284284

285285
/// SAFETY: Given 'bytes' must be ascii
286-
unsafe fn new_ascii_unchecked(bytes: Vec<u8>) -> Self {
286+
pub(crate) unsafe fn new_ascii_unchecked(bytes: Vec<u8>) -> Self {
287287
Self::new_str_unchecked(bytes, PyStrKind::Ascii)
288288
}
289289

vm/src/stdlib/pystruct.rs

Lines changed: 69 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
#[pymodule]
1313
pub(crate) mod _struct {
1414
use crate::{
15-
builtins::{float, PyBaseExceptionRef, PyBytesRef, PyStr, PyStrRef, PyTupleRef, PyTypeRef},
15+
builtins::{
16+
float, PyBaseExceptionRef, PyBytes, PyBytesRef, PyStr, PyStrRef, PyTupleRef, PyTypeRef,
17+
},
1618
common::str::wchar_t,
1719
function::{ArgBytesLike, ArgIntoBool, ArgMemoryBuffer, IntoPyObject, PosArgs},
1820
protocol::PyIterReturn,
1921
slots::{IteratorIterable, SlotConstructor, SlotIterator},
20-
utils::Either,
21-
PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, VirtualMachine,
22+
PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol, VirtualMachine,
2223
};
2324
use crossbeam_utils::atomic::AtomicCell;
2425
use half::f16;
@@ -202,6 +203,39 @@ pub(crate) mod _struct {
202203

203204
const OVERFLOW_MSG: &str = "total struct size too long";
204205

206+
struct IntoStructFormatBytes(PyStrRef);
207+
208+
impl TryFromObject for IntoStructFormatBytes {
209+
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
210+
// CPython turns str to bytes but we do reversed way here
211+
// The only performance difference is this transition cost
212+
let fmt = match_class! {
213+
match obj {
214+
s @ PyStr => if s.is_ascii() {
215+
Some(s)
216+
} else {
217+
None
218+
},
219+
b @ PyBytes => if b.is_ascii() {
220+
Some(unsafe {
221+
PyStr::new_ascii_unchecked(b.as_bytes().to_vec())
222+
}.into_ref(vm))
223+
} else {
224+
None
225+
},
226+
other => return Err(vm.new_type_error(format!("Struct() argument 1 must be a str or bytes object, not {}", other.class().name()))),
227+
}
228+
}.ok_or_else(|| vm.new_unicode_decode_error("Struct format must be a ascii string".to_owned()))?;
229+
Ok(IntoStructFormatBytes(fmt))
230+
}
231+
}
232+
233+
impl IntoStructFormatBytes {
234+
fn format_spec(&self, vm: &VirtualMachine) -> PyResult<FormatSpec> {
235+
FormatSpec::parse(self.0.as_str().as_bytes(), vm)
236+
}
237+
}
238+
205239
#[derive(Debug, Clone)]
206240
pub(crate) struct FormatSpec {
207241
endianness: Endianness,
@@ -211,24 +245,8 @@ pub(crate) mod _struct {
211245
}
212246

213247
impl FormatSpec {
214-
fn decode_and_parse(
215-
vm: &VirtualMachine,
216-
fmt: &Either<PyStrRef, PyBytesRef>,
217-
) -> PyResult<FormatSpec> {
218-
let decoded_fmt = match fmt {
219-
Either::A(string) => string.as_str(),
220-
Either::B(bytes) if bytes.is_ascii() => std::str::from_utf8(bytes).unwrap(),
221-
_ => {
222-
return Err(vm.new_unicode_decode_error(
223-
"Struct format must be a ascii string".to_owned(),
224-
))
225-
}
226-
};
227-
FormatSpec::parse(decoded_fmt, vm)
228-
}
229-
230-
pub fn parse(fmt: &str, vm: &VirtualMachine) -> PyResult<FormatSpec> {
231-
let mut chars = fmt.bytes().peekable();
248+
pub fn parse(fmt: &[u8], vm: &VirtualMachine) -> PyResult<FormatSpec> {
249+
let mut chars = fmt.iter().copied().peekable();
232250

233251
// First determine "@", "<", ">","!" or "="
234252
let endianness = parse_endianness(&mut chars);
@@ -399,10 +417,10 @@ pub(crate) mod _struct {
399417
let mut repeat = 0isize;
400418
while let Some(b'0'..=b'9') = chars.peek() {
401419
if let Some(c) = chars.next() {
402-
let current_digit = (c as char).to_digit(10).unwrap() as isize;
420+
let current_digit = c - b'0';
403421
repeat = repeat
404422
.checked_mul(10)
405-
.and_then(|r| r.checked_add(current_digit))
423+
.and_then(|r| r.checked_add(current_digit as _))
406424
.ok_or_else(|| OVERFLOW_MSG.to_owned())?;
407425
}
408426
}
@@ -486,20 +504,26 @@ pub(crate) mod _struct {
486504
}
487505
buffer_len - (-offset as usize)
488506
} else {
489-
if offset as usize >= buffer_len {
507+
let offset = offset as usize;
508+
let (op, op_action) = if is_pack {
509+
("pack_into", "packing")
510+
} else {
511+
("unpack_from", "unpacking")
512+
};
513+
if offset >= buffer_len {
490514
let msg = format!(
491515
"{op} requires a buffer of at least {required} bytes for {op_action} {needed} \
492516
bytes at offset {offset} (actual buffer size is {buffer_len})",
493-
op = if is_pack { "pack_into" } else { "unpack_from" },
494-
op_action = if is_pack { "packing" } else { "unpacking" },
517+
op = op,
518+
op_action = op_action,
495519
required = needed + offset as usize,
496520
needed = needed,
497521
offset = offset,
498522
buffer_len = buffer_len
499523
);
500524
return Err(new_struct_error(vm, msg));
501525
}
502-
offset as usize
526+
offset
503527
};
504528

505529
if (buffer_len - offset_from_start) < needed {
@@ -717,24 +741,19 @@ pub(crate) mod _struct {
717741
}
718742

719743
#[pyfunction]
720-
fn pack(
721-
fmt: Either<PyStrRef, PyBytesRef>,
722-
args: PosArgs,
723-
vm: &VirtualMachine,
724-
) -> PyResult<Vec<u8>> {
725-
let format_spec = FormatSpec::decode_and_parse(vm, &fmt)?;
726-
format_spec.pack(args.into_vec(), vm)
744+
fn pack(fmt: IntoStructFormatBytes, args: PosArgs, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
745+
fmt.format_spec(vm)?.pack(args.into_vec(), vm)
727746
}
728747

729748
#[pyfunction]
730749
fn pack_into(
731-
fmt: Either<PyStrRef, PyBytesRef>,
750+
fmt: IntoStructFormatBytes,
732751
buffer: ArgMemoryBuffer,
733752
offset: isize,
734753
args: PosArgs,
735754
vm: &VirtualMachine,
736755
) -> PyResult<()> {
737-
let format_spec = FormatSpec::decode_and_parse(vm, &fmt)?;
756+
let format_spec = fmt.format_spec(vm)?;
738757
let offset = get_buffer_offset(buffer.len(), offset, format_spec.size, true, vm)?;
739758
buffer.with_ref(|data| format_spec.pack_into(&mut data[offset..], args.into_vec(), vm))
740759
}
@@ -757,11 +776,11 @@ pub(crate) mod _struct {
757776

758777
#[pyfunction]
759778
fn unpack(
760-
fmt: Either<PyStrRef, PyBytesRef>,
779+
fmt: IntoStructFormatBytes,
761780
buffer: ArgBytesLike,
762781
vm: &VirtualMachine,
763782
) -> PyResult<PyTupleRef> {
764-
let format_spec = FormatSpec::decode_and_parse(vm, &fmt)?;
783+
let format_spec = fmt.format_spec(vm)?;
765784
buffer.with_ref(|buf| format_spec.unpack(buf, vm))
766785
}
767786

@@ -774,11 +793,11 @@ pub(crate) mod _struct {
774793

775794
#[pyfunction]
776795
fn unpack_from(
777-
fmt: Either<PyStrRef, PyBytesRef>,
796+
fmt: IntoStructFormatBytes,
778797
args: UpdateFromArgs,
779798
vm: &VirtualMachine,
780799
) -> PyResult<PyTupleRef> {
781-
let format_spec = FormatSpec::decode_and_parse(vm, &fmt)?;
800+
let format_spec = fmt.format_spec(vm)?;
782801
let offset =
783802
get_buffer_offset(args.buffer.len(), args.offset, format_spec.size, false, vm)?;
784803
args.buffer
@@ -849,47 +868,42 @@ pub(crate) mod _struct {
849868

850869
#[pyfunction]
851870
fn iter_unpack(
852-
fmt: Either<PyStrRef, PyBytesRef>,
871+
fmt: IntoStructFormatBytes,
853872
buffer: ArgBytesLike,
854873
vm: &VirtualMachine,
855874
) -> PyResult<UnpackIterator> {
856-
let format_spec = FormatSpec::decode_and_parse(vm, &fmt)?;
875+
let format_spec = fmt.format_spec(vm)?;
857876
UnpackIterator::new(vm, format_spec, buffer)
858877
}
859878

860879
#[pyfunction]
861-
fn calcsize(fmt: Either<PyStrRef, PyBytesRef>, vm: &VirtualMachine) -> PyResult<usize> {
862-
let format_spec = FormatSpec::decode_and_parse(vm, &fmt)?;
863-
Ok(format_spec.size)
880+
fn calcsize(fmt: IntoStructFormatBytes, vm: &VirtualMachine) -> PyResult<usize> {
881+
Ok(fmt.format_spec(vm)?.size)
864882
}
865883

866884
#[pyattr]
867885
#[pyclass(name = "Struct")]
868886
#[derive(Debug, PyValue)]
869887
struct PyStruct {
870888
spec: FormatSpec,
871-
fmt_str: PyStrRef,
889+
format: PyStrRef,
872890
}
873891

874892
impl SlotConstructor for PyStruct {
875-
type Args = Either<PyStrRef, PyBytesRef>;
893+
type Args = IntoStructFormatBytes;
876894

877895
fn py_new(cls: PyTypeRef, fmt: Self::Args, vm: &VirtualMachine) -> PyResult {
878-
let spec = FormatSpec::decode_and_parse(vm, &fmt)?;
879-
let fmt_str = match fmt {
880-
Either::A(s) => s,
881-
Either::B(b) => PyStr::from(std::str::from_utf8(b.as_bytes()).unwrap())
882-
.into_ref_with_type(vm, vm.ctx.types.str_type.clone())?,
883-
};
884-
PyStruct { spec, fmt_str }.into_pyresult_with_type(vm, cls)
896+
let spec = fmt.format_spec(vm)?;
897+
let format = fmt.0;
898+
PyStruct { spec, format }.into_pyresult_with_type(vm, cls)
885899
}
886900
}
887901

888902
#[pyimpl(with(SlotConstructor))]
889903
impl PyStruct {
890904
#[pyproperty]
891905
fn format(&self) -> PyStrRef {
892-
self.fmt_str.clone()
906+
self.format.clone()
893907
}
894908

895909
#[pyproperty]

0 commit comments

Comments
 (0)