Skip to content

Commit 8868195

Browse files
committed
Add errors, use strings
1 parent c891648 commit 8868195

File tree

2 files changed

+138
-34
lines changed

2 files changed

+138
-34
lines changed

tests/resp.py

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,37 @@
66
CRNL = b"\r\n"
77

88

9-
class VerbatimString(bytes):
9+
class VerbatimStr(str):
1010
"""
1111
A string that is encoded as a resp3 verbatim string
1212
"""
1313

14-
def __new__(cls, value: bytes, hint: str) -> "VerbatimString":
15-
return bytes.__new__(cls, value)
14+
def __new__(cls, value: str, hint: str) -> "VerbatimStr":
15+
return str.__new__(cls, value)
1616

17-
def __init__(self, value: bytes, hint: str) -> None:
17+
def __init__(self, value: str, hint: str) -> None:
1818
self.hint = hint
1919

2020
def __repr__(self) -> str:
21-
return f"VerbatimString({super().__repr__()}, {self.hint!r})"
21+
return f"VerbatimStr({super().__repr__()}, {self.hint!r})"
22+
23+
24+
class ErrorStr(str):
25+
"""
26+
A string to be encoded as a resp3 error
27+
"""
28+
29+
def __new__(cls, code: str, value: str) -> "ErrorStr":
30+
return str.__new__(cls, value)
31+
32+
def __init__(self, code: str, value: str) -> None:
33+
self.code = code.upper()
34+
35+
def __repr__(self) -> str:
36+
return f"ErrorString({self.code!r}, {super().__repr__()})"
37+
38+
def __str__(self):
39+
return f"{self.code} {super().__str__()}"
2240

2341

2442
class PushData(list):
@@ -30,19 +48,39 @@ def __repr__(self) -> str:
3048
return f"PushData({super().__repr__()})"
3149

3250

51+
class Attribute(dict):
52+
"""
53+
A special type of map indicating data from a attribute response
54+
"""
55+
56+
def __repr__(self) -> str:
57+
return f"Attribute({super().__repr__()})"
58+
59+
3360
class RespEncoder:
3461
"""
35-
A class for simple RESP protocol encoding for unit tests
62+
A class for simple RESP protocol encoder for unit tests
3663
"""
3764

38-
def __init__(self, protocol: int = 2, encoding: str = "utf-8") -> None:
65+
def __init__(self, protocol: int = 2) -> None:
3966
self.protocol = protocol
40-
self.encoding = encoding
67+
68+
def has_crnl(self, value: bytes) -> bool:
69+
"""check if either cr or nl is in the value"""
70+
return b"\r" in value or b"\n" in value
71+
72+
def strip_crnl(self, value: bytes) -> bytes:
73+
"""remove any cr or nl from the value"""
74+
return value.replace(b"\r", b"").replace(b"\n", b"")
75+
76+
def encodestrip(self, value: str) -> bytes:
77+
return self.strip_crnl(value.encode())
4178

4279
def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
4380
if isinstance(data, dict):
4481
if self.protocol > 2:
45-
result = f"%{len(data)}\r\n".encode()
82+
code = "|" if isinstance(data, Attribute) else "%"
83+
result = f"{code}{len(data)}\r\n".encode()
4684
for key, val in data.items():
4785
result += self.encode(key) + self.encode(val)
4886
return result
@@ -54,10 +92,8 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
5492
return self.encode(mylist)
5593

5694
elif isinstance(data, list):
57-
if isinstance(data, PushData) and self.protocol > 2:
58-
result = f">{len(data)}\r\n".encode()
59-
else:
60-
result = f"*{len(data)}\r\n".encode()
95+
code = ">" if isinstance(data, PushData) and self.protocol > 2 else "*"
96+
result = f"{code}{len(data)}\r\n".encode()
6197
for val in data:
6298
result += self.encode(val)
6399
return result
@@ -71,11 +107,18 @@ def encode(self, data: Any, hint: Optional[str] = None) -> bytes:
71107
else:
72108
return self.encode(list(data))
73109

110+
elif isinstance(data, ErrorStr):
111+
enc = str(data).encode()
112+
if self.protocol > 2:
113+
if len(enc) > 80 or self.has_crnl(enc):
114+
return f"!{len(enc)}\r\n".encode() + enc + b"\r\n"
115+
return b"-" + self.strip_crnl(enc) + b"\r\n"
116+
74117
elif isinstance(data, str):
75-
enc = data.encode(self.encoding)
118+
enc = data.encode()
76119
# long strings or strings with control characters must be encoded as bulk
77120
# strings
78-
if hint or len(enc) > 20 or b"\r" in enc or b"\n" in enc:
121+
if hint or len(enc) > 80 or self.has_crnl(enc):
79122
return self.encode_bulkstr(enc, hint)
80123
return b"+" + enc + b"\r\n"
81124

