Skip to content

Commit 5b8ed8e

Browse files
committed
expand the RespServer
1 parent 11500ab commit 5b8ed8e

File tree

1 file changed

+103
-15
lines changed

1 file changed

+103
-15
lines changed

tests/resp.py

Lines changed: 103 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def decode_bytes(self, data: bytes) -> str:
192192
"""
193193
return data.decode(self.encoding, errors=self.errorhandler)
194194

195-
# a stateful RESP parser implemented via a generator
196195
def parse(
197196
self,
198197
buffer: bytes,
@@ -340,8 +339,9 @@ class NeedMoreData(RuntimeError):
340339

341340
class RespParser:
342341
"""
343-
A class for simple RESP protocol decoding for unit tests
344-
Uses a RespGeneratorParser to produce data.
342+
A class for simple RESP protocol decoding for unit tests.
343+
Uses a RespGeneratorParser to produce data, and can
344+
produce top-level objects for as long as there is data available.
345345
"""
346346

347347
def __init__(self) -> None:
@@ -355,7 +355,8 @@ def __init__(self) -> None:
355355
def parse(self, buffer: bytes) -> Optional[Any]:
356356
"""
357357
Parse a buffer of data, return a tuple of a single top-level primitive and the
358-
remaining buffer or raise NeedMoreData if more data is needed
358+
remaining buffer or raise NeedMoreData if more data is needed to
359+
produce a value.
359360
"""
360361
if self.generator is None:
361362
# create a new parser generator, initializing it with
@@ -372,7 +373,7 @@ def parse(self, buffer: bytes) -> Optional[Any]:
372373
self.consumed.append(buffer)
373374
raise NeedMoreData()
374375

375-
# got a value, close the parser, store the remaining buffer
376+
# got a value, close the generator, store the remaining buffer
376377
self.generator.close()
377378
self.generator = None
378379
value, remaining = parsed
@@ -427,18 +428,105 @@ class RespServer:
427428
Accepts RESP commands and returns RESP responses.
428429
"""
429430

430-
_CLIENT_NAME = "test-suite-client"
431-
_SUCCESS_RESP = b"+OK" + CRNL
432-
_ERROR_RESP = b"-ERR" + CRNL
433-
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP}
431+
handlers = {}
432+
433+
def __init__(self):
434+
self.protocol = 2
435+
self.server_ver = self.get_server_version()
436+
self.auth = []
437+
self.client_name = ""
438+
439+
# patchable methods for testing
440+
441+
def get_server_version(self):
442+
return 6
443+
444+
def on_auth(self, auth):
445+
pass
446+
447+
def on_setname(self, name):
448+
pass
449+
450+
def on_protocol(self, proto):
451+
pass
434452

435453
def command(self, cmd: Any) -> bytes:
436454
"""Process a single command and return the response"""
455+
result = self._command(cmd)
456+
return RespEncoder(self.protocol).encode(result)
457+
458+
def _command(self, cmd: Any) -> Any:
437459
if not isinstance(cmd, list):
438-
return f"-ERR unknown command {cmd!r}\r\n".encode()
460+
return ErrorStr("ERR", "unknown command {cmd!r}")
461+
462+
# handle registered commands
463+
command = cmd[0].upper()
464+
args = cmd[1:]
465+
if command in self.handlers:
466+
return self.handlers[command](self, args)
467+
468+
return ErrorStr("ERR", "unknown command {cmd!r}")
469+
470+
def handle_auth(self, args):
471+
self.auth = args[:]
472+
self.on_auth(self.auth)
473+
expect = 2 if self.server_ver >= 6 else 1
474+
if len(args) != expect:
475+
return ErrorStr("ERR", "wrong number of arguments" " for 'AUTH' command")
476+
return "OK"
477+
478+
handlers["AUTH"] = handle_auth
479+
480+
def handle_client(self, args):
481+
if args[0] == "SETNAME":
482+
return self.handle_setname(args[1:])
483+
return ErrorStr("ERR", "unknown subcommand or wrong number of arguments")
484+
485+
handlers["CLIENT"] = handle_client
486+
487+
def handle_setname(self, args):
488+
if len(args) != 1:
489+
return ErrorStr("ERR", "wrong number of arguments")
490+
self.client_name = args[0]
491+
self.on_setname(self.client_name)
492+
return "OK"
493+
494+
def handle_hello(self, args):
495+
if self.server_ver < 6:
496+
return ErrorStr("ERR", "unknown command 'HELLO'")
497+
proto = self.protocol
498+
if args:
499+
proto = args.pop(0)
500+
if str(proto) not in ["2", "3"]:
501+
return ErrorStr(
502+
"NOPROTO", "sorry this protocol version is not supported"
503+
)
439504

440-
# currently supports only a single command
441-
command = " ".join(cmd)
442-
if command in self._SUPPORTED_CMDS:
443-
return self._SUPPORTED_CMDS[command]
444-
return self._ERROR_RESP
505+
while args:
506+
cmd = args.pop(0).upper()
507+
if cmd == "AUTH":
508+
auth_args = args[:2]
509+
args = args[2:]
510+
res = self.handle_auth(auth_args)
511+
if res != "OK":
512+
return res
513+
continue
514+
if cmd == "SETNAME":
515+
setname_args = args[:1]
516+
args = args[1:]
517+
res = self.handle_setname(setname_args)
518+
if res != "OK":
519+
return res
520+
continue
521+
return ErrorStr("ERR", "unknown subcommand or wrong number of arguments")
522+
523+
self.protocol = int(proto)
524+
self.on_protocol(self.protocol)
525+
result = {
526+
"server": "redistester",
527+
"version": "0.0.1",
528+
"proto": self.protocol,
529+
}
530+
return result
531+
532+
handlers["HELLO"] = handle_hello

0 commit comments

Comments
 (0)