|
6 | 6 | # LICENSE file in the root directory of this source tree.
|
7 | 7 |
|
8 | 8 | import os
|
| 9 | +import re |
| 10 | +import shutil |
9 | 11 | import tempfile
|
10 | 12 | import unittest
|
11 | 13 | from typing import Dict, Optional, Sequence
|
12 | 14 | from unittest.mock import patch
|
13 | 15 |
|
14 | 16 | from executorch.exir._serialize import _flatbuffer
|
15 |
| -from executorch.exir._serialize._flatbuffer import _ResourceFiles, _SchemaInfo |
| 17 | +from executorch.exir._serialize._flatbuffer import ( |
| 18 | + _program_json_to_flatbuffer, |
| 19 | + _ResourceFiles, |
| 20 | + _SchemaInfo, |
| 21 | +) |
16 | 22 |
|
17 | 23 |
|
18 | 24 | def read_file(dir: str, filename: str) -> bytes:
|
@@ -266,3 +272,60 @@ def test_bad_delegate_alignment_fails(self) -> None:
|
266 | 272 | out_dir,
|
267 | 273 | delegate_alignment=bad_alignment,
|
268 | 274 | )
|
| 275 | + |
| 276 | + |
| 277 | +class TestProgramJsonToFlatbuffer(unittest.TestCase): |
| 278 | + @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"}) |
| 279 | + def test_save_json_on_failure(self) -> None: |
| 280 | + err_msg: Optional[str] = None |
| 281 | + try: |
| 282 | + _program_json_to_flatbuffer("} some bad json {") |
| 283 | + self.fail("Should have raised an exception") |
| 284 | + except RuntimeError as err: |
| 285 | + err_msg = err.args[0] |
| 286 | + |
| 287 | + self.assertIsNotNone(err_msg) |
| 288 | + match = re.search(r"Moved input files to '(.*?)'", err_msg) |
| 289 | + self.assertTrue(match, msg=f"Unexpected error message: {err_msg}") |
| 290 | + path = match.group(1) |
| 291 | + |
| 292 | + files = frozenset(os.listdir(path)) |
| 293 | + # Delete the files otherwise they'll accumulate every time the |
| 294 | + # test is run. |
| 295 | + shutil.rmtree(path) |
| 296 | + # Check for a couple of the files that should be there. |
| 297 | + self.assertIn("data.json", files) |
| 298 | + self.assertIn("program.fbs", files) |
| 299 | + |
| 300 | + @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: "1"}) |
| 301 | + def test_unable_to_save_json_on_failure(self) -> None: |
| 302 | + err_msg: Optional[str] = None |
| 303 | + try: |
| 304 | + with patch.object( |
| 305 | + _flatbuffer.shutil, |
| 306 | + "move", |
| 307 | + side_effect=Exception("shutil.move mock failure"), |
| 308 | + ): |
| 309 | + _program_json_to_flatbuffer("} some bad json {") |
| 310 | + self.fail("Should have raised an exception") |
| 311 | + except RuntimeError as err: |
| 312 | + err_msg = err.args[0] |
| 313 | + |
| 314 | + self.assertIsNotNone(err_msg) |
| 315 | + self.assertIn("Failed to save input files", err_msg) |
| 316 | + |
| 317 | + @patch.dict(os.environ, {_flatbuffer._SAVE_FLATC_ENV: ""}) |
| 318 | + def test_no_save_json_on_failure(self) -> None: |
| 319 | + err_msg: Optional[str] = None |
| 320 | + try: |
| 321 | + _program_json_to_flatbuffer("} some bad json {") |
| 322 | + self.fail("Should have raised an exception") |
| 323 | + except RuntimeError as err: |
| 324 | + err_msg = err.args[0] |
| 325 | + |
| 326 | + self.assertIsNotNone(err_msg) |
| 327 | + self.assertIn( |
| 328 | + f"Set {_flatbuffer._SAVE_FLATC_ENV}=1 to save input files", err_msg |
| 329 | + ) |
| 330 | + self.assertNotIn("Moved input files", err_msg) |
| 331 | + self.assertNotIn("Failed to save input files", err_msg) |
0 commit comments