Skip to content

Commit 9843e39

Browse files
committed
Add a class around the parser, to hold parsing rules.
1 parent fd56ac8 commit 9843e39

File tree

1 file changed

+166
-138
lines changed

1 file changed

+166
-138
lines changed

tests/resp.py

Lines changed: 166 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -169,141 +169,167 @@ def encode(value: Any, protocol: int = 2, hint: Optional[str] = None) -> bytes:
169169
return RespEncoder(protocol).encode(value, hint)
170170

171171

172-
# a stateful RESP parser implemented via a generator
173-
def resp_parse(
174-
buffer: bytes,
175-
) -> Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None]:
172+
class RespGeneratorParser:
176173
"""
177-
A stateful, generator based, RESP parser.
178-
Returns a generator producing at most a single top-level primitive.
179-
Yields tuple of (data_item, unparsed), or None if more data is needed.
180-
It is fed more data with generator.send()
174+
A wrapper class around a stateful RESP parsing generator,
175+
allowing custom string decoding rules.
181176
"""
182-
# Read the first line of resp or yield to get more data
183-
while CRNL not in buffer:
184-
incoming = yield None
185-
assert incoming is not None
186-
buffer += incoming
187-
cmd, rest = buffer.split(CRNL, 1)
188177

189-
code, arg = cmd[:1], cmd[1:]
190-
191-
if code == b":" or code == b"(": # integer, resp3 large int
192-
yield int(arg), rest
193-
194-
elif code == b"t": # resp3 true
195-
yield True, rest
196-
197-
elif code == b"f": # resp3 false
198-
yield False, rest
199-
200-
elif code == b"_": # resp3 null
201-
yield None, rest
202-
203-
elif code == b",": # resp3 double
204-
yield float(arg), rest
178+
def __init__(self, encoding: str = "utf-8", errorhandler: str = "surrogateescape"):
179+
"""
180+
Create a new parser, optionally specifying the encoding and errorhandler.
181+
If `encoding` is None, bytes will be returned as-is.
182+
The default settings are utf-8 encoding and surrogateescape errorhandler,
183+
which can decode all possible byte sequences,
184+
allowing decoded data to be re-encoded back to bytes.
185+
"""
186+
self.encoding = encoding
187+
self.errorhandler = errorhandler
205188

206-
elif code == b"+": # simple string
207-
# we decode them automatically
208-
yield arg.decode(errors="surrogateescape"), rest
189+
def decode_bytes(self, data: bytes) -> str:
190+
"""
191+
decode the data as a string,
192+
"""
193+
return data.decode(self.encoding, errors=self.errorhandler)
209194

