@@ -192,7 +192,6 @@ def decode_bytes(self, data: bytes) -> str:
192
192
"""
193
193
return data .decode (self .encoding , errors = self .errorhandler )
194
194
195
- # a stateful RESP parser implemented via a generator
196
195
def parse (
197
196
self ,
198
197
buffer : bytes ,
@@ -340,8 +339,9 @@ class NeedMoreData(RuntimeError):
340
339
341
340
class RespParser :
342
341
"""
343
- A class for simple RESP protocol decoding for unit tests
344
- Uses a RespGeneratorParser to produce data.
342
+ A class for simple RESP protocol decoding for unit tests.
343
+ Uses a RespGeneratorParser to produce data, and can
344
+ produce top-level objects for as long as there is data available.
345
345
"""
346
346
347
347
def __init__ (self ) -> None :
@@ -355,7 +355,8 @@ def __init__(self) -> None:
355
355
def parse (self , buffer : bytes ) -> Optional [Any ]:
356
356
"""
357
357
Parse a buffer of data, return a tuple of a single top-level primitive and the
358
- remaining buffer or raise NeedMoreData if more data is needed
358
+ remaining buffer or raise NeedMoreData if more data is needed to
359
+ produce a value.
359
360
"""
360
361
if self .generator is None :
361
362
# create a new parser generator, initializing it with
@@ -372,7 +373,7 @@ def parse(self, buffer: bytes) -> Optional[Any]:
372
373
self .consumed .append (buffer )
373
374
raise NeedMoreData ()
374
375
375
- # got a value, close the parser , store the remaining buffer
376
+ # got a value, close the generator , store the remaining buffer
376
377
self .generator .close ()
377
378
self .generator = None
378
379
value , remaining = parsed
@@ -427,18 +428,105 @@ class RespServer:
427
428
Accepts RESP commands and returns RESP responses.
428
429
"""
429
430
430
- _CLIENT_NAME = "test-suite-client"
431
- _SUCCESS_RESP = b"+OK" + CRNL
432
- _ERROR_RESP = b"-ERR" + CRNL
433
- _SUPPORTED_CMDS = {f"CLIENT SETNAME { _CLIENT_NAME } " : _SUCCESS_RESP }
431
+ handlers = {}
432
+
433
+ def __init__ (self ):
434
+ self .protocol = 2
435
+ self .server_ver = self .get_server_version ()
436
+ self .auth = []
437
+ self .client_name = ""
438
+
439
+ # patchable methods for testing
440
+
441
+ def get_server_version (self ):
442
+ return 6
443
+
444
+ def on_auth (self , auth ):
445
+ pass
446
+
447
+ def on_setname (self , name ):
448
+ pass
449
+
450
+ def on_protocol (self , proto ):
451
+ pass
434
452
435
453
def command (self , cmd : Any ) -> bytes :
436
454
"""Process a single command and return the response"""
455
+ result = self ._command (cmd )
456
+ return RespEncoder (self .protocol ).encode (result )
457
+
458
+ def _command (self , cmd : Any ) -> Any :
437
459
if not isinstance (cmd , list ):
438
- return f"-ERR unknown command { cmd !r} \r \n " .encode ()
460
+ return ErrorStr ("ERR" , "unknown command {cmd!r}" )
461
+
462
+ # handle registered commands
463
+ command = cmd [0 ].upper ()
464
+ args = cmd [1 :]
465
+ if command in self .handlers :
466
+ return self .handlers [command ](self , args )
467
+
468
+ return ErrorStr ("ERR" , "unknown command {cmd!r}" )
469
+
470
+ def handle_auth (self , args ):
471
+ self .auth = args [:]
472
+ self .on_auth (self .auth )
473
+ expect = 2 if self .server_ver >= 6 else 1
474
+ if len (args ) != expect :
475
+ return ErrorStr ("ERR" , "wrong number of arguments" " for 'AUTH' command" )
476
+ return "OK"
477
+
478
+ handlers ["AUTH" ] = handle_auth
479
+
480
+ def handle_client (self , args ):
481
+ if args [0 ] == "SETNAME" :
482
+ return self .handle_setname (args [1 :])
483
+ return ErrorStr ("ERR" , "unknown subcommand or wrong number of arguments" )
484
+
485
+ handlers ["CLIENT" ] = handle_client
486
+
487
+ def handle_setname (self , args ):
488
+ if len (args ) != 1 :
489
+ return ErrorStr ("ERR" , "wrong number of arguments" )
490
+ self .client_name = args [0 ]
491
+ self .on_setname (self .client_name )
492
+ return "OK"
493
+
494
+ def handle_hello (self , args ):
495
+ if self .server_ver < 6 :
496
+ return ErrorStr ("ERR" , "unknown command 'HELLO'" )
497
+ proto = self .protocol
498
+ if args :
499
+ proto = args .pop (0 )
500
+ if str (proto ) not in ["2" , "3" ]:
501
+ return ErrorStr (
502
+ "NOPROTO" , "sorry this protocol version is not supported"
503
+ )
439
504
440
- # currently supports only a single command
441
- command = " " .join (cmd )
442
- if command in self ._SUPPORTED_CMDS :
443
- return self ._SUPPORTED_CMDS [command ]
444
- return self ._ERROR_RESP
505
+ while args :
506
+ cmd = args .pop (0 ).upper ()
507
+ if cmd == "AUTH" :
508
+ auth_args = args [:2 ]
509
+ args = args [2 :]
510
+ res = self .handle_auth (auth_args )
511
+ if res != "OK" :
512
+ return res
513
+ continue
514
+ if cmd == "SETNAME" :
515
+ setname_args = args [:1 ]
516
+ args = args [1 :]
517
+ res = self .handle_setname (setname_args )
518
+ if res != "OK" :
519
+ return res
520
+ continue
521
+ return ErrorStr ("ERR" , "unknown subcommand or wrong number of arguments" )
522
+
523
+ self .protocol = int (proto )
524
+ self .on_protocol (self .protocol )
525
+ result = {
526
+ "server" : "redistester" ,
527
+ "version" : "0.0.1" ,
528
+ "proto" : self .protocol ,
529
+ }
530
+ return result
531
+
532
+ handlers ["HELLO" ] = handle_hello
0 commit comments