@@ -169,141 +169,167 @@ def encode(value: Any, protocol: int = 2, hint: Optional[str] = None) -> bytes:
169
169
return RespEncoder (protocol ).encode (value , hint )
170
170
171
171
172
- # a stateful RESP parser implemented via a generator
173
- def resp_parse (
174
- buffer : bytes ,
175
- ) -> Generator [Optional [Tuple [Any , bytes ]], Union [None , bytes ], None ]:
172
+ class RespGeneratorParser :
176
173
"""
177
- A stateful, generator based, RESP parser.
178
- Returns a generator producing at most a single top-level primitive.
179
- Yields tuple of (data_item, unparsed), or None if more data is needed.
180
- It is fed more data with generator.send()
174
+ A wrapper class around a stateful RESP parsing generator,
175
+ allowing custom string decoding rules.
181
176
"""
182
- # Read the first line of resp or yield to get more data
183
- while CRNL not in buffer :
184
- incoming = yield None
185
- assert incoming is not None
186
- buffer += incoming
187
- cmd , rest = buffer .split (CRNL , 1 )
188
177
189
- code , arg = cmd [:1 ], cmd [1 :]
190
-
191
- if code == b":" or code == b"(" : # integer, resp3 large int
192
- yield int (arg ), rest
193
-
194
- elif code == b"t" : # resp3 true
195
- yield True , rest
196
-
197
- elif code == b"f" : # resp3 false
198
- yield False , rest
199
-
200
- elif code == b"_" : # resp3 null
201
- yield None , rest
202
-
203
- elif code == b"," : # resp3 double
204
- yield float (arg ), rest
178
+ def __init__ (self , encoding : str = "utf-8" , errorhandler : str = "surrogateescape" ):
179
+ """
180
+ Create a new parser, optionally specifying the encoding and errorhandler.
181
+ If `encoding` is None, bytes will be returned as-is.
182
+ The default settings are utf-8 encoding and surrogateescape errorhandler,
183
+ which can decode all possible byte sequences,
184
+ allowing decoded data to be re-encoded back to bytes.
185
+ """
186
+ self .encoding = encoding
187
+ self .errorhandler = errorhandler
205
188
206
- elif code == b"+" : # simple string
207
- # we decode them automatically
208
- yield arg .decode (errors = "surrogateescape" ), rest
189
+ def decode_bytes (self , data : bytes ) -> str :
190
+ """
191
+ decode the data as a string,
192
+ """
193
+ return data .decode (self .encoding , errors = self .errorhandler )
209
194
210
- elif code == b"$" : # bulk string
211
- count = int (arg )
212
- expect = count + 2 # +2 for the trailing CRNL
213
- while len (rest ) < expect :
214
- incoming = yield (None )
215
- assert incoming is not None
216
- rest += incoming
217
- bulkstr = rest [:count ]
218
- # we decode them automatically. Can be encoded
219
- # back to binary if necessary with "surrogatescape"
220
- yield bulkstr .decode (errors = "surrogateescape" ), rest [expect :]
221
-
222
- elif code == b"=" : # verbatim strings
223
- count = int (arg )
224
- expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL
225
- while len (rest ) < expect :
226
- incoming = yield (None )
227
- assert incoming is not None
228
- rest += incoming
229
- hint = rest [:3 ]
230
- result = rest [4 : (count + 4 )]
231
- yield VerbatimStr (result .decode (errors = "surrogateescape" ),
232
- hint .decode ()), rest [expect :]
233
-
234
- elif code in b"*>" : # array or push data
235
- count = int (arg )
236
- result_array = []
237
- for _ in range (count ):
238
- # recursively parse the next array item
239
- with closing (resp_parse (rest )) as parser :
240
- parsed = parser .send (None )
241
- while parsed is None :
242
- incoming = yield None
243
- parsed = parser .send (incoming )
244
- value , rest = parsed
245
- result_array .append (value )
246
- if code == b">" :
247
- yield PushData (result_array ), rest
248
- else :
249
- yield result_array , rest
250
-
251
- elif code == b"~" : # set
252
- count = int (arg )
253
- result_set = set ()
254
- for _ in range (count ):
255
- # recursively parse the next set item
256
- with closing (resp_parse (rest )) as parser :
257
- parsed = parser .send (None )
258
- while parsed is None :
259
- incoming = yield None
260
- parsed = parser .send (incoming )
261
- value , rest = parsed
262
- result_set .add (value )
263
- yield result_set , rest
264
-
265
- elif code in b"%|" : # map or attribute
266
- count = int (arg )
267
- result_map = {}
268
- for _ in range (count ):
269
- # recursively parse the next key, and value
270
- with closing (resp_parse (rest )) as parser :
271
- parsed = parser .send (None )
272
- while parsed is None :
273
- incoming = yield None
274
- parsed = parser .send (incoming )
275
- key , rest = parsed
276
- with closing (resp_parse (rest )) as parser :
277
- parsed = parser .send (None )
278
- while parsed is None :
279
- incoming = yield None
280
- parsed = parser .send (incoming )
281
- value , rest = parsed
282
- result_map [key ] = value
283
- if code == b"|" :
284
- yield Attribute (result_map ), rest
285
- yield result_map , rest
286
-
287
- elif code == b"-" : # error
288
- # we decode them automatically
289
- decoded = arg .decode (errors = "surrogateescape" )
290
- code , value = decoded .split (" " , 1 )
291
- yield ErrorStr (code , value ), rest
292
-
293
- elif code == b"!" : # resp3 error
294
- count = int (arg )
295
- expect = count + 2 # +2 for the trailing CRNL
296
- while len (rest ) < expect :
297
- incoming = yield (None )
195
+ # a stateful RESP parser implemented via a generator
196
+ def parse (
197
+ self ,
198
+ buffer : bytes ,
199
+ ) -> Generator [Optional [Tuple [Any , bytes ]], Union [None , bytes ], None ]:
200
+ """
201
+ A stateful, generator based, RESP parser.
202
+ Returns a generator producing at most a single top-level primitive.
203
+ Yields tuple of (data_item, unparsed), or None if more data is needed.
204
+ It is fed more data with generator.send()
205
+ """
206
+ # Read the first line of resp or yield to get more data
207
+ while CRNL not in buffer :
208
+ incoming = yield None
298
209
assert incoming is not None
299
- rest += incoming
300
- bulkstr = rest [:count ]
301
- decoded = bulkstr .decode (errors = "surrogateescape" )
302
- code , value = decoded .split (" " , 1 )
303
- yield ErrorStr (code , value ), rest [expect :]
210
+ buffer += incoming
211
+ cmd , rest = buffer .split (CRNL , 1 )
212
+
213
+ code , arg = cmd [:1 ], cmd [1 :]
214
+
215
+ if code == b":" or code == b"(" : # integer, resp3 large int
216
+ yield int (arg ), rest
217
+
218
+ elif code == b"t" : # resp3 true
219
+ yield True , rest
220
+
221
+ elif code == b"f" : # resp3 false
222
+ yield False , rest
223
+
224
+ elif code == b"_" : # resp3 null
225
+ yield None , rest
226
+
227
+ elif code == b"," : # resp3 double
228
+ yield float (arg ), rest
229
+
230
+ elif code == b"+" : # simple string
231
+ # we decode them automatically
232
+ yield self .decode_bytes (arg ), rest
233
+
234
+ elif code == b"$" : # bulk string
235
+ count = int (arg )
236
+ expect = count + 2 # +2 for the trailing CRNL
237
+ while len (rest ) < expect :
238
+ incoming = yield (None )
239
+ assert incoming is not None
240
+ rest += incoming
241
+ bulkstr = rest [:count ]
242
+ yield self .decode_bytes (bulkstr ), rest [expect :]
243
+
244
+ elif code == b"=" : # verbatim strings
245
+ count = int (arg )
246
+ expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL
247
+ while len (rest ) < expect :
248
+ incoming = yield (None )
249
+ assert incoming is not None
250
+ rest += incoming
251
+ string = self .decode_bytes (rest [: (count + 4 )])
252
+ if string [3 ] != ":" :
253
+ raise ValueError (f"Expected colon after hint, got { bulkstr [3 ]} " )
254
+ hint = string [:3 ]
255
+ string = string [4 : (count + 4 )]
256
+ yield VerbatimStr (string , hint ), rest [expect :]
257
+
258
+ elif code in b"*>" : # array or push data
259
+ count = int (arg )
260
+ result_array = []
261
+ for _ in range (count ):
262
+ # recursively parse the next array item
263
+ with closing (self .parse (rest )) as parser :
264
+ parsed = parser .send (None )
265
+ while parsed is None :
266
+ incoming = yield None
267
+ parsed = parser .send (incoming )
268
+ value , rest = parsed
269
+ result_array .append (value )
270
+ if code == b">" :
271
+ yield PushData (result_array ), rest
272
+ else :
273
+ yield result_array , rest
274
+
275
+ elif code == b"~" : # set
276
+ count = int (arg )
277
+ result_set = set ()
278
+ for _ in range (count ):
279
+ # recursively parse the next set item
280
+ with closing (self .parse (rest )) as parser :
281
+ parsed = parser .send (None )
282
+ while parsed is None :
283
+ incoming = yield None
284
+ parsed = parser .send (incoming )
285
+ value , rest = parsed
286
+ result_set .add (value )
287
+ yield result_set , rest
288
+
289
+ elif code in b"%|" : # map or attribute
290
+ count = int (arg )
291
+ result_map = {}
292
+ for _ in range (count ):
293
+ # recursively parse the next key, and value
294
+ with closing (self .parse (rest )) as parser :
295
+ parsed = parser .send (None )
296
+ while parsed is None :
297
+ incoming = yield None
298
+ parsed = parser .send (incoming )
299
+ key , rest = parsed
300
+ with closing (self .parse (rest )) as parser :
301
+ parsed = parser .send (None )
302
+ while parsed is None :
303
+ incoming = yield None
304
+ parsed = parser .send (incoming )
305
+ value , rest = parsed
306
+ result_map [key ] = value
307
+ if code == b"|" :
308
+ yield Attribute (result_map ), rest
309
+ yield result_map , rest
310
+
311
+ elif code == b"-" : # error
312
+ # we decode them automatically
313
+ decoded = self .decode_bytes (arg )
314
+ assert isinstance (decoded , str )
315
+ code , value = decoded .split (" " , 1 )
316
+ yield ErrorStr (code , value ), rest
317
+
318
+ elif code == b"!" : # resp3 error
319
+ count = int (arg )
320
+ expect = count + 2 # +2 for the trailing CRNL
321
+ while len (rest ) < expect :
322
+ incoming = yield (None )
323
+ assert incoming is not None
324
+ rest += incoming
325
+ bulkstr = rest [:count ]
326
+ decoded = self .decode_bytes (bulkstr )
327
+ assert isinstance (decoded , str )
328
+ code , value = decoded .split (" " , 1 )
329
+ yield ErrorStr (code , value ), rest [expect :]
304
330
305
- else :
306
- raise ValueError (f"Unknown opcode '{ code .decode ()} '" )
331
+ else :
332
+ raise ValueError (f"Unknown opcode '{ code .decode ()} '" )
307
333
308
334
309
335
class NeedMoreData (RuntimeError ):
@@ -315,10 +341,12 @@ class NeedMoreData(RuntimeError):
315
341
class RespParser :
316
342
"""
317
343
A class for simple RESP protocol decoding for unit tests
344
+ Uses a RespGeneratorParser to produce data.
318
345
"""
319
346
320
347
def __init__ (self ) -> None :
321
- self .parser : Optional [
348
+ self .parser = RespGeneratorParser ()
349
+ self .generator : Optional [
322
350
Generator [Optional [Tuple [Any , bytes ]], Union [None , bytes ], None ]
323
351
] = None
324
352
# which has not resulted in a parsed value
@@ -329,24 +357,24 @@ def parse(self, buffer: bytes) -> Optional[Any]:
329
357
Parse a buffer of data, return a tuple of a single top-level primitive and the
330
358
remaining buffer or raise NeedMoreData if more data is needed
331
359
"""
332
- if self .parser is None :
360
+ if self .generator is None :
333
361
# create a new parser generator, initializing it with
334
362
# any unparsed data from previous calls
335
363
buffer = b"" .join (self .consumed ) + buffer
336
364
del self .consumed [:]
337
- self .parser = resp_parse (buffer )
338
- parsed = self .parser .send (None )
365
+ self .generator = self . parser . parse (buffer )
366
+ parsed = self .generator .send (None )
339
367
else :
340
368
# sen more data to the parser
341
- parsed = self .parser .send (buffer )
369
+ parsed = self .generator .send (buffer )
342
370
343
371
if parsed is None :
344
372
self .consumed .append (buffer )
345
373
raise NeedMoreData ()
346
374
347
375
# got a value, close the parser, store the remaining buffer
348
- self .parser .close ()
349
- self .parser = None
376
+ self .generator .close ()
377
+ self .generator = None
350
378
value , remaining = parsed
351
379
self .consumed = [remaining ]
352
380
return value
@@ -355,9 +383,9 @@ def get_unparsed(self) -> bytes:
355
383
return b"" .join (self .consumed )
356
384
357
385
def close (self ) -> None :
358
- if self .parser is not None :
359
- self .parser .close ()
360
- self .parser = None
386
+ if self .generator is not None :
387
+ self .generator .close ()
388
+ self .generator = None
361
389
del self .consumed [:]
362
390
363
391
0 commit comments