210-
elif code == b"$": # bulk string
211-
count = int(arg)
212-
expect = count + 2 # +2 for the trailing CRNL
213-
while len(rest) < expect:
214-
incoming = yield (None)
215-
assert incoming is not None
216-
rest += incoming
217-
bulkstr = rest[:count]
218-
# we decode them automatically. Can be encoded
219-
# back to binary if necessary with "surrogatescape"
220-
yield bulkstr.decode(errors="surrogateescape"), rest[expect:]
221-
222-
elif code == b"=": # verbatim strings
223-
count = int(arg)
224-
expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL
225-
while len(rest) < expect:
226-
incoming = yield (None)
227-
assert incoming is not None
228-
rest += incoming
229-
hint = rest[:3]
230-
result = rest[4: (count + 4)]
231-
yield VerbatimStr(result.decode(errors="surrogateescape"),
232-
hint.decode()), rest[expect:]
233-
234-
elif code in b"*>": # array or push data
235-
count = int(arg)
236-
result_array = []
237-
for _ in range(count):
238-
# recursively parse the next array item
239-
with closing(resp_parse(rest)) as parser:
240-
parsed = parser.send(None)
241-
while parsed is None:
242-
incoming = yield None
243-
parsed = parser.send(incoming)
244-
value, rest = parsed
245-
result_array.append(value)
246-
if code == b">":
247-
yield PushData(result_array), rest
248-
else:
249-
yield result_array, rest
250-
251-
elif code == b"~": # set
252-
count = int(arg)
253-
result_set = set()
254-
for _ in range(count):
255-
# recursively parse the next set item
256-
with closing(resp_parse(rest)) as parser:
257-
parsed = parser.send(None)
258-
while parsed is None:
259-
incoming = yield None
260-
parsed = parser.send(incoming)
261-
value, rest = parsed
262-
result_set.add(value)
263-
yield result_set, rest
264-
265-
elif code in b"%|": # map or attribute
266-
count = int(arg)
267-
result_map = {}
268-
for _ in range(count):
269-
# recursively parse the next key, and value
270-
with closing(resp_parse(rest)) as parser:
271-
parsed = parser.send(None)
272-
while parsed is None:
273-
incoming = yield None
274-
parsed = parser.send(incoming)
275-
key, rest = parsed
276-
with closing(resp_parse(rest)) as parser:
277-
parsed = parser.send(None)
278-
while parsed is None:
279-
incoming = yield None
280-
parsed = parser.send(incoming)
281-
value, rest = parsed
282-
result_map[key] = value
283-
if code == b"|":
284-
yield Attribute(result_map), rest
285-
yield result_map, rest
286-
287-
elif code == b"-": # error
288-
# we decode them automatically
289-
decoded = arg.decode(errors="surrogateescape")
290-
code, value = decoded.split(" ", 1)
291-
yield ErrorStr(code, value), rest
292-
293-
elif code == b"!": # resp3 error
294-
count = int(arg)
295-
expect = count + 2 # +2 for the trailing CRNL
296-
while len(rest) < expect:
297-
incoming = yield (None)
195+
# a stateful RESP parser implemented via a generator
196+
def parse(
197+
self,
198+
buffer: bytes,
199+
) -> Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None]:
200+
"""
201+
A stateful, generator based, RESP parser.
202+
Returns a generator producing at most a single top-level primitive.
203+
Yields tuple of (data_item, unparsed), or None if more data is needed.
204+
It is fed more data with generator.send()
205+
"""
206+
# Read the first line of resp or yield to get more data
207+
while CRNL not in buffer:
208+
incoming = yield None
298209
assert incoming is not None
299-
rest += incoming
300-
bulkstr = rest[:count]
301-
decoded = bulkstr.decode(errors="surrogateescape")
302-
code, value = decoded.split(" ", 1)
303-
yield ErrorStr(code, value), rest[expect:]
210+
buffer += incoming
211+
cmd, rest = buffer.split(CRNL, 1)
212+
213+
code, arg = cmd[:1], cmd[1:]
214+
215+
if code == b":" or code == b"(": # integer, resp3 large int
216+
yield int(arg), rest
217+
218+
elif code == b"t": # resp3 true
219+
yield True, rest
220+
221+
elif code == b"f": # resp3 false
222+
yield False, rest
223+
224+
elif code == b"_": # resp3 null
225+
yield None, rest
226+
227+
elif code == b",": # resp3 double
228+
yield float(arg), rest
229+
230+
elif code == b"+": # simple string
231+
# we decode them automatically
232+
yield self.decode_bytes(arg), rest
233+
234+
elif code == b"$": # bulk string
235+
count = int(arg)
236+
expect = count + 2 # +2 for the trailing CRNL
237+
while len(rest) < expect:
238+
incoming = yield (None)
239+
assert incoming is not None
240+
rest += incoming
241+
bulkstr = rest[:count]
242+
yield self.decode_bytes(bulkstr), rest[expect:]
243+
244+
elif code == b"=": # verbatim strings
245+
count = int(arg)
246+
expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL
247+
while len(rest) < expect:
248+
incoming = yield (None)
249+
assert incoming is not None
250+
rest += incoming
251+
string = self.decode_bytes(rest[: (count + 4)])
252+
if string[3] != ":":
253+
raise ValueError(f"Expected colon after hint, got {bulkstr[3]}")
254+
hint = string[:3]
255+
string = string[4 : (count + 4)]
256+
yield VerbatimStr(string, hint), rest[expect:]
257+
258+
elif code in b"*>": # array or push data
259+
count = int(arg)
260+
result_array = []
261+
for _ in range(count):
262+
# recursively parse the next array item
263+
with closing(self.parse(rest)) as parser:
264+
parsed = parser.send(None)
265+
while parsed is None:
266+
incoming = yield None
267+
parsed = parser.send(incoming)
268+
value, rest = parsed
269+
result_array.append(value)
270+
if code == b">":
271+
yield PushData(result_array), rest
272+
else:
273+
yield result_array, rest
274+
275+
elif code == b"~": # set
276+
count = int(arg)
277+
result_set = set()
278+
for _ in range(count):
279+
# recursively parse the next set item
280+
with closing(self.parse(rest)) as parser:
281+
parsed = parser.send(None)
282+
while parsed is None:
283+
incoming = yield None
284+
parsed = parser.send(incoming)
285+
value, rest = parsed
286+
result_set.add(value)
287+
yield result_set, rest
288+
289+
elif code in b"%|": # map or attribute
290+
count = int(arg)
291+
result_map = {}
292+
for _ in range(count):
293+
# recursively parse the next key, and value
294+
with closing(self.parse(rest)) as parser:
295+
parsed = parser.send(None)
296+
while parsed is None:
297+
incoming = yield None
298+
parsed = parser.send(incoming)
299+
key, rest = parsed
300+
with closing(self.parse(rest)) as parser:
301+
parsed = parser.send(None)
302+
while parsed is None:
303+
incoming = yield None
304+
parsed = parser.send(incoming)
305+
value, rest = parsed
306+
result_map[key] = value
307+
if code == b"|":
308+
yield Attribute(result_map), rest
309+
yield result_map, rest
310+
311+
elif code == b"-": # error
312+
# we decode them automatically
313+
decoded = self.decode_bytes(arg)
314+
assert isinstance(decoded, str)
315+
code, value = decoded.split(" ", 1)
316+
yield ErrorStr(code, value), rest
317+
318+
elif code == b"!": # resp3 error
319+
count = int(arg)
320+
expect = count + 2 # +2 for the trailing CRNL
321+
while len(rest) < expect:
322+
incoming = yield (None)
323+
assert incoming is not None
324+
rest += incoming
325+
bulkstr = rest[:count]
326+
decoded = self.decode_bytes(bulkstr)
327+
assert isinstance(decoded, str)
328+
code, value = decoded.split(" ", 1)
329+
yield ErrorStr(code, value), rest[expect:]
304330

