12
12
#[ pymodule]
13
13
pub ( crate ) mod _struct {
14
14
use crate :: {
15
- builtins:: { float, PyBaseExceptionRef , PyBytesRef , PyStr , PyStrRef , PyTupleRef , PyTypeRef } ,
15
+ builtins:: {
16
+ float, PyBaseExceptionRef , PyBytes , PyBytesRef , PyStr , PyStrRef , PyTupleRef , PyTypeRef ,
17
+ } ,
16
18
common:: str:: wchar_t,
17
19
function:: { ArgBytesLike , ArgIntoBool , ArgMemoryBuffer , IntoPyObject , PosArgs } ,
18
20
protocol:: PyIterReturn ,
19
21
slots:: { IteratorIterable , SlotConstructor , SlotIterator } ,
20
- utils:: Either ,
21
- PyObjectRef , PyRef , PyResult , PyValue , TryFromObject , VirtualMachine ,
22
+ PyObjectRef , PyRef , PyResult , PyValue , TryFromObject , TypeProtocol , VirtualMachine ,
22
23
} ;
23
24
use crossbeam_utils:: atomic:: AtomicCell ;
24
25
use half:: f16;
@@ -202,6 +203,39 @@ pub(crate) mod _struct {
202
203
203
204
const OVERFLOW_MSG : & str = "total struct size too long" ;
204
205
206
+ struct IntoStructFormatBytes ( PyStrRef ) ;
207
+
208
+ impl TryFromObject for IntoStructFormatBytes {
209
+ fn try_from_object ( vm : & VirtualMachine , obj : PyObjectRef ) -> PyResult < Self > {
210
+ // CPython turns str to bytes but we do reversed way here
211
+ // The only performance difference is this transition cost
212
+ let fmt = match_class ! {
213
+ match obj {
214
+ s @ PyStr => if s. is_ascii( ) {
215
+ Some ( s)
216
+ } else {
217
+ None
218
+ } ,
219
+ b @ PyBytes => if b. is_ascii( ) {
220
+ Some ( unsafe {
221
+ PyStr :: new_ascii_unchecked( b. as_bytes( ) . to_vec( ) )
222
+ } . into_ref( vm) )
223
+ } else {
224
+ None
225
+ } ,
226
+ other => return Err ( vm. new_type_error( format!( "Struct() argument 1 must be a str or bytes object, not {}" , other. class( ) . name( ) ) ) ) ,
227
+ }
228
+ } . ok_or_else ( || vm. new_unicode_decode_error ( "Struct format must be a ascii string" . to_owned ( ) ) ) ?;
229
+ Ok ( IntoStructFormatBytes ( fmt) )
230
+ }
231
+ }
232
+
233
+ impl IntoStructFormatBytes {
234
+ fn format_spec ( & self , vm : & VirtualMachine ) -> PyResult < FormatSpec > {
235
+ FormatSpec :: parse ( self . 0 . as_str ( ) . as_bytes ( ) , vm)
236
+ }
237
+ }
238
+
205
239
#[ derive( Debug , Clone ) ]
206
240
pub ( crate ) struct FormatSpec {
207
241
endianness : Endianness ,
@@ -211,24 +245,8 @@ pub(crate) mod _struct {
211
245
}
212
246
213
247
impl FormatSpec {
214
- fn decode_and_parse (
215
- vm : & VirtualMachine ,
216
- fmt : & Either < PyStrRef , PyBytesRef > ,
217
- ) -> PyResult < FormatSpec > {
218
- let decoded_fmt = match fmt {
219
- Either :: A ( string) => string. as_str ( ) ,
220
- Either :: B ( bytes) if bytes. is_ascii ( ) => std:: str:: from_utf8 ( bytes) . unwrap ( ) ,
221
- _ => {
222
- return Err ( vm. new_unicode_decode_error (
223
- "Struct format must be a ascii string" . to_owned ( ) ,
224
- ) )
225
- }
226
- } ;
227
- FormatSpec :: parse ( decoded_fmt, vm)
228
- }
229
-
230
- pub fn parse ( fmt : & str , vm : & VirtualMachine ) -> PyResult < FormatSpec > {
231
- let mut chars = fmt. bytes ( ) . peekable ( ) ;
248
+ pub fn parse ( fmt : & [ u8 ] , vm : & VirtualMachine ) -> PyResult < FormatSpec > {
249
+ let mut chars = fmt. iter ( ) . copied ( ) . peekable ( ) ;
232
250
233
251
// First determine "@", "<", ">","!" or "="
234
252
let endianness = parse_endianness ( & mut chars) ;
@@ -399,10 +417,10 @@ pub(crate) mod _struct {
399
417
let mut repeat = 0isize ;
400
418
while let Some ( b'0' ..=b'9' ) = chars. peek ( ) {
401
419
if let Some ( c) = chars. next ( ) {
402
- let current_digit = ( c as char ) . to_digit ( 10 ) . unwrap ( ) as isize ;
420
+ let current_digit = c - b'0' ;
403
421
repeat = repeat
404
422
. checked_mul ( 10 )
405
- . and_then ( |r| r. checked_add ( current_digit) )
423
+ . and_then ( |r| r. checked_add ( current_digit as _ ) )
406
424
. ok_or_else ( || OVERFLOW_MSG . to_owned ( ) ) ?;
407
425
}
408
426
}
@@ -486,20 +504,26 @@ pub(crate) mod _struct {
486
504
}
487
505
buffer_len - ( -offset as usize )
488
506
} else {
489
- if offset as usize >= buffer_len {
507
+ let offset = offset as usize ;
508
+ let ( op, op_action) = if is_pack {
509
+ ( "pack_into" , "packing" )
510
+ } else {
511
+ ( "unpack_from" , "unpacking" )
512
+ } ;
513
+ if offset >= buffer_len {
490
514
let msg = format ! (
491
515
"{op} requires a buffer of at least {required} bytes for {op_action} {needed} \
492
516
bytes at offset {offset} (actual buffer size is {buffer_len})",
493
- op = if is_pack { "pack_into" } else { "unpack_from" } ,
494
- op_action = if is_pack { "packing" } else { "unpacking" } ,
517
+ op = op ,
518
+ op_action = op_action ,
495
519
required = needed + offset as usize ,
496
520
needed = needed,
497
521
offset = offset,
498
522
buffer_len = buffer_len
499
523
) ;
500
524
return Err ( new_struct_error ( vm, msg) ) ;
501
525
}
502
- offset as usize
526
+ offset
503
527
} ;
504
528
505
529
if ( buffer_len - offset_from_start) < needed {
@@ -717,24 +741,19 @@ pub(crate) mod _struct {
717
741
}
718
742
719
743
#[ pyfunction]
720
- fn pack (
721
- fmt : Either < PyStrRef , PyBytesRef > ,
722
- args : PosArgs ,
723
- vm : & VirtualMachine ,
724
- ) -> PyResult < Vec < u8 > > {
725
- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt) ?;
726
- format_spec. pack ( args. into_vec ( ) , vm)
744
+ fn pack ( fmt : IntoStructFormatBytes , args : PosArgs , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
745
+ fmt. format_spec ( vm) ?. pack ( args. into_vec ( ) , vm)
727
746
}
728
747
729
748
#[ pyfunction]
730
749
fn pack_into (
731
- fmt : Either < PyStrRef , PyBytesRef > ,
750
+ fmt : IntoStructFormatBytes ,
732
751
buffer : ArgMemoryBuffer ,
733
752
offset : isize ,
734
753
args : PosArgs ,
735
754
vm : & VirtualMachine ,
736
755
) -> PyResult < ( ) > {
737
- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt ) ?;
756
+ let format_spec = fmt . format_spec ( vm) ?;
738
757
let offset = get_buffer_offset ( buffer. len ( ) , offset, format_spec. size , true , vm) ?;
739
758
buffer. with_ref ( |data| format_spec. pack_into ( & mut data[ offset..] , args. into_vec ( ) , vm) )
740
759
}
@@ -757,11 +776,11 @@ pub(crate) mod _struct {
757
776
758
777
#[ pyfunction]
759
778
fn unpack (
760
- fmt : Either < PyStrRef , PyBytesRef > ,
779
+ fmt : IntoStructFormatBytes ,
761
780
buffer : ArgBytesLike ,
762
781
vm : & VirtualMachine ,
763
782
) -> PyResult < PyTupleRef > {
764
- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt ) ?;
783
+ let format_spec = fmt . format_spec ( vm) ?;
765
784
buffer. with_ref ( |buf| format_spec. unpack ( buf, vm) )
766
785
}
767
786
@@ -774,11 +793,11 @@ pub(crate) mod _struct {
774
793
775
794
#[ pyfunction]
776
795
fn unpack_from (
777
- fmt : Either < PyStrRef , PyBytesRef > ,
796
+ fmt : IntoStructFormatBytes ,
778
797
args : UpdateFromArgs ,
779
798
vm : & VirtualMachine ,
780
799
) -> PyResult < PyTupleRef > {
781
- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt ) ?;
800
+ let format_spec = fmt . format_spec ( vm) ?;
782
801
let offset =
783
802
get_buffer_offset ( args. buffer . len ( ) , args. offset , format_spec. size , false , vm) ?;
784
803
args. buffer
@@ -849,47 +868,42 @@ pub(crate) mod _struct {
849
868
850
869
#[ pyfunction]
851
870
fn iter_unpack (
852
- fmt : Either < PyStrRef , PyBytesRef > ,
871
+ fmt : IntoStructFormatBytes ,
853
872
buffer : ArgBytesLike ,
854
873
vm : & VirtualMachine ,
855
874
) -> PyResult < UnpackIterator > {
856
- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt ) ?;
875
+ let format_spec = fmt . format_spec ( vm) ?;
857
876
UnpackIterator :: new ( vm, format_spec, buffer)
858
877
}
859
878
860
879
#[ pyfunction]
861
- fn calcsize ( fmt : Either < PyStrRef , PyBytesRef > , vm : & VirtualMachine ) -> PyResult < usize > {
862
- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt) ?;
863
- Ok ( format_spec. size )
880
+ fn calcsize ( fmt : IntoStructFormatBytes , vm : & VirtualMachine ) -> PyResult < usize > {
881
+ Ok ( fmt. format_spec ( vm) ?. size )
864
882
}
865
883
866
884
#[ pyattr]
867
885
#[ pyclass( name = "Struct" ) ]
868
886
#[ derive( Debug , PyValue ) ]
869
887
struct PyStruct {
870
888
spec : FormatSpec ,
871
- fmt_str : PyStrRef ,
889
+ format : PyStrRef ,
872
890
}
873
891
874
892
impl SlotConstructor for PyStruct {
875
- type Args = Either < PyStrRef , PyBytesRef > ;
893
+ type Args = IntoStructFormatBytes ;
876
894
877
895
fn py_new ( cls : PyTypeRef , fmt : Self :: Args , vm : & VirtualMachine ) -> PyResult {
878
- let spec = FormatSpec :: decode_and_parse ( vm, & fmt) ?;
879
- let fmt_str = match fmt {
880
- Either :: A ( s) => s,
881
- Either :: B ( b) => PyStr :: from ( std:: str:: from_utf8 ( b. as_bytes ( ) ) . unwrap ( ) )
882
- . into_ref_with_type ( vm, vm. ctx . types . str_type . clone ( ) ) ?,
883
- } ;
884
- PyStruct { spec, fmt_str } . into_pyresult_with_type ( vm, cls)
896
+ let spec = fmt. format_spec ( vm) ?;
897
+ let format = fmt. 0 ;
898
+ PyStruct { spec, format } . into_pyresult_with_type ( vm, cls)
885
899
}
886
900
}
887
901
888
902
#[ pyimpl( with( SlotConstructor ) ) ]
889
903
impl PyStruct {
890
904
#[ pyproperty]
891
905
fn format ( & self ) -> PyStrRef {
892
- self . fmt_str . clone ( )
906
+ self . format . clone ( )
893
907
}
894
908
895
909
#[ pyproperty]
0 commit comments