1
1
use std:: borrow:: Cow ;
2
2
use std:: str:: FromStr ;
3
3
4
- use base64:: Engine ;
4
+ use base64:: engine:: general_purpose:: { STANDARD , URL_SAFE } ;
5
+ use base64:: { DecodeError , Engine } ;
5
6
use pyo3:: types:: { PyDict , PyString } ;
6
7
use pyo3:: { intern, prelude:: * } ;
7
8
@@ -28,31 +29,18 @@ impl ValBytesMode {
28
29
pub fn deserialize_string < ' py > ( self , s : & str ) -> Result < EitherBytes < ' _ , ' py > , ErrorType > {
29
30
match self . ser {
30
31
BytesMode :: Utf8 => Ok ( EitherBytes :: Cow ( Cow :: Borrowed ( s. as_bytes ( ) ) ) ) ,
31
- BytesMode :: Base64 => {
32
- fn decode ( input : & str ) -> Result < Vec < u8 > , ErrorType > {
33
- base64:: engine:: general_purpose:: URL_SAFE . decode ( input) . map_err ( |err| {
34
- ErrorType :: BytesInvalidEncoding {
35
- encoding : "base64" . to_string ( ) ,
36
- encoding_error : err. to_string ( ) ,
37
- context : None ,
38
- }
39
- } )
40
- }
41
- let result = if s. contains ( |c| c == '+' || c == '/' ) {
42
- let replaced: String = s
43
- . chars ( )
44
- . map ( |c| match c {
45
- '+' => '-' ,
46
- '/' => '_' ,
47
- _ => c,
48
- } )
49
- . collect ( ) ;
50
- decode ( & replaced)
51
- } else {
52
- decode ( s)
53
- } ;
54
- result. map ( EitherBytes :: from)
55
- }
32
+ BytesMode :: Base64 => URL_SAFE
33
+ . decode ( s)
34
+ . or_else ( |err| match err {
35
+ DecodeError :: InvalidByte ( _, b'/' | b'+' ) => STANDARD . decode ( s) ,
36
+ _ => Err ( err) ,
37
+ } )
38
+ . map ( EitherBytes :: from)
39
+ . map_err ( |err| ErrorType :: BytesInvalidEncoding {
40
+ encoding : "base64" . to_string ( ) ,
41
+ encoding_error : err. to_string ( ) ,
42
+ context : None ,
43
+ } ) ,
56
44
BytesMode :: Hex => match hex:: decode ( s) {
57
45
Ok ( vec) => Ok ( EitherBytes :: from ( vec) ) ,
58
46
Err ( err) => Err ( ErrorType :: BytesInvalidEncoding {
0 commit comments