|
2 | 2 | import typing
|
3 | 3 | import asyncio
|
4 | 4 | import logging
|
| 5 | +import async_timeout |
5 | 6 |
|
6 | 7 | from collections import defaultdict
|
7 | 8 |
|
8 | 9 | import zigpy_znp.commands
|
9 | 10 | import zigpy_znp.types as t
|
| 11 | +import zigpy_znp.commands as c |
10 | 12 | from zigpy_znp.types import nvids
|
11 | 13 |
|
12 | 14 | from zigpy_znp import uart
|
|
16 | 18 |
|
17 | 19 |
|
18 | 20 | LOGGER = logging.getLogger(__name__)
|
| 21 | +RECONNECT_RETRY_TIME = 5 # seconds |
19 | 22 |
|
20 | 23 |
|
21 | 24 | def _deduplicate_commands(commands):
|
@@ -145,19 +148,67 @@ def cancel(self):
|
145 | 148 |
|
146 | 149 |
|
147 | 150 | class ZNP:
|
148 |
| - def __init__(self): |
| 151 | + def __init__(self, *, auto_reconnect=True): |
149 | 152 | self._uart = None
|
150 | 153 | self._response_listeners = defaultdict(list)
|
151 | 154 |
|
| 155 | + self._auto_reconnect = auto_reconnect |
| 156 | + self._device = None |
| 157 | + self._baudrate = None |
| 158 | + |
| 159 | + self._reconnect_task = None |
| 160 | + |
152 | 161 | def set_application(self, app):
|
153 | 162 | self._app = app
|
154 | 163 |
|
155 | 164 | async def connect(self, device, baudrate=115_200):
|
156 | 165 | assert self._uart is None
|
157 |
| - self._uart = await uart.connect(device, baudrate, self) |
| 166 | + |
| 167 | + self._uart, device = await uart.connect(device, baudrate, self) |
| 168 | + |
| 169 | + # Make sure that our port works |
| 170 | + with async_timeout.timeout(2): |
| 171 | + await self.command(c.SysCommands.Ping.Req()) |
| 172 | + |
| 173 | + # We want to reuse the same device when reconnecting |
| 174 | + self._device = device |
| 175 | + self._baudrate = baudrate |
| 176 | + |
| 177 | + def _cancel_all_listeners(self): |
| 178 | + for header, listeners in self._response_listeners.items(): |
| 179 | + for listener in listeners: |
| 180 | + listener.cancel() |
| 181 | + |
| 182 | + async def _reconnect(self): |
| 183 | + while True: |
| 184 | + assert self._device is not None and self._baudrate is not None |
| 185 | + assert self._uart is None |
| 186 | + |
| 187 | + try: |
| 188 | + self._cancel_all_listeners() |
| 189 | + |
| 190 | + await self.connect(self._device, self._baudrate) |
| 191 | + await self._app.startup() |
| 192 | + |
| 193 | + self._reconnect_task = None |
| 194 | + break |
| 195 | + except Exception as e: |
| 196 | + LOGGER.error("Failed to reconnect", exc_info=e) |
| 197 | + await asyncio.sleep(RECONNECT_RETRY_TIME) |
158 | 198 |
|
159 | 199 | def connection_lost(self, exc):
|
160 |
| - raise NotImplementedError() |
| 200 | + self._uart = None |
| 201 | + |
| 202 | + if not self._auto_reconnect: |
| 203 | + return |
| 204 | + |
| 205 | + self._cancel_all_listeners() |
| 206 | + |
| 207 | + assert self._reconnect_task is None |
| 208 | + |
| 209 | + # Reconnect in the background using our previous device info |
| 210 | + # Note that this will reuse the same port as before |
| 211 | + self._reconnect_task = asyncio.create_task(self._reconnect()) |
161 | 212 |
|
162 | 213 | def close(self):
|
163 | 214 | return self._uart.close()
|
|
0 commit comments