Skip to content

Commit 6e96054

Browse files
committed
Implement flash_write
1 parent 1a5a405 commit 6e96054

File tree

4 files changed

+163
-53
lines changed

4 files changed

+163
-53
lines changed

tests/test_tools_flash.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,30 @@
1+
import random
12
import zigpy_znp.types as t
23
import zigpy_znp.commands as c
34

4-
from zigpy_znp.tools.flash_backup import main as flash_backup
5+
from zigpy_znp.tools.flash_read import main as flash_read
6+
from zigpy_znp.tools.flash_write import main as flash_write
57

68
from test_api import pytest_mark_asyncio_timeout # noqa: F401
79
from test_application import znp_server # noqa: F401
810
from test_tools_nvram import openable_serial_znp_server # noqa: F401
911

1012

11-
# Just random bytes
12-
FAKE_FLASH = bytes.fromhex(
13-
"""
14-
a66ea64b2299ef91102c692c8739433776ac1f7967b2d7be3b532db5255dee88f49cad134ef4155375d2
15-
67acecbe64637bd1df47ce1cb8b776caad7a7cd2b39892b69fbf2420176e598f689df05a3554400efb99
16-
60dcedfb3416fe72b1570b6eb4aa877213afb92c7a6fc8b755e7457072a8c4d4ac9ec727b7748b267fda
17-
241334ab9195b4eb52cb50b396859c355dfad136e1c56b18f6599e08a7464524587a44ea0caaeb2b0a79
18-
44ff74576db0c16b133f862de8ee8b6b37181a897416b40c589a645c62bbc6b2b4e993a6ee39ca1141bb
19-
7baeb7bb85476c7b905fa8f3f2148fe1162a218fb575eb3ed9849bc63212f7332a27f83c75e6590a25ad
20-
8ad3d13b212da0142bc257851afcc7c87c80c23d9f741f7159ccc89fed58ff2369523af224369df39224
21-
a4154dc2932958d3289d387356af931aa6e02d8216bffc3972674cf060de50c10e0705b2f80d7b54c763
22-
0999d2f28f8e3b1917d89e960a1893ebdaa1695c5b2f1fc36efb144b326d4cb8119803ea327f2848b45a
23-
a6e3e1ca93459eb848a8333826b12d87949be6cf652b1265a7c74e2b750303ee25f6296ed687393cb1a1
24-
64648ae92eb2c426ea3f35770f6d64fefcd87fc9835ab39134be9a5d325cc2839a47515f15ce5b2072fe
25-
808a5e897a273f883751d029bec9fe89797fd2940603537770c745c17e817e495e4d8741e744b652254b
26-
2b776c1d313ca30a
27-
"""
28-
)
13+
random.seed(12345)
14+
FAKE_IMAGE_SIZE = 2 ** 10
15+
FAKE_FLASH = random.getrandbits(FAKE_IMAGE_SIZE * 8).to_bytes(FAKE_IMAGE_SIZE, "little")
16+
random.seed()
2917

3018