@@ -158,7 +201,7 @@ def resp_parse(
158201

159202
elif code == b"+": # simple string
160203
# we decode them automatically
161-
yield arg.decode(), rest
204+
yield arg.decode(errors="surrogateescape"), rest
162205

163206
elif code == b"$": # bulk string
164207
count = int(arg)
@@ -168,8 +211,9 @@ def resp_parse(
168211
assert incoming is not None
169212
rest += incoming
170213
bulkstr = rest[:count]
171-
# bulk strings are not decoded, could contain binary data
172-
yield bulkstr, rest[expect:]
214+
# we decode them automatically. Can be encoded
215+
# back to binary if necessary with "surrogatescape"
216+
yield bulkstr.decode(errors="surrogateescape"), rest[expect:]
173217

174218
elif code == b"=": # verbatim strings
175219
count = int(arg)
@@ -179,9 +223,9 @@ def resp_parse(
179223
assert incoming is not None
180224
rest += incoming
181225
hint = rest[:3]
182-
result = rest[4 : (count + 4)]
183-
# verbatim strings are not decoded, could contain binary data
184-
yield VerbatimString(result, hint.decode()), rest[expect:]
226+
result = rest[4: (count + 4)]
227+
yield VerbatimStr(result.decode(errors="surrogateescape"),
228+
hint.decode()), rest[expect:]
185229

186230
elif code in b"*>": # array or push data
187231
count = int(arg)
@@ -214,7 +258,7 @@ def resp_parse(
214258
result_set.add(value)
215259
yield result_set, rest
216260

217-
elif code == b"%": # map
261+
elif code in b"%|": # map or attribute
218262
count = int(arg)
219263
result_map = {}
220264
for _ in range(count):
@@ -232,10 +276,29 @@ def resp_parse(
232276
parsed = parser.send(incoming)
233277
value, rest = parsed
234278
result_map[key] = value
279+
if code == b"|":
280+
yield Attribute(result_map), rest
235281
yield result_map, rest
282+
283+
elif code == b"-": # error
284+
# we decode them automatically
285+
decoded = arg.decode(errors="surrogateescape")
286+
code, value = decoded.split(" ", 1)
287+
yield ErrorStr(code, value), rest
288+
289+
elif code == b"!": # resp3 error
290+
count = int(arg)
291+
expect = count + 2 # +2 for the trailing CRNL
292+
while len(rest) < expect:
293+
incoming = yield (None)
294+
assert incoming is not None
295+
rest += incoming
296+
bulkstr = rest[:count]
297+
decoded = bulkstr.decode(errors="surrogateescape")
298+
code, value = decoded.split(" ", 1)
299+
yield ErrorStr(code, value), rest[expect:]
300+
236301
else:
237-
if code in b"-!":
238-
raise NotImplementedError(f"resp opcode '{code.decode()}' not implemented")
239302
raise ValueError(f"Unknown opcode '{code.decode()}'")
240303

241304

tests/test_resp.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import pytest
22

3-
from .resp import PushData, VerbatimString, encode, parse_all, parse_chunks
3+
from .resp import (
4+
Attribute,
5+
ErrorStr,
6+
PushData,
7+
VerbatimStr,
8+
encode,
9+
parse_all,
10+
parse_chunks,
11+
)
412

513

614
@pytest.fixture(params=[2, 3])
@@ -13,9 +21,9 @@ def test_simple_str(self):
1321
assert encode("foo") == b"+foo\r\n"
1422

1523
def test_long_str(self):
16-
text = "fooling around with the sword in the mud"
17-
assert len(text) == 40
18-
assert encode(text) == b"$40\r\n" + text.encode() + b"\r\n"
24+
text = 3 * "fooling around with the sword in the mud"
25+
assert len(text) == 120
26+
assert encode(text) == b"$120\r\n" + text.encode() + b"\r\n"
1927

2028
# test strings with control characters
2129
def test_str_with_ctrl_chars(self):
@@ -66,6 +74,13 @@ def test_map(self, resp_version):
6674
else:
6775
assert data == b"%2\r\n:1\r\n:2\r\n:3\r\n:4\r\n"
6876

77+
def test_attribute(self, resp_version):
78+
data = encode(Attribute({1: 2, 3: 4}), protocol=resp_version)
79+
if resp_version == 2:
80+
assert data == b"*4\r\n:1\r\n:2\r\n:3\r\n:4\r\n"
81+
else:
82+
assert data == b"|2\r\n:1\r\n:2\r\n:3\r\n:4\r\n"
83+
6984
def test_nested_array(self):
7085
assert encode([1, [2, 3]]) == b"*2\r\n:1\r\n*2\r\n:2\r\n:3\r\n"
7186

@@ -103,6 +118,14 @@ def test_bool(self, resp_version):
103118
else:
104119
assert data == b"f\r\n"
105120

121+
def test_errorstr(self, resp_version):
122+
err = ErrorStr("foo", "bar\r\nbaz")
123+
data = encode(err, protocol=resp_version)
124+
if resp_version == 2:
125+
assert data == b"-FOO barbaz\r\n"
126+
else:
127+
assert data == b"!12\r\nFOO bar\r\nbaz\r\n"
128+
106129

107130
@pytest.mark.parametrize("chunk_size", [0, 1, 2, -2])
108131
class TestParser:
@@ -111,7 +134,7 @@ def breakup_bytes(self, data, chunk_size=2):
111134
if chunk_size < 0:
112135
insert_empty = True
113136
chunk_size = -chunk_size
114-
chunks = [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)]
137+
chunks = [data[i: i + chunk_size] for i in range(0, len(data), chunk_size)]
115138
if insert_empty:
116139
empty = len(chunks) * [b""]
117140
chunks = [item for pair in zip(chunks, empty) for item in pair]
@@ -154,7 +177,7 @@ def test_incomplete_list(self, chunk_size):
154177
def test_invalid_token(self, chunk_size):
155178
with pytest.raises(ValueError):
156179
self.parse_data(chunk_size, b")foo\r\n")
157-
with pytest.raises(NotImplementedError):
180+
with pytest.raises(ValueError):
158181
self.parse_data(chunk_size, b"!foo\r\n")
159182

160183
def test_multiple_ints(self, chunk_size):
@@ -185,12 +208,30 @@ def test_simple_string(self, chunk_size):
185208

186209
def test_bulk_string(self, chunk_size):
187210
parsed = parse_all(b"$3\r\nfoo\r\nbar")
188-
assert parsed == ([b"foo"], b"bar")
211+
assert parsed == (["foo"], b"bar")
189212

190213
def test_bulk_string_with_ctrl_chars(self, chunk_size):
191214
parsed = self.parse_data(chunk_size, b"$8\r\nfoo\r\nbar\r\n")
192-
assert parsed == ([b"foo\r\nbar"], b"")
215+
assert parsed == (["foo\r\nbar"], b"")
193216

194-
def test_verbatim_string(self, chunk_size):
217+
def test_verbatimstr(self, chunk_size):
195218
parsed = self.parse_data(chunk_size, b"=3\r\ntxt:foo\r\nbar")
196-
assert parsed == ([VerbatimString(b"foo", "txt")], b"bar")
219+
assert parsed == ([VerbatimStr("foo", "txt")], b"bar")
220+
221+
def test_errorstr(self, chunk_size):
222+
parsed = self.parse_data(chunk_size, b"-FOO bar\r\nbaz")
223+
assert parsed == ([ErrorStr("foo", "bar")], b"baz")
224+
225+
def test_errorstr_resp3(self, chunk_size):
226+
parsed = self.parse_data(chunk_size, b"!12\r\nFOO bar\r\nbaz\r\n")
227+
assert parsed == ([ErrorStr("foo", "bar\r\nbaz")], b"")
228+
229+
def test_attribute_map(self, chunk_size):
230+
parsed = self.parse_data(chunk_size, b"|2\r\n:1\r\n:2\r\n:3\r\n:4\r\n")
231+
assert parsed == ([Attribute({1: 2, 3: 4})], b"")
232+
233+
def test_surrogateescape(self, chunk_size):
234+
data = b"foo\xff"
235+
parsed = self.parse_data(chunk_size, b"$4\r\n" + data + b"\r\nbar")
236+
assert parsed == ([data.decode(errors="surrogateescape")], b"bar")
237+
assert parsed[0][0].encode("utf-8", "surrogateescape") == data

0 commit comments

Comments
 (0)