|
1 | 1 | import asyncio
|
2 | 2 | import logging
|
3 |
| -import re |
4 | 3 | import socket
|
5 | 4 | import ssl
|
6 | 5 |
|
|
11 | 10 | UnixDomainSocketConnection,
|
12 | 11 | )
|
13 | 12 |
|
| 13 | +from .. import resp |
14 | 14 | from ..ssl_utils import get_ssl_filename
|
15 | 15 |
|
16 | 16 | _logger = logging.getLogger(__name__)
|
17 | 17 |
|
18 | 18 |
|
19 | 19 | _CLIENT_NAME = "test-suite-client"
|
20 |
| -_CMD_SEP = b"\r\n" |
21 |
| -_SUCCESS_RESP = b"+OK" + _CMD_SEP |
22 |
| -_ERROR_RESP = b"-ERR" + _CMD_SEP |
23 |
| -_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} |
24 | 20 |
|
25 | 21 |
|
26 | 22 | @pytest.fixture
|
@@ -100,46 +96,34 @@ async def _handler(reader, writer):
|
100 | 96 |
|
101 | 97 |
|
102 | 98 | async def _redis_request_handler(reader, writer, stop_event):
|
| 99 | + parser = resp.RespParser() |
| 100 | + server = resp.RespServer() |
103 | 101 | buffer = b""
|
104 |
| - command = None |
105 |
| - command_ptr = None |
106 |
| - fragment_length = None |
107 |
| - while not stop_event.is_set() or buffer: |
108 |
| - _logger.info(str(stop_event.is_set())) |
109 |
| - try: |
110 |
| - buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5) |
111 |
| - except TimeoutError: |
112 |
| - continue |
113 |
| - if not buffer: |
114 |
| - continue |
115 |
| - parts = re.split(_CMD_SEP, buffer) |
116 |
| - buffer = parts[-1] |
117 |
| - for fragment in parts[:-1]: |
118 |
| - fragment = fragment.decode() |
119 |
| - _logger.info("Command fragment: %s", fragment) |
120 |
| - |
121 |
| - if fragment.startswith("*") and command is None: |
122 |
| - command = [None for _ in range(int(fragment[1:]))] |
123 |
| - command_ptr = 0 |
124 |
| - fragment_length = None |
125 |
| - continue |
126 |
| - |
127 |
| - if fragment.startswith("$") and command[command_ptr] is None: |
128 |
| - fragment_length = int(fragment[1:]) |
129 |
| - continue |
130 |
| - |
131 |
| - assert len(fragment) == fragment_length |
132 |
| - command[command_ptr] = fragment |
133 |
| - command_ptr += 1 |
134 |
| - |
135 |
| - if command_ptr < len(command): |
| 102 | + try: |
| 103 | + # if client performs pipelining, we may need |
| 104 | + # to adjust this code to not block when sending |
| 105 | + # responses. |
| 106 | + while not stop_event.is_set() or buffer: |
| 107 | + _logger.info(str(stop_event.is_set())) |
| 108 | + try: |
| 109 | + command = parser.parse(buffer) |
| 110 | + buffer = b"" |
| 111 | + except resp.NeedMoreData: |
| 112 | + try: |
| 113 | + buffer = await asyncio.wait_for(reader.read(1024), timeout=0.5) |
| 114 | + except TimeoutError: |
| 115 | + buffer = b"" |
| 116 | + continue |
| 117 | + if not buffer: |
| 118 | + break # EOF |
136 | 119 | continue
|
137 | 120 |
|
138 |
| - command = " ".join(command) |
139 | 121 | _logger.info("Command %s", command)
|
140 |
| - resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) |
141 |
| - _logger.info("Response from %s", resp) |
142 |
| - writer.write(resp) |
| 122 | + response = server.command(command) |
| 123 | + _logger.info("Response %s", response) |
| 124 | + writer.write(response) |
143 | 125 | await writer.drain()
|
144 |
| - command = None |
145 |
| - _logger.info("Exit handler") |
| 126 | + except Exception: |
| 127 | + _logger.exception("Error in handler") |
| 128 | + finally: |
| 129 | + _logger.info("Exit handler") |
0 commit comments