Skip to content

Commit ce65339

Browse files
committed
Add errors, use strings
1 parent c891648 commit ce65339

File tree

2 files changed

+140
-32
lines changed

2 files changed

+140
-32
lines changed

tests/resp.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,37 @@
66
CRNL = b"\r\n"
77

88

9-
class VerbatimString(bytes):
9+
class VerbatimStr(str):
1010
"""
1111
A string that is encoded as a resp3 verbatim string
1212
"""
1313

14-
def __new__(cls, value: bytes, hint: str) -> "VerbatimString":
15-
return bytes.__new__(cls, value)
14+
def __new__(cls, value: str, hint: str) -> "VerbatimStr":
15+
return str.__new__(cls, value)
1616

17-
def __init__(self, value: bytes, hint: str) -> None:
17+
def __init__(self, value: str, hint: str) -> None:
1818
self.hint = hint
1919

2020
def __repr__(self) -> str:
21-
return f"VerbatimString({super().__repr__()}, {self.hint!r})"
21+
return f"VerbatimStr({super().__repr__()}, {self.hint!r})"
22+
23+
24+
class ErrorStr(str):
25+
"""
26+
A string to be encoded as a resp3 error
27+
"""
28+
29+
def __new__(cls, code: str, value: str) -> "ErrorStr":
30+
return str.__new__(cls, value)
31+
32+
def __init__(self, code: str, value: str) -> None:
33+
self.code = code.upper()
34+
35+
def __repr__(self) -> str:
36+
return f"ErrorString({self.code!r}, {super().__repr__()})"
37+
38+
def __str__(self):
39+
return f"{self.code} {super().__str__()}"
2240

2341

2442
class PushData(list):
@@ -30,19 +48,43 @@ def __repr__(self) -> str:
3048
return f"PushData({super().__repr__()})"
3149

3250

51+
class Attribute(dict):
52+
"""
53+
A special type of map indicating data from a attribute response
54+
"""
55+
56+
def __repr__(self) -> str:
57+
return f"Attribute({super().__repr__()})"
58+
59+
3360
class RespEncoder:
3461
"""
35-
A class for simple RESP protocol encoding for unit tests
62+
A class for simple RESP protocol encoder for unit tests
3663
"""
3764

38-
def __init__(self, protocol: int = 2, encoding: str = "utf-8") -> None:
65+
def __init__(
66+
self, protocol: int = 2, encoding: str = "utf-8", errorhander="strict"
67+
) -> None:
3968
self.protocol = protocol
4069
self.encoding = encoding
70+
self.errorhandler = errorhander
71+
72+
def apply_encoding(self, value: str) -> bytes:
73+
return value.encode(self.encoding, errors=self.errorhandler)
74+
75+
def has_crnl(self, value: bytes) -> bool:
76+
"""check if either cr or nl is in the value"""
77+
return b"\r" in value or b"\n" in value
78+
79+
def escape_crln(self, value: bytes) -> bytes:
80+
"""remove any cr or nl from the value"""
81+
return value.replace(b"\r", b"\\r").replace(b"\n", b"\\n")
4182

4283
def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
4384
if isinstance(data, dict):
4485
if self.protocol > 2:
45-
result = f"%{len(data)}\r\n".encode()
86+
code = "|" if isinstance(data, Attribute) else "%"
87+
result = f"{code}{len(data)}\r\n".encode()
4688
for key, val in data.items():
4789
result += self.encode(key) + self.encode(val)
4890
return result
@@ -54,10 +96,8 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
5496
return self.encode(mylist)
5597

5698
elif isinstance(data, list):
57-
if isinstance(data, PushData) and self.protocol > 2:
58-
result = f">{len(data)}\r\n".encode()
59-
else:
60-
result = f"*{len(data)}\r\n".encode()
99+
code = ">" if isinstance(data, PushData) and self.protocol > 2 else "*"
100+
result = f"{code}{len(data)}\r\n".encode()
61101
for val in data:
62102
result += self.encode(val)
63103
return result
@@ -71,11 +111,18 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
71111
else:
72112
return self.encode(list(data))
73113

