6
6
CRNL = b"\r \n "
7
7
8
8
9
- class VerbatimString ( bytes ):
9
+ class VerbatimStr ( str ):
10
10
"""
11
11
A string that is encoded as a resp3 verbatim string
12
12
"""
13
13
14
- def __new__ (cls , value : bytes , hint : str ) -> "VerbatimString " :
15
- return bytes .__new__ (cls , value )
14
+ def __new__ (cls , value : str , hint : str ) -> "VerbatimStr " :
15
+ return str .__new__ (cls , value )
16
16
17
- def __init__ (self , value : bytes , hint : str ) -> None :
17
+ def __init__ (self , value : str , hint : str ) -> None :
18
18
self .hint = hint
19
19
20
20
def __repr__ (self ) -> str :
21
- return f"VerbatimString({ super ().__repr__ ()} , { self .hint !r} )"
21
+ return f"VerbatimStr({ super ().__repr__ ()} , { self .hint !r} )"
22
+
23
+
24
+ class ErrorStr (str ):
25
+ """
26
+ A string to be encoded as a resp3 error
27
+ """
28
+
29
+ def __new__ (cls , code : str , value : str ) -> "ErrorStr" :
30
+ return str .__new__ (cls , value )
31
+
32
+ def __init__ (self , code : str , value : str ) -> None :
33
+ self .code = code .upper ()
34
+
35
+ def __repr__ (self ) -> str :
36
+ return f"ErrorString({ self .code !r} , { super ().__repr__ ()} )"
37
+
38
+ def __str__ (self ):
39
+ return f"{ self .code } { super ().__str__ ()} "
22
40
23
41
24
42
class PushData (list ):
@@ -30,19 +48,43 @@ def __repr__(self) -> str:
30
48
return f"PushData({ super ().__repr__ ()} )"
31
49
32
50
51
+ class Attribute (dict ):
52
+ """
53
+ A special type of map indicating data from a attribute response
54
+ """
55
+
56
+ def __repr__ (self ) -> str :
57
+ return f"Attribute({ super ().__repr__ ()} )"
58
+
59
+
33
60
class RespEncoder :
34
61
"""
35
- A class for simple RESP protocol encoding for unit tests
62
+ A class for simple RESP protocol encoder for unit tests
36
63
"""
37
64
38
- def __init__ (self , protocol : int = 2 , encoding : str = "utf-8" ) -> None :
65
+ def __init__ (
66
+ self , protocol : int = 2 , encoding : str = "utf-8" , errorhander = "strict"
67
+ ) -> None :
39
68
self .protocol = protocol
40
69
self .encoding = encoding
70
+ self .errorhandler = errorhander
71
+
72
+ def apply_encoding (self , value : str ) -> bytes :
73
+ return value .encode (self .encoding , errors = self .errorhandler )
74
+
75
+ def has_crnl (self , value : bytes ) -> bool :
76
+ """check if either cr or nl is in the value"""
77
+ return b"\r " in value or b"\n " in value
78
+
79
+ def escape_crln (self , value : bytes ) -> bytes :
80
+ """remove any cr or nl from the value"""
81
+ return value .replace (b"\r " , b"\\ r" ).replace (b"\n " , b"\\ n" )
41
82
42
83
def encode (self , data : Any , hint : Optional [str ] = None ) -> bytes :
43
84
if isinstance (data , dict ):
44
85
if self .protocol > 2 :
45
- result = f"%{ len (data )} \r \n " .encode ()
86
+ code = "|" if isinstance (data , Attribute ) else "%"
87
+ result = f"{ code } { len (data )} \r \n " .encode ()
46
88
for key , val in data .items ():
47
89
result += self .encode (key ) + self .encode (val )
48
90
return result
@@ -54,10 +96,8 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
54
96
return self .encode (mylist )
55
97
56
98
elif isinstance (data , list ):
57
- if isinstance (data , PushData ) and self .protocol > 2 :
58
- result = f">{ len (data )} \r \n " .encode ()
59
- else :
60
- result = f"*{ len (data )} \r \n " .encode ()
99
+ code = ">" if isinstance (data , PushData ) and self .protocol > 2 else "*"
100
+ result = f"{ code } { len (data )} \r \n " .encode ()
61
101
for val in data :
62
102
result += self .encode (val )
63
103
return result
@@ -71,11 +111,18 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
71
111
else :
72
112
return self .encode (list (data ))
73
113
114
+ elif isinstance (data , ErrorStr ):
115
+ enc = self .apply_encoding (str (data ))
116
+ if self .protocol > 2 :
117
+ if len (enc ) > 80 or self .has_crnl (enc ):
118
+ return f"!{ len (enc )} \r \n " .encode () + enc + b"\r \n "
119
+ return b"-" + self .escape_crln (enc ) + b"\r \n "
120
+
74
121
elif isinstance (data , str ):
75
- enc = data . encode ( self .encoding )
122
+ enc = self .apply_encoding ( data )
76
123
# long strings or strings with control characters must be encoded as bulk
77
124
# strings
78
- if hint or len (enc ) > 20 or b" \r " in enc or b" \n " in enc :
125
+ if hint or len (enc ) > 80 or self . has_crnl ( enc ) :
79
126
return self .encode_bulkstr (enc , hint )
80
127
return b"+" + enc + b"\r \n "
81
128
@@ -158,7 +205,7 @@ def resp_parse(
158
205
159
206
elif code == b"+" : # simple string
160
207
# we decode them automatically
161
- yield arg .decode (), rest
208
+ yield arg .decode (errors = "surrogateescape" ), rest
162
209
163
210
elif code == b"$" : # bulk string
164
211
count = int (arg )
@@ -168,8 +215,9 @@ def resp_parse(
168
215
assert incoming is not None
169
216
rest += incoming
170
217
bulkstr = rest [:count ]
171
- # bulk strings are not decoded, could contain binary data
172
- yield bulkstr , rest [expect :]
218
+ # we decode them automatically. Can be encoded
219
+ # back to binary if necessary with "surrogatescape"
220
+ yield bulkstr .decode (errors = "surrogateescape" ), rest [expect :]
173
221
174
222
elif code == b"=" : # verbatim strings
175
223
count = int (arg )
@@ -179,9 +227,9 @@ def resp_parse(
179
227
assert incoming is not None
180
228
rest += incoming
181
229
hint = rest [:3 ]
182
- result = rest [4 : (count + 4 )]
183
- # verbatim strings are not decoded, could contain binary data
184
- yield VerbatimString ( result , hint .decode ()), rest [expect :]
230
+ result = rest [4 : (count + 4 )]
231
+ yield VerbatimStr ( result . decode ( errors = "surrogateescape" ),
232
+ hint .decode ()), rest [expect :]
185
233
186
234
elif code in b"*>" : # array or push data
187
235
count = int (arg )
@@ -214,7 +262,7 @@ def resp_parse(
214
262
result_set .add (value )
215
263
yield result_set , rest
216
264
217
- elif code == b"%" : # map
265
+ elif code in b"%| " : # map or attribute
218
266
count = int (arg )
219
267
result_map = {}
220
268
for _ in range (count ):
@@ -232,10 +280,29 @@ def resp_parse(
232
280
parsed = parser .send (incoming )
233
281
value , rest = parsed
234
282
result_map [key ] = value
283
+ if code == b"|" :
284
+ yield Attribute (result_map ), rest
235
285
yield result_map , rest
286
+
287
+ elif code == b"-" : # error
288
+ # we decode them automatically
289
+ decoded = arg .decode (errors = "surrogateescape" )
290
+ code , value = decoded .split (" " , 1 )
291
+ yield ErrorStr (code , value ), rest
292
+
293
+ elif code == b"!" : # resp3 error
294
+ count = int (arg )
295
+ expect = count + 2 # +2 for the trailing CRNL
296
+ while len (rest ) < expect :
297
+ incoming = yield (None )
298
+ assert incoming is not None
299
+ rest += incoming
300
+ bulkstr = rest [:count ]
301
+ decoded = bulkstr .decode (errors = "surrogateescape" )
302
+ code , value = decoded .split (" " , 1 )
303
+ yield ErrorStr (code , value ), rest [expect :]
304
+
236
305
else :
237
- if code in b"-!" :
238
- raise NotImplementedError (f"resp opcode '{ code .decode ()} ' not implemented" )
239
306
raise ValueError (f"Unknown opcode '{ code .decode ()} '" )
240
307
241
308
0 commit comments