Skip to content

Commit c891648

Browse files
committed
Add resp parser, tests
1 parent 3bc0165 commit c891648

File tree

2 files changed

+350
-17
lines changed

2 files changed

+350
-17
lines changed

tests/resp.py

Lines changed: 249 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,38 @@
11
import itertools
2+
from contextlib import closing
23
from types import NoneType
3-
from typing import Any, Optional
4+
from typing import Any, Generator, List, Optional, Tuple, Union
5+
6+
CRNL = b"\r\n"
7+
8+
9+
class VerbatimString(bytes):
10+
"""
11+
A string that is encoded as a resp3 verbatim string
12+
"""
13+
14+
def __new__(cls, value: bytes, hint: str) -> "VerbatimString":
15+
return bytes.__new__(cls, value)
16+
17+
def __init__(self, value: bytes, hint: str) -> None:
18+
self.hint = hint
19+
20+
def __repr__(self) -> str:
21+
return f"VerbatimString({super().__repr__()}, {self.hint!r})"
22+
23+
24+
class PushData(list):
25+
"""
26+
A special type of list indicating data from a push response
27+
"""
28+
29+
def __repr__(self) -> str:
30+
return f"PushData({super().__repr__()})"
431

532

633
class RespEncoder:
734
"""
8-
A class for simple RESP protocol encodign for unit tests
35+
A class for simple RESP protocol encoding for unit tests
936
"""
1037

1138
def __init__(self, protocol: int = 2, encoding: str = "utf-8") -> None:
@@ -27,7 +54,10 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
2754
return self.encode(mylist)
2855

2956
elif isinstance(data, list):
30-
result = f"*{len(data)}\r\n".encode()
57+
if isinstance(data, PushData) and self.protocol > 2:
58+
result = f">{len(data)}\r\n".encode()
59+
else:
60+
result = f"*{len(data)}\r\n".encode()
3161
for val in data:
3262
result += self.encode(val)
3363
return result
@@ -55,39 +85,243 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
5585
elif isinstance(data, bool):
5686
if self.protocol == 2:
5787
return b":1\r\n" if data else b":0\r\n"
58-
else:
59-
return b"t\r\n" if data else b"f\r\n"
88+
return b"t\r\n" if data else b"f\r\n"
6089

6190
elif isinstance(data, int):
6291
if (data > 2**63 - 1) or (data < -(2**63)):
6392
if self.protocol > 2:
6493
return f"({data}\r\n".encode() # resp3 big int
65-
else:
66-
return f"+{data}\r\n".encode() # force to simple string
94+
return f"+{data}\r\n".encode() # force to simple string
6795
return f":{data}\r\n".encode()
6896
elif isinstance(data, float):
6997
if self.protocol > 2:
7098
return f",{data}\r\n".encode() # resp3 double
71-
else:
72-
return f"+{data}\r\n".encode() # simple string
99+
return f"+{data}\r\n".encode() # simple string
73100

74101
elif isinstance(data, NoneType):
75102
if self.protocol > 2:
76103
return b"_\r\n" # resp3 null
77-
else:
78-
return b"$-1\r\n" # Null bulk string
79-
# some commands return null array: b"*-1\r\n"
104+
return b"$-1\r\n" # Null bulk string
105+
# some commands return null array: b"*-1\r\n"
80106

81107
else:
82-
raise NotImplementedError
108+
raise NotImplementedError(f"encode not implemented for {type(data)}")
83109

84110
def encode_bulkstr(self, bstr: bytes, hint: Optional[str]) -> bytes:
85111
if self.protocol > 2 and hint is not None:
86112
# a resp3 verbatim string
87113
return f"={len(bstr)}\r\n{hint}:".encode() + bstr + b"\r\n"
88-
else:
89-
return f"${len(bstr)}\r\n".encode() + bstr + b"\r\n"
114+
# regular bulk string
115+
return f"${len(bstr)}\r\n".encode() + bstr + b"\r\n"
90116

91117