114+
elif isinstance(data, ErrorStr):
115+
enc = self.apply_encoding(str(data))
116+
if self.protocol > 2:
117+
if len(enc) > 80 or self.has_crnl(enc):
118+
return f"!{len(enc)}\r\n".encode() + enc + b"\r\n"
119+
return b"-" + self.escape_crln(enc) + b"\r\n"
120+
74121
elif isinstance(data, str):
75-
enc = data.encode(self.encoding)
122+
enc = self.apply_encoding(data)
76123
# long strings or strings with control characters must be encoded as bulk
77124
# strings
78-
if hint or len(enc) > 20 or b"\r" in enc or b"\n" in enc:
125+
if hint or len(enc) > 80 or self.has_crnl(enc):
79126
return self.encode_bulkstr(enc, hint)
80127
return b"+" + enc + b"\r\n"
81128

@@ -158,7 +205,7 @@ def resp_parse(
158205

159206
elif code == b"+": # simple string
160207
# we decode them automatically
161-
yield arg.decode(), rest
208+
yield arg.decode(errors="surrogateescape"), rest
162209

163210
elif code == b"$": # bulk string
164211
count = int(arg)
@@ -168,8 +215,9 @@ def resp_parse(
168215
assert incoming is not None
169216
rest += incoming
170217
bulkstr = rest[:count]
171-
# bulk strings are not decoded, could contain binary data
172-
yield bulkstr, rest[expect:]
218+
# we decode them automatically. Can be encoded
219+
# back to binary if necessary with "surrogatescape"
220+
yield bulkstr.decode(errors="surrogateescape"), rest[expect:]
173221

174222
elif code == b"=": # verbatim strings
175223
count = int(arg)
@@ -179,9 +227,9 @@ def resp_parse(
179227
assert incoming is not None
180228
rest += incoming
181229
hint = rest[:3]
182-
result = rest[4 : (count + 4)]
183-
# verbatim strings are not decoded, could contain binary data
184-
yield VerbatimString(result, hint.decode()), rest[expect:]
230+
result = rest[4: (count + 4)]
231+
yield VerbatimStr(result.decode(errors="surrogateescape"),
232+
hint.decode()), rest[expect:]
185233

186234
elif code in b"*>": # array or push data
187235
count = int(arg)
@@ -214,7 +262,7 @@ def resp_parse(
214262
result_set.add(value)
215263
yield result_set, rest
216264

217-
elif code == b"%": # map
265+
elif code in b"%|": # map or attribute
218266
count = int(arg)
219267
result_map = {}
220268
for _ in range(count):
@@ -232,10 +280,29 @@ def resp_parse(
232280
parsed = parser.send(incoming)
233281
value, rest = parsed
234282
result_map[key] = value
283+
if code == b"|":
284+
yield Attribute(result_map), rest
235285
yield result_map, rest
286+
287+
elif code == b"-": # error
288+
# we decode them automatically
289+
decoded = arg.decode(errors="surrogateescape")
290+
code, value = decoded.split(" ", 1)
291+
yield ErrorStr(code, value), rest
292+
293+
elif code == b"!": # resp3 error
294+
count = int(arg)
295+
expect = count + 2 # +2 for the trailing CRNL
296+
while len(rest) < expect:
297+
incoming = yield (None)
298+
assert incoming is not None
299+
rest += incoming
300+
bulkstr = rest[:count]
301+
decoded = bulkstr.decode(errors="surrogateescape")
302+
code, value = decoded.split(" ", 1)
303+
yield ErrorStr(code, value), rest[expect:]
304+
236305
else:
237-
if code in b"-!":
238-
raise NotImplementedError(f"resp opcode '{code.decode()}' not implemented")
239306
raise ValueError(f"Unknown opcode '{code.decode()}'")
240307

241308

tests/test_resp.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import pytest
22

3-
from .resp import PushData, VerbatimString, encode, parse_all, parse_chunks
3+
from .resp import (
4+
Attribute,
5+
ErrorStr,
6+
PushData,
7+
VerbatimStr,
8+
encode,
9+
parse_all,
10+
parse_chunks,
11+
)
412

513

614
@pytest.fixture(params=[2, 3])
@@ -13,9 +21,9 @@ def test_simple_str(self):
1321
assert encode("foo") == b"+foo\r\n"
1422

1523
def test_long_str(self):
16-
text = "fooling around with the sword in the mud"
17-
assert len(text) == 40
18-
assert encode(text) == b"$40\r\n" + text.encode() + b"\r\n"
24+
text = 3 * "fooling around with the sword in the mud"
25+
assert len(text) == 120
26+
assert encode(text) == b"$120\r\n" + text.encode() + b"\r\n"
1927

2028
# test strings with control characters
2129
def test_str_with_ctrl_chars(self):
@@ -66,6 +74,13 @@ def test_map(self, resp_version):
6674
else:
6775
assert data == b"%2\r\n:1\r\n:2\r\n:3\r\n:4\r\n"
6876

77+
def test_attribute(self, resp_version):
78+
data = encode(Attribute({1: 2, 3: 4}), protocol=resp_version)
79+
if resp_version == 2:
80+
assert data == b"*4\r\n:1\r\n:2\r\n:3\r\n:4\r\n"
81+
else:
82+
assert data == b"|2\r\n:1\r\n:2\r\n:3\r\n:4\r\n"
83+
6984
def test_nested_array(self):
7085
assert encode([1, [2, 3]]) == b"*2\r\n:1\r\n*2\r\n:2\r\n:3\r\n"
7186

@@ -103,6 +118,14 @@ def test_bool(self, resp_version):
103118
else:
104119
assert data == b"f\r\n"
105120

121+
def test_errorstr(self, resp_version):
122+
err = ErrorStr("foo", "bar\r\nbaz")
123+
data = encode(err, protocol=resp_version)
124+
if resp_version == 2:
125+
assert data == b"-FOO bar\\r\\nbaz\r\n"
126+
else:
127+
assert data == b"!12\r\nFOO bar\r\nbaz\r\n"
128+
106129

107130
@pytest.mark.parametrize("chunk_size", [0, 1, 2, -2])
108131
class TestParser:
@@ -154,7 +177,7 @@ def test_incomplete_list(self, chunk_size):
154177
def test_invalid_token(self, chunk_size):
155178
with pytest.raises(ValueError):
156179
self.parse_data(chunk_size, b")foo\r\n")
157-
with pytest.raises(NotImplementedError):
180+
with pytest.raises(ValueError):
158181
self.parse_data(chunk_size, b"!foo\r\n")
159182

160183
def test_multiple_ints(self, chunk_size):
@@ -185,12 +208,30 @@ def test_simple_string(self, chunk_size):
185208

186209
def test_bulk_string(self, chunk_size):
187210
parsed = parse_all(b"$3\r\nfoo\r\nbar")
188-
assert parsed == ([b"foo"], b"bar")
211+
assert parsed == (["foo"], b"bar")
189212

190213
def test_bulk_string_with_ctrl_chars(self, chunk_size):
191214
parsed = self.parse_data(chunk_size, b"$8\r\nfoo\r\nbar\r\n")
192-
assert parsed == ([b"foo\r\nbar"], b"")
215+
assert parsed == (["foo\r\nbar"], b"")
193216

194-
def test_verbatim_string(self, chunk_size):
217+
def test_verbatimstr(self, chunk_size):
195218
parsed = self.parse_data(chunk_size, b"=3\r\ntxt:foo\r\nbar")
196-
assert parsed == ([VerbatimString(b"foo", "txt")], b"bar")
219+
assert parsed == ([VerbatimStr("foo", "txt")], b"bar")
220+
221+
def test_errorstr(self, chunk_size):
222+
parsed = self.parse_data(chunk_size, b"-FOO bar\r\nbaz")
223+
assert parsed == ([ErrorStr("foo", "bar")], b"baz")
224+
225+
def test_errorstr_resp3(self, chunk_size):
226+
parsed = self.parse_data(chunk_size, b"!12\r\nFOO bar\r\nbaz\r\n")
227+
assert parsed == ([ErrorStr("foo", "bar\r\nbaz")], b"")
228+
229+
def test_attribute_map(self, chunk_size):
230+
parsed = self.parse_data(chunk_size, b"|2\r\n:1\r\n:2\r\n:3\r\n:4\r\n")
231+
assert parsed == ([Attribute({1: 2, 3: 4})], b"")
232+
233+
def test_surrogateescape(self, chunk_size):
234+
data = b"foo\xff"
235+
parsed = self.parse_data(chunk_size, b"$4\r\n" + data + b"\r\nbar")
236+
assert parsed == ([data.decode(errors="surrogateescape")], b"bar")
237+
assert parsed[0][0].encode("utf-8", "surrogateescape") == data

0 commit comments

Comments
 (0)