1
1
import itertools
2
+ from contextlib import closing
2
3
from types import NoneType
3
- from typing import Any , Optional
4
+ from typing import Any , Generator , List , Optional , Tuple , Union
5
+
6
+ CRNL = b"\r \n "
7
+
8
+
9
+ class VerbatimString (bytes ):
10
+ """
11
+ A string that is encoded as a resp3 verbatim string
12
+ """
13
+
14
+ def __new__ (cls , value : bytes , hint : str ) -> "VerbatimString" :
15
+ return bytes .__new__ (cls , value )
16
+
17
+ def __init__ (self , value : bytes , hint : str ) -> None :
18
+ self .hint = hint
19
+
20
+ def __repr__ (self ) -> str :
21
+ return f"VerbatimString({ super ().__repr__ ()} , { self .hint !r} )"
22
+
23
+
24
+ class PushData (list ):
25
+ """
26
+ A special type of list indicating data from a push response
27
+ """
28
+
29
+ def __repr__ (self ) -> str :
30
+ return f"PushData({ super ().__repr__ ()} )"
4
31
5
32
6
33
class RespEncoder :
7
34
"""
8
- A class for simple RESP protocol encodign for unit tests
35
+ A class for simple RESP protocol encoding for unit tests
9
36
"""
10
37
11
38
def __init__ (self , protocol : int = 2 , encoding : str = "utf-8" ) -> None :
@@ -27,7 +54,10 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
27
54
return self .encode (mylist )
28
55
29
56
elif isinstance (data , list ):
30
- result = f"*{ len (data )} \r \n " .encode ()
57
+ if isinstance (data , PushData ) and self .protocol > 2 :
58
+ result = f">{ len (data )} \r \n " .encode ()
59
+ else :
60
+ result = f"*{ len (data )} \r \n " .encode ()
31
61
for val in data :
32
62
result += self .encode (val )
33
63
return result
@@ -55,39 +85,243 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
55
85
elif isinstance (data , bool ):
56
86
if self .protocol == 2 :
57
87
return b":1\r \n " if data else b":0\r \n "
58
- else :
59
- return b"t\r \n " if data else b"f\r \n "
88
+ return b"t\r \n " if data else b"f\r \n "
60
89
61
90
elif isinstance (data , int ):
62
91
if (data > 2 ** 63 - 1 ) or (data < - (2 ** 63 )):
63
92
if self .protocol > 2 :
64
93
return f"({ data } \r \n " .encode () # resp3 big int
65
- else :
66
- return f"+{ data } \r \n " .encode () # force to simple string
94
+ return f"+{ data } \r \n " .encode () # force to simple string
67
95
return f":{ data } \r \n " .encode ()
68
96
elif isinstance (data , float ):
69
97
if self .protocol > 2 :
70
98
return f",{ data } \r \n " .encode () # resp3 double
71
- else :
72
- return f"+{ data } \r \n " .encode () # simple string
99
+ return f"+{ data } \r \n " .encode () # simple string
73
100
74
101
elif isinstance (data , NoneType ):
75
102
if self .protocol > 2 :
76
103
return b"_\r \n " # resp3 null
77
- else :
78
- return b"$-1\r \n " # Null bulk string
79
- # some commands return null array: b"*-1\r\n"
104
+ return b"$-1\r \n " # Null bulk string
105
+ # some commands return null array: b"*-1\r\n"
80
106
81
107
else :
82
- raise NotImplementedError
108
+ raise NotImplementedError ( f"encode not implemented for { type ( data ) } " )
83
109
84
110
def encode_bulkstr (self , bstr : bytes , hint : Optional [str ]) -> bytes :
85
111
if self .protocol > 2 and hint is not None :
86
112
# a resp3 verbatim string
87
113
return f"={ len (bstr )} \r \n { hint } :" .encode () + bstr + b"\r \n "
88
- else :
89
- return f"${ len (bstr )} \r \n " .encode () + bstr + b"\r \n "
114
+ # regular bulk string
115
+ return f"${ len (bstr )} \r \n " .encode () + bstr + b"\r \n "
90
116
91
117
92
118
def encode (value : Any , protocol : int = 2 , hint : Optional [str ] = None ) -> bytes :
119
+ """
120
+ Encode a value using the RESP protocol
121
+ """
93
122
return RespEncoder (protocol ).encode (value , hint )
123
+
124
+
125
+ # a stateful RESP parser implemented via a generator
126
+ def resp_parse (
127
+ buffer : bytes ,
128
+ ) -> Generator [Optional [Tuple [Any , bytes ]], Union [None , bytes ], None ]:
129
+ """
130
+ A stateful, generator based, RESP parser.
131
+ Returns a generator producing at most a single top-level primitive.
132
+ Yields tuple of (data_item, unparsed), or None if more data is needed.
133
+ It is fed more data with generator.send()
134
+ """
135
+ # Read the first line of resp or yield to get more data
136
+ while CRNL not in buffer :
137
+ incoming = yield None
138
+ assert incoming is not None
139
+ buffer += incoming
140
+ cmd , rest = buffer .split (CRNL , 1 )
141
+
142
+ code , arg = cmd [:1 ], cmd [1 :]
143
+
144
+ if code == b":" or code == b"(" : # integer, resp3 large int
145
+ yield int (arg ), rest
146
+
147
+ elif code == b"t" : # resp3 true
148
+ yield True , rest
149
+
150
+ elif code == b"f" : # resp3 false
151
+ yield False , rest
152
+
153
+ elif code == b"_" : # resp3 null
154
+ yield None , rest
155
+
156
+ elif code == b"," : # resp3 double
157
+ yield float (arg ), rest
158
+
159
+ elif code == b"+" : # simple string
160
+ # we decode them automatically
161
+ yield arg .decode (), rest
162
+
163
+ elif code == b"$" : # bulk string
164
+ count = int (arg )
165
+ expect = count + 2 # +2 for the trailing CRNL
166
+ while len (rest ) < expect :
167
+ incoming = yield (None )
168
+ assert incoming is not None
169
+ rest += incoming
170
+ bulkstr = rest [:count ]
171
+ # bulk strings are not decoded, could contain binary data
172
+ yield bulkstr , rest [expect :]
173
+
174
+ elif code == b"=" : # verbatim strings
175
+ count = int (arg )
176
+ expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL
177
+ while len (rest ) < expect :
178
+ incoming = yield (None )
179
+ assert incoming is not None
180
+ rest += incoming
181
+ hint = rest [:3 ]
182
+ result = rest [4 : (count + 4 )]
183
+ # verbatim strings are not decoded, could contain binary data
184
+ yield VerbatimString (result , hint .decode ()), rest [expect :]
185
+
186
+ elif code in b"*>" : # array or push data
187
+ count = int (arg )
188
+ result_array = []
189
+ for _ in range (count ):
190
+ # recursively parse the next array item
191
+ with closing (resp_parse (rest )) as parser :
192
+ parsed = parser .send (None )
193
+ while parsed is None :
194
+ incoming = yield None
195
+ parsed = parser .send (incoming )
196
+ value , rest = parsed
197
+ result_array .append (value )
198
+ if code == b">" :
199
+ yield PushData (result_array ), rest
200
+ else :
201
+ yield result_array , rest
202
+
203
+ elif code == b"~" : # set
204
+ count = int (arg )
205
+ result_set = set ()
206
+ for _ in range (count ):
207
+ # recursively parse the next set item
208
+ with closing (resp_parse (rest )) as parser :
209
+ parsed = parser .send (None )
210
+ while parsed is None :
211
+ incoming = yield None
212
+ parsed = parser .send (incoming )
213
+ value , rest = parsed
214
+ result_set .add (value )
215
+ yield result_set , rest
216
+
217
+ elif code == b"%" : # map
218
+ count = int (arg )
219
+ result_map = {}
220
+ for _ in range (count ):
221
+ # recursively parse the next key, and value
222
+ with closing (resp_parse (rest )) as parser :
223
+ parsed = parser .send (None )
224
+ while parsed is None :
225
+ incoming = yield None
226
+ parsed = parser .send (incoming )
227
+ key , rest = parsed
228
+ with closing (resp_parse (rest )) as parser :
229
+ parsed = parser .send (None )
230
+ while parsed is None :
231
+ incoming = yield None
232
+ parsed = parser .send (incoming )
233
+ value , rest = parsed
234
+ result_map [key ] = value
235
+ yield result_map , rest
236
+ else :
237
+ if code in b"-!" :
238
+ raise NotImplementedError (f"resp opcode '{ code .decode ()} ' not implemented" )
239
+ raise ValueError (f"Unknown opcode '{ code .decode ()} '" )
240
+
241
+
242
+ class NeedMoreData (RuntimeError ):
243
+ """
244
+ Raised when more data is needed to complete a parse
245
+ """
246
+
247
+
248
+ class RespParser :
249
+ """
250
+ A class for simple RESP protocol decoding for unit tests
251
+ """
252
+
253
+ def __init__ (self ) -> None :
254
+ self .parser : Optional [
255
+ Generator [Optional [Tuple [Any , bytes ]], Union [None , bytes ], None ]
256
+ ] = None
257
+ # which has not resulted in a parsed value
258
+ self .consumed : List [bytes ] = []
259
+
260
+ def parse (self , buffer : bytes ) -> Optional [Any ]:
261
+ """
262
+ Parse a buffer of data, return a tuple of a single top-level primitive and the
263
+ remaining buffer or raise NeedMoreData if more data is needed
264
+ """
265
+ if self .parser is None :
266
+ # create a new parser generator, initializing it with
267
+ # any unparsed data from previous calls
268
+ buffer = b"" .join (self .consumed ) + buffer
269
+ del self .consumed [:]
270
+ self .parser = resp_parse (buffer )
271
+ parsed = self .parser .send (None )
272
+ else :
273
+ # sen more data to the parser
274
+ parsed = self .parser .send (buffer )
275
+
276
+ if parsed is None :
277
+ self .consumed .append (buffer )
278
+ raise NeedMoreData ()
279
+
280
+ # got a value, close the parser, store the remaining buffer
281
+ self .parser .close ()
282
+ self .parser = None
283
+ value , remaining = parsed
284
+ self .consumed = [remaining ]
285
+ return value
286
+
287
+ def get_unparsed (self ) -> bytes :
288
+ return b"" .join (self .consumed )
289
+
290
+ def close (self ) -> None :
291
+ if self .parser is not None :
292
+ self .parser .close ()
293
+ self .parser = None
294
+ del self .consumed [:]
295
+
296
+
297
+ def parse_all (buffer : bytes ) -> Tuple [List [Any ], bytes ]:
298
+ """
299
+ Parse all the data in the buffer, returning the list of top-level objects and the
300
+ remaining buffer
301
+ """
302
+ with closing (RespParser ()) as parser :
303
+ result : List [Any ] = []
304
+ while True :
305
+ try :
306
+ result .append (parser .parse (buffer ))
307
+ buffer = b""
308
+ except NeedMoreData :
309
+ return result , parser .get_unparsed ()
310
+
311
+
312
+ def parse_chunks (buffers : List [bytes ]) -> Tuple [List [Any ], bytes ]:
313
+ """
314
+ Parse all the data in the buffers, returning the list of top-level objects and the
315
+ remaining buffer.
316
+ Used primarily for testing, since it will parse the data in chunks
317
+ """
318
+ result : List [Any ] = []
319
+ with closing (RespParser ()) as parser :
320
+ for buffer in buffers :
321
+ while True :
322
+ try :
323
+ result .append (parser .parse (buffer ))
324
+ buffer = b""
325
+ except NeedMoreData :
326
+ break
327
+ return result , parser .get_unparsed ()
0 commit comments