305-
else:
306-
raise ValueError(f"Unknown opcode '{code.decode()}'")
331+
else:
332+
raise ValueError(f"Unknown opcode '{code.decode()}'")
307333

308334

309335
class NeedMoreData(RuntimeError):
@@ -315,10 +341,12 @@ class NeedMoreData(RuntimeError):
315341
class RespParser:
316342
"""
317343
A class for simple RESP protocol decoding for unit tests
344+
Uses a RespGeneratorParser to produce data.
318345
"""
319346

320347
def __init__(self) -> None:
321-
self.parser: Optional[
348+
self.parser = RespGeneratorParser()
349+
self.generator: Optional[
322350
Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None]
323351
] = None
324352
# which has not resulted in a parsed value
@@ -329,24 +357,24 @@ def parse(self, buffer: bytes) -> Optional[Any]:
329357
Parse a buffer of data, return a tuple of a single top-level primitive and the
330358
remaining buffer or raise NeedMoreData if more data is needed
331359
"""
332-
if self.parser is None:
360+
if self.generator is None:
333361
# create a new parser generator, initializing it with
334362
# any unparsed data from previous calls
335363
buffer = b"".join(self.consumed) + buffer
336364
del self.consumed[:]
337-
self.parser = resp_parse(buffer)
338-
parsed = self.parser.send(None)
365+
self.generator = self.parser.parse(buffer)
366+
parsed = self.generator.send(None)
339367
else:
340368
# sen more data to the parser
341-
parsed = self.parser.send(buffer)
369+
parsed = self.generator.send(buffer)
342370

343371
if parsed is None:
344372
self.consumed.append(buffer)
345373
raise NeedMoreData()
346374

347375
# got a value, close the parser, store the remaining buffer
348-
self.parser.close()
349-
self.parser = None
376+
self.generator.close()
377+
self.generator = None
350378
value, remaining = parsed
351379
self.consumed = [remaining]
352380
return value
@@ -355,9 +383,9 @@ def get_unparsed(self) -> bytes:
355383
return b"".join(self.consumed)
356384

357385
def close(self) -> None:
358-
if self.parser is not None:
359-
self.parser.close()
360-
self.parser = None
386+
if self.generator is not None:
387+
self.generator.close()
388+
self.generator = None
361389
del self.consumed[:]
362390

363391

0 commit comments

Comments
 (0)