92118
def encode(value: Any, protocol: int = 2, hint: Optional[str] = None) -> bytes:
119+
"""
120+
Encode a value using the RESP protocol
121+
"""
93122
return RespEncoder(protocol).encode(value, hint)
123+
124+
125+
# a stateful RESP parser implemented via a generator
126+
def resp_parse(
127+
buffer: bytes,
128+
) -> Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None]:
129+
"""
130+
A stateful, generator based, RESP parser.
131+
Returns a generator producing at most a single top-level primitive.
132+
Yields tuple of (data_item, unparsed), or None if more data is needed.
133+
It is fed more data with generator.send()
134+
"""
135+
# Read the first line of resp or yield to get more data
136+
while CRNL not in buffer:
137+
incoming = yield None
138+
assert incoming is not None
139+
buffer += incoming
140+
cmd, rest = buffer.split(CRNL, 1)
141+
142+
code, arg = cmd[:1], cmd[1:]
143+
144+
if code == b":" or code == b"(": # integer, resp3 large int
145+
yield int(arg), rest
146+
147+
elif code == b"t": # resp3 true
148+
yield True, rest
149+
150+
elif code == b"f": # resp3 false
151+
yield False, rest
152+
153+
elif code == b"_": # resp3 null
154+
yield None, rest
155+
156+
elif code == b",": # resp3 double
157+
yield float(arg), rest
158+
159+
elif code == b"+": # simple string
160+
# we decode them automatically
161+
yield arg.decode(), rest
162+
163+
elif code == b"$": # bulk string
164+
count = int(arg)
165+
expect = count + 2 # +2 for the trailing CRNL
166+
while len(rest) < expect:
167+
incoming = yield (None)
168+
assert incoming is not None
169+
rest += incoming
170+
bulkstr = rest[:count]
171+
# bulk strings are not decoded, could contain binary data
172+
yield bulkstr, rest[expect:]
173+
174+
elif code == b"=": # verbatim strings
175+
count = int(arg)
176+
expect = count + 4 + 2 # 4 type and colon +2 for the trailing CRNL
177+
while len(rest) < expect:
178+
incoming = yield (None)
179+
assert incoming is not None
180+
rest += incoming
181+
hint = rest[:3]
182+
result = rest[4 : (count + 4)]
183+
# verbatim strings are not decoded, could contain binary data
184+
yield VerbatimString(result, hint.decode()), rest[expect:]
185+
186+
elif code in b"*>": # array or push data
187+
count = int(arg)
188+
result_array = []
189+
for _ in range(count):
190+
# recursively parse the next array item
191+
with closing(resp_parse(rest)) as parser:
192+
parsed = parser.send(None)
193+
while parsed is None:
194+
incoming = yield None
195+
parsed = parser.send(incoming)
196+
value, rest = parsed
197+
result_array.append(value)
198+
if code == b">":
199+
yield PushData(result_array), rest
200+
else:
201+
yield result_array, rest
202+
203+
elif code == b"~": # set
204+
count = int(arg)
205+
result_set = set()
206+
for _ in range(count):
207+
# recursively parse the next set item
208+
with closing(resp_parse(rest)) as parser:
209+
parsed = parser.send(None)
210+
while parsed is None:
211+
incoming = yield None
212+
parsed = parser.send(incoming)
213+
value, rest = parsed
214+
result_set.add(value)
215+
yield result_set, rest
216+
217+
elif code == b"%": # map
218+
count = int(arg)
219+
result_map = {}
220+
for _ in range(count):
221+
# recursively parse the next key, and value
222+
with closing(resp_parse(rest)) as parser:
223+
parsed = parser.send(None)
224+
while parsed is None:
225+
incoming = yield None
226+
parsed = parser.send(incoming)
227+
key, rest = parsed
228+
with closing(resp_parse(rest)) as parser:
229+
parsed = parser.send(None)
230+
while parsed is None:
231+
incoming = yield None
232+
parsed = parser.send(incoming)
233+
value, rest = parsed
234+
result_map[key] = value
235+
yield result_map, rest
236+
else:
237+
if code in b"-!":
238+
raise NotImplementedError(f"resp opcode '{code.decode()}' not implemented")
239+
raise ValueError(f"Unknown opcode '{code.decode()}'")
240+
241+
242+
class NeedMoreData(RuntimeError):
243+
"""
244+
Raised when more data is needed to complete a parse
245+
"""
246+
247+
248+
class RespParser:
249+
"""
250+
A class for simple RESP protocol decoding for unit tests
251+
"""
252+
253+
def __init__(self) -> None:
254+
self.parser: Optional[
255+
Generator[Optional[Tuple[Any, bytes]], Union[None, bytes], None]
256+
] = None
257+
# which has not resulted in a parsed value
258+
self.consumed: List[bytes] = []
259+
260+
def parse(self, buffer: bytes) -> Optional[Any]:
261+
"""
262+
Parse a buffer of data, return a tuple of a single top-level primitive and the
263+
remaining buffer or raise NeedMoreData if more data is needed
264+
"""
265+
if self.parser is None:
266+
# create a new parser generator, initializing it with
267+
# any unparsed data from previous calls
268+
buffer = b"".join(self.consumed) + buffer
269+
del self.consumed[:]
270+
self.parser = resp_parse(buffer)
271+
parsed = self.parser.send(None)
272+
else:
273+
# sen more data to the parser
274+
parsed = self.parser.send(buffer)
275+
276+
if parsed is None:
277+
self.consumed.append(buffer)
278+
raise NeedMoreData()
279+
280+
# got a value, close the parser, store the remaining buffer
281+
self.parser.close()
282+
self.parser = None
283+
value, remaining = parsed
284+
self.consumed = [remaining]
285+
return value
286+
287+
def get_unparsed(self) -> bytes:
288+
return b"".join(self.consumed)
289+
290+
def close(self) -> None:
291+
if self.parser is not None:
292+
self.parser.close()
293+
self.parser = None
294+
del self.consumed[:]
295+
296+
297+
def parse_all(buffer: bytes) -> Tuple[List[Any], bytes]:
298+
"""
299+
Parse all the data in the buffer, returning the list of top-level objects and the
300+
remaining buffer
301+
"""
302+
with closing(RespParser()) as parser:
303+
result: List[Any] = []
304+
while True:
305+
try:
306+
result.append(parser.parse(buffer))
307+
buffer = b""
308+
except NeedMoreData:
309+
return result, parser.get_unparsed()
310+
311+
312+
def parse_chunks(buffers: List[bytes]) -> Tuple[List[Any], bytes]:
313+
"""
314+
Parse all the data in the buffers, returning the list of top-level objects and the
315+
remaining buffer.
316+
Used primarily for testing, since it will parse the data in chunks
317+
"""
318+
result: List[Any] = []
319+
with closing(RespParser()) as parser:
320+
for buffer in buffers:
321+
while True:
322+
try:
323+
result.append(parser.parse(buffer))
324+
buffer = b""
325+
except NeedMoreData:
326+
break
327+
return result, parser.get_unparsed()

0 commit comments

Comments
 (0)