|
1 | 1 | import asyncio
|
2 | 2 | import logging
|
3 |
| -import re |
4 | 3 | import socket
|
5 | 4 | import ssl
|
6 | 5 |
|
|
11 | 10 | UnixDomainSocketConnection,
|
12 | 11 | )
|
13 | 12 |
|
| 13 | +from .. import resp |
14 | 14 | from ..ssl_utils import get_ssl_filename
|
15 | 15 |
|
16 | 16 | _logger = logging.getLogger(__name__)
|
17 | 17 |
|
18 | 18 |
|
19 | 19 | _CLIENT_NAME = "test-suite-client"
|
20 |
| -_CMD_SEP = b"\r\n" |
21 |
| -_SUCCESS_RESP = b"+OK" + _CMD_SEP |
22 |
| -_ERROR_RESP = b"-ERR" + _CMD_SEP |
23 |
| -_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} |
24 | 20 |
|
25 | 21 |
|
26 | 22 | @pytest.fixture
|
@@ -102,46 +98,34 @@ async def _handler(reader, writer):
|
102 | 98 |
|
103 | 99 |
|
104 | 100 | async def _redis_request_handler(reader, writer, stop_event):
|
| 101 | + parser = resp.RespParser() |
| 102 | + server = resp.RespServer() |
105 | 103 | buffer = b""
|
106 |
| - command = None |
107 |
| - command_ptr = None |
108 |
| - fragment_length = None |
109 |
| - while not stop_event.is_set() or buffer: |
110 |
| - _logger.info(str(stop_event.is_set())) |
111 |
| - try: |
112 |
| - buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5) |
113 |
| - except TimeoutError: |
114 |
| - continue |
115 |
| - if not buffer: |
116 |
| - continue |
117 |
| - parts = re.split(_CMD_SEP, buffer) |
118 |
| - buffer = parts[-1] |
119 |
| - for fragment in parts[:-1]: |
120 |
| - fragment = fragment.decode() |
121 |
| - _logger.info("Command fragment: %s", fragment) |
122 |
| - |
123 |
| - if fragment.startswith("*") and command is None: |
124 |
| - command = [None for _ in range(int(fragment[1:]))] |
125 |
| - command_ptr = 0 |
126 |
| - fragment_length = None |
127 |
| - continue |
128 |
| - |
129 |
| - if fragment.startswith("$") and command[command_ptr] is None: |
130 |
| - fragment_length = int(fragment[1:]) |
131 |
| - continue |
132 |
| - |
133 |
| - assert len(fragment) == fragment_length |
134 |
| - command[command_ptr] = fragment |
135 |
| - command_ptr += 1 |
136 |
| - |
137 |
| - if command_ptr < len(command): |
| 104 | + try: |
| 105 | + # if client performs pipelining, we may need |
| 106 | + # to adjust this code to not block when sending |
| 107 | + # responses. |
| 108 | + while not stop_event.is_set() or buffer: |
| 109 | + _logger.info(str(stop_event.is_set())) |
| 110 | + try: |
| 111 | + command = parser.parse(buffer) |
| 112 | + buffer = b"" |
| 113 | + except resp.NeedMoreData: |
| 114 | + try: |
| 115 | + buffer = await asyncio.wait_for(reader.read(1024), timeout=0.5) |
| 116 | + except TimeoutError: |
| 117 | + buffer = b"" |
| 118 | + continue |
| 119 | + if not buffer: |
| 120 | + break # EOF |
138 | 121 | continue
|
139 | 122 |
|
140 |
| - command = " ".join(command) |
141 | 123 | _logger.info("Command %s", command)
|
142 |
| - resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) |
143 |
| - _logger.info("Response from %s", resp) |
144 |
| - writer.write(resp) |
| 124 | + response = server.command(command) |
| 125 | + _logger.info("Response %s", response) |
| 126 | + writer.write(response) |
145 | 127 | await writer.drain()
|
146 |
| - command = None |
147 |
| - _logger.info("Exit handler") |
| 128 | + except Exception: |
| 129 | + _logger.exception("Error in handler") |
| 130 | + finally: |
| 131 | + _logger.info("Exit handler") |
0 commit comments