3119
@pytest_mark_asyncio_timeout(seconds=5)
32-
async def test_flash_backup(openable_serial_znp_server, tmp_path): # noqa: F811
20+
async def test_flash_backup_write(
21+
openable_serial_znp_server, tmp_path, mocker # noqa: F811
22+
):
23+
# It takes too long otherwise
24+
mocker.patch("zigpy_znp.commands.ubl.IMAGE_SIZE", FAKE_IMAGE_SIZE)
25+
26+
WRITABLE_FLASH = bytearray(len(FAKE_FLASH))
27+
3328
openable_serial_znp_server.reply_to(
3429
request=c.UBL.HandshakeReq.Req(partial=True),
3530
responses=[
@@ -46,7 +41,7 @@ async def test_flash_backup(openable_serial_znp_server, tmp_path): # noqa: F811
4641

4742
def read_flash(req):
4843
offset = req.FlashWordAddr * 4
49-
data = FAKE_FLASH[offset : offset + 64]
44+
data = WRITABLE_FLASH[offset : offset + 64]
5045

5146
# We should not read partial blocks
5247
assert len(data) in (0, 64)
@@ -60,11 +55,37 @@ def read_flash(req):
6055
Data=t.TrailingBytes(data),
6156
)
6257

58+
def write_flash(req):
59+
offset = req.FlashWordAddr * 4
60+
61+
assert len(req.Data) == 64
62+
63+
WRITABLE_FLASH[offset : offset + 64] = req.Data
64+
assert len(WRITABLE_FLASH) == FAKE_IMAGE_SIZE
65+
66+
return c.UBL.WriteRsp.Callback(Status=c.ubl.BootloaderStatus.SUCCESS)
67+
6368
openable_serial_znp_server.reply_to(
6469
request=c.UBL.ReadReq.Req(partial=True), responses=[read_flash]
6570
)
6671

72+
openable_serial_znp_server.reply_to(
73+
request=c.UBL.WriteReq.Req(partial=True), responses=[write_flash]
74+
)
75+
76+
openable_serial_znp_server.reply_to(
77+
request=c.UBL.EnableReq.Req(partial=True),
78+
responses=[c.UBL.EnableRsp.Callback(Status=c.ubl.BootloaderStatus.SUCCESS)],
79+
)
80+
81+
# First we write the flash
82+
firmware_file = tmp_path / "firmware.bin"
83+
firmware_file.write_bytes(FAKE_FLASH)
84+
await flash_write([openable_serial_znp_server._port_path, "-i", str(firmware_file)])
85+
86+
# And then make a backup
6787
backup_file = tmp_path / "backup.bin"
68-
await flash_backup([openable_serial_znp_server._port_path, "-o", str(backup_file)])
88+
await flash_read([openable_serial_znp_server._port_path, "-o", str(backup_file)])
6989

90+
# They should be identical
7091
assert backup_file.read_bytes() == FAKE_FLASH

zigpy_znp/commands/ubl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import zigpy_znp.types as t
55

66

7+
# Size of internal flash less 4 pages for boot loader,
8+
# 6 pages for NV, & 1 page for lock bits.
9+
IMAGE_SIZE = 0x40000 - 0x2000 - 0x3000 - 0x0800
10+
IMAGE_CRC_OFFSET = 0x90
11+
712
FLASH_WORD_SIZE = 4
813

914

zigpy_znp/tools/flash_backup.py renamed to zigpy_znp/tools/flash_read.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,31 +16,6 @@
1616
LOGGER = logging.getLogger(__name__)
1717

1818

19-
async def get_firmware_size(znp: ZNP, block_size: int) -> int:
20-
valid_index = 0x0000
21-
22-
# Z-Stack lets you read beyond the end of the flash (???) if you go too high,
23-
# instead of throwing an error. We need to be careful.
24-
invalid_index = 0xFFFF // block_size
25-
26-
while invalid_index - valid_index > 1:
27-
midpoint = (valid_index + invalid_index) // 2
28-
29-
read_rsp = await znp.request_callback_rsp(
30-
request=c.UBL.ReadReq.Req(FlashWordAddr=midpoint * block_size),
31-
callback=c.UBL.ReadRsp.Callback(partial=True),
32-
)
33-
34-
if read_rsp.Status == c.ubl.BootloaderStatus.SUCCESS:
35-
valid_index = midpoint
36-
elif read_rsp.Status == c.ubl.BootloaderStatus.FAILURE:
37-
invalid_index = midpoint
38-
else:
39-
raise ValueError(f"Unexpected read response: {read_rsp}")
40-
41-
return invalid_index * block_size
42-
43-
4419
async def read_firmware(radio_path: str) -> bytearray:
4520
znp = ZNP(CONFIG_SCHEMA({"device": {"path": radio_path}}))
4621

@@ -65,15 +40,12 @@ async def read_firmware(radio_path: str) -> bytearray:
6540

6641
# All reads and writes are this size
6742
buffer_size = handshake_rsp.BufferSize
68-
block_size = buffer_size // c.ubl.FLASH_WORD_SIZE
69-
firmware_size = await get_firmware_size(znp, buffer_size)
70-
71-
LOGGER.info("Total firmware size is %d", firmware_size)
7243

7344
data = bytearray()
7445

75-
for address in range(0, firmware_size, block_size):
76-
LOGGER.info("Progress: %0.2f%%", (100.0 * address) / firmware_size)
46+
for offset in range(0, c.ubl.IMAGE_SIZE, buffer_size):
47+
address = offset // c.ubl.FLASH_WORD_SIZE
48+
LOGGER.info("Progress: %0.2f%%", (100.0 * offset) / c.ubl.IMAGE_SIZE)
7749

7850
read_rsp = await znp.request_callback_rsp(
7951
request=c.UBL.ReadReq.Req(FlashWordAddr=address),

zigpy_znp/tools/flash_write.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import sys
2+
import asyncio
3+
import logging
4+
import argparse
5+
import coloredlogs
6+
import async_timeout
7+
8+
import zigpy_znp.types as t
9+
import zigpy_znp.commands as c
10+
11+
from zigpy_znp.api import ZNP
12+
from zigpy_znp.config import CONFIG_SCHEMA
13+
14+
coloredlogs.install(level=logging.DEBUG)
15+
logging.getLogger("zigpy_znp").setLevel(logging.DEBUG)
16+
17+
LOGGER = logging.getLogger(__name__)
18+
19+
20+
async def write_firmware(firmware: bytes, radio_path: str):
21+
if len(firmware) != c.ubl.IMAGE_SIZE:
22+
raise ValueError(
23+
f"Firmware is the wrong size."
24+
f" Expected {c.ubl.IMAGE_SIZE}, got {len(firmware)}"
25+
)
26+
27+
znp = ZNP(CONFIG_SCHEMA({"device": {"path": radio_path}}))
28+
29+
# The bootloader handshake must be the very first command
30+
await znp.connect(test_port=False)
31+
32+
try:
33+
async with async_timeout.timeout(5):
34+
handshake_rsp = await znp.request_callback_rsp(
35+
request=c.UBL.HandshakeReq.Req(),
36+
callback=c.UBL.HandshakeRsp.Callback(partial=True),
37+
)
38+
except asyncio.TimeoutError:
39+
raise RuntimeError(
40+
"Did not receive a bootloader handshake response!"
41+
" Make sure your adapter has just been plugged in and"
42+
" nothing else has had a chance to communicate with it."
43+
)
44+
45+
if handshake_rsp.Status != c.ubl.BootloaderStatus.SUCCESS:
46+
raise RuntimeError(f"Bad bootloader handshake response: {handshake_rsp}")
47+
48+
# All reads and writes are this size
49+
buffer_size = handshake_rsp.BufferSize
50+
51+
for offset in range(0, c.ubl.IMAGE_SIZE, buffer_size):
52+
address = offset // c.ubl.FLASH_WORD_SIZE
53+
LOGGER.info("Write progress: %0.2f%%", (100.0 * offset) / c.ubl.IMAGE_SIZE)
54+
55+
write_rsp = await znp.request_callback_rsp(
56+
request=c.UBL.WriteReq.Req(
57+
FlashWordAddr=address,
58+
Data=t.TrailingBytes(firmware[offset : offset + buffer_size]),
59+
),
60+
callback=c.UBL.WriteRsp.Callback(partial=True),
61+
)
62+
63+
assert write_rsp.Status == c.ubl.BootloaderStatus.SUCCESS
64+
65+
# Now we have to read it all back
66+
# TODO: figure out how the CRC is computed!
67+
for offset in range(0, c.ubl.IMAGE_SIZE, buffer_size):
68+
address = offset // c.ubl.FLASH_WORD_SIZE
69+
LOGGER.info(
70+
"Verification progress: %0.2f%%", (100.0 * offset) / c.ubl.IMAGE_SIZE
71+
)
72+
73+
read_rsp = await znp.request_callback_rsp(
74+
request=c.UBL.ReadReq.Req(FlashWordAddr=address,),
75+
callback=c.UBL.ReadRsp.Callback(partial=True),
76+
)
77+
78+
assert read_rsp.Status == c.ubl.BootloaderStatus.SUCCESS
79+
assert read_rsp.FlashWordAddr == address
80+
assert read_rsp.Data == firmware[offset : offset + buffer_size]
81+
82+
# This seems to cause the firmware to compute and verify the CRC
83+
enable_rsp = await znp.request_callback_rsp(
84+
request=c.UBL.EnableReq.Req(), callback=c.UBL.EnableRsp.Callback(partial=True),
85+
)
86+
87+
assert enable_rsp.Status == c.ubl.BootloaderStatus.SUCCESS
88+
89+
90+
async def main(argv):
91+
parser = argparse.ArgumentParser(description="Write firmware to a radio")
92+
parser.add_argument("serial", type=argparse.FileType("rb"), help="Serial port path")
93+
parser.add_argument(
94+
"--input",
95+
"-i",
96+
type=argparse.FileType("rb"),
97+
help="Input .bin file",
98+
required=True,
99+
)
100+
101+
args = parser.parse_args(argv)
102+
103+
# We just want to make sure it exists
104+
args.serial.close()
105+
106+
await write_firmware(args.input.read(), args.serial.name)
107+
108+
LOGGER.info("Unplug your adapter to leave bootloader mode!")
109+
110+
111+
if __name__ == "__main__":
112+
asyncio.run(main(sys.argv[1:])) # pragma: no cover

0 commit comments

Comments
 (0)