6
6
CRNL = b"\r \n "
7
7
8
8
9
- class VerbatimString ( bytes ):
9
+ class VerbatimStr ( str ):
10
10
"""
11
11
A string that is encoded as a resp3 verbatim string
12
12
"""
13
13
14
- def __new__ (cls , value : bytes , hint : str ) -> "VerbatimString " :
15
- return bytes .__new__ (cls , value )
14
+ def __new__ (cls , value : str , hint : str ) -> "VerbatimStr " :
15
+ return str .__new__ (cls , value )
16
16
17
- def __init__ (self , value : bytes , hint : str ) -> None :
17
+ def __init__ (self , value : str , hint : str ) -> None :
18
18
self .hint = hint
19
19
20
20
def __repr__ (self ) -> str :
21
- return f"VerbatimString({ super ().__repr__ ()} , { self .hint !r} )"
21
+ return f"VerbatimStr({ super ().__repr__ ()} , { self .hint !r} )"
22
+
23
+
24
+ class ErrorStr (str ):
25
+ """
26
+ A string to be encoded as a resp3 error
27
+ """
28
+
29
+ def __new__ (cls , code : str , value : str ) -> "ErrorStr" :
30
+ return str .__new__ (cls , value )
31
+
32
+ def __init__ (self , code : str , value : str ) -> None :
33
+ self .code = code .upper ()
34
+
35
+ def __repr__ (self ) -> str :
36
+ return f"ErrorString({ self .code !r} , { super ().__repr__ ()} )"
37
+
38
+ def __str__ (self ):
39
+ return f"{ self .code } { super ().__str__ ()} "
22
40
23
41
24
42
class PushData (list ):
@@ -30,19 +48,39 @@ def __repr__(self) -> str:
30
48
return f"PushData({ super ().__repr__ ()} )"
31
49
32
50
51
+ class Attribute (dict ):
52
+ """
53
+ A special type of map indicating data from a attribute response
54
+ """
55
+
56
+ def __repr__ (self ) -> str :
57
+ return f"Attribute({ super ().__repr__ ()} )"
58
+
59
+
33
60
class RespEncoder :
34
61
"""
35
- A class for simple RESP protocol encoding for unit tests
62
+ A class for simple RESP protocol encoder for unit tests
36
63
"""
37
64
38
- def __init__ (self , protocol : int = 2 , encoding : str = "utf-8" ) -> None :
65
+ def __init__ (self , protocol : int = 2 ) -> None :
39
66
self .protocol = protocol
40
- self .encoding = encoding
67
+
68
+ def has_crnl (self , value : bytes ) -> bool :
69
+ """check if either cr or nl is in the value"""
70
+ return b"\r " in value or b"\n " in value
71
+
72
+ def strip_crnl (self , value : bytes ) -> bytes :
73
+ """remove any cr or nl from the value"""
74
+ return value .replace (b"\r " , b"" ).replace (b"\n " , b"" )
75
+
76
+ def encodestrip (self , value : str ) -> bytes :
77
+ return self .strip_crnl (value .encode ())
41
78
42
79
def encode (self , data : Any , hint : Optional [str ] = None ) -> bytes :
43
80
if isinstance (data , dict ):
44
81
if self .protocol > 2 :
45
- result = f"%{ len (data )} \r \n " .encode ()
82
+ code = "|" if isinstance (data , Attribute ) else "%"
83
+ result = f"{ code } { len (data )} \r \n " .encode ()
46
84
for key , val in data .items ():
47
85
result += self .encode (key ) + self .encode (val )
48
86
return result
@@ -54,10 +92,8 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
54
92
return self .encode (mylist )
55
93
56
94
elif isinstance (data , list ):
57
- if isinstance (data , PushData ) and self .protocol > 2 :
58
- result = f">{ len (data )} \r \n " .encode ()
59
- else :
60
- result = f"*{ len (data )} \r \n " .encode ()
95
+ code = ">" if isinstance (data , PushData ) and self .protocol > 2 else "*"
96
+ result = f"{ code } { len (data )} \r \n " .encode ()
61
97
for val in data :
62
98
result += self .encode (val )
63
99
return result
@@ -71,11 +107,18 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
71
107
else :
72
108
return self .encode (list (data ))
73
109
110
+ elif isinstance (data , ErrorStr ):
111
+ enc = str (data ).encode ()
112
+ if self .protocol > 2 :
113
+ if len (enc ) > 80 or self .has_crnl (enc ):
114
+ return f"!{ len (enc )} \r \n " .encode () + enc + b"\r \n "
115
+ return b"-" + self .strip_crnl (enc ) + b"\r \n "
116
+
74
117
elif isinstance (data , str ):
75
- enc = data .encode (self . encoding )
118
+ enc = data .encode ()
76
119
# long strings or strings with control characters must be encoded as bulk
77
120
# strings
78
- if hint or len (enc ) > 20 or b" \r " in enc or b" \n " in enc :
121
+ if hint or len (enc ) > 80 or self . has_crnl ( enc ) :
79
122
return self .encode_bulkstr (enc , hint )
80
123
return b"+" + enc + b"\r \n "
81
124
@@ -158,7 +201,7 @@ def resp_parse(
158
201
159
202
elif code == b"+" : # simple string
160
203
# we decode them automatically
161
- yield arg .decode (), rest
204
+ yield arg .decode (errors = "surrogateescape" ), rest
162
205
163
206
elif code == b"$" : # bulk string
164
207
count = int (arg )
@@ -168,8 +211,9 @@ def resp_parse(
168
211
assert incoming is not None
169
212
rest += incoming
170
213
bulkstr = rest [:count ]
171
- # bulk strings are not decoded, could contain binary data
172
- yield bulkstr , rest [expect :]
214
+ # we decode them automatically. Can be encoded
215
+ # back to binary if necessary with "surrogatescape"
216
+ yield bulkstr .decode (errors = "surrogateescape" ), rest [expect :]
173
217
174
218
elif code == b"=" : # verbatim strings
175
219
count = int (arg )
@@ -179,9 +223,9 @@ def resp_parse(
179
223
assert incoming is not None
180
224
rest += incoming
181
225
hint = rest [:3 ]
182
- result = rest [4 : (count + 4 )]
183
- # verbatim strings are not decoded, could contain binary data
184
- yield VerbatimString ( result , hint .decode ()), rest [expect :]
226
+ result = rest [4 : (count + 4 )]
227
+ yield VerbatimStr ( result . decode ( errors = "surrogateescape" ),
228
+ hint .decode ()), rest [expect :]
185
229
186
230
elif code in b"*>" : # array or push data
187
231
count = int (arg )
@@ -214,7 +258,7 @@ def resp_parse(
214
258
result_set .add (value )
215
259
yield result_set , rest
216
260
217
- elif code == b"%" : # map
261
+ elif code in b"%| " : # map or attribute
218
262
count = int (arg )
219
263
result_map = {}
220
264
for _ in range (count ):
@@ -232,10 +276,29 @@ def resp_parse(
232
276
parsed = parser .send (incoming )
233
277
value , rest = parsed
234
278
result_map [key ] = value
279
+ if code == b"|" :
280
+ yield Attribute (result_map ), rest
235
281
yield result_map , rest
282
+
283
+ elif code == b"-" : # error
284
+ # we decode them automatically
285
+ decoded = arg .decode (errors = "surrogateescape" )
286
+ code , value = decoded .split (" " , 1 )
287
+ yield ErrorStr (code , value ), rest
288
+
289
+ elif code == b"!" : # resp3 error
290
+ count = int (arg )
291
+ expect = count + 2 # +2 for the trailing CRNL
292
+ while len (rest ) < expect :
293
+ incoming = yield (None )
294
+ assert incoming is not None
295
+ rest += incoming
296
+ bulkstr = rest [:count ]
297
+ decoded = bulkstr .decode (errors = "surrogateescape" )
298
+ code , value = decoded .split (" " , 1 )
299
+ yield ErrorStr (code , value ), rest [expect :]
300
+
236
301
else :
237
- if code in b"-!" :
238
- raise NotImplementedError (f"resp opcode '{ code .decode ()} ' not implemented" )
239
302
raise ValueError (f"Unknown opcode '{ code .decode ()} '" )
240
303
241
304
0 commit comments