|
26 | 26 | )
|
27 | 27 |
|
28 | 28 | from PIL import Image
|
| 29 | +from torch._inductor.package import package_aoti |
29 | 30 |
|
30 | 31 | from torchtune.models.clip.inference._transform import CLIPImageTransform
|
31 | 32 |
|
@@ -55,31 +56,46 @@ def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
|
55 | 56 | possible_resolutions=None,
|
56 | 57 | )
|
57 | 58 |
|
| 59 | + # Eager model. |
58 | 60 | model = CLIPImageTransformModel(config)
|
59 | 61 |
|
| 62 | + # Exported model. |
60 | 63 | exported_model = torch.export.export(
|
61 | 64 | model.get_eager_model(),
|
62 | 65 | model.get_example_inputs(),
|
63 | 66 | dynamic_shapes=model.get_dynamic_shapes(),
|
64 | 67 | strict=False,
|
65 | 68 | )
|
66 | 69 |
|
67 |
| - # aoti_path = torch._inductor.aot_compile( |
68 |
| - # exported_model.module(), |
69 |
| - # model.get_example_inputs(), |
70 |
| - # ) |
| 70 | + # AOTInductor model. |
| 71 | + so = torch._export.aot_compile( |
| 72 | + exported_model.module(), |
| 73 | + args=model.get_example_inputs(), |
| 74 | + options={"aot_inductor.package": True}, |
| 75 | + dynamic_shapes=model.get_dynamic_shapes(), |
| 76 | + ) |
| 77 | + aoti_path = "preprocess.pt2" |
| 78 | + package_aoti(aoti_path, so) |
71 | 79 |
|
72 | 80 | edge_program = to_edge(
|
73 | 81 | exported_model, compile_config=EdgeCompileConfig(_check_ir_validity=False)
|
74 | 82 | )
|
75 | 83 | executorch_model = edge_program.to_executorch()
|
76 | 84 |
|
| 85 | + # Re-export as ExecuTorch edits the ExportedProgram. |
| 86 | + exported_model = torch.export.export( |
| 87 | + model.get_eager_model(), |
| 88 | + model.get_example_inputs(), |
| 89 | + dynamic_shapes=model.get_dynamic_shapes(), |
| 90 | + strict=False, |
| 91 | + ) |
| 92 | + |
77 | 93 | return {
|
78 | 94 | "config": config,
|
79 | 95 | "reference_model": reference_model,
|
80 | 96 | "model": model,
|
81 | 97 | "exported_model": exported_model,
|
82 |
| - # "aoti_path": aoti_path, |
| 98 | + "aoti_path": aoti_path, |
83 | 99 | "executorch_model": executorch_model,
|
84 | 100 | }
|
85 | 101 |
|
@@ -265,11 +281,13 @@ def run_preprocess(
|
265 | 281 | ), f"Executorch model: expected {reference_ar} but got {et_ar.tolist()}"
|
266 | 282 |
|
267 | 283 | # Run aoti model and check it matches reference model.
|
268 |
| - # aoti_path = models["aoti_path"] |
269 |
| - # aoti_model = torch._export.aot_load(aoti_path, "cpu") |
270 |
| - # aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution) |
271 |
| - # self.assertTrue(torch.allclose(reference_image, aoti_image)) |
272 |
| - # self.assertEqual(reference_ar, aoti_ar.tolist()) |
| 284 | + aoti_path = models["aoti_path"] |
| 285 | + aoti_model = torch._inductor.aoti_load_package(aoti_path) |
| 286 | + aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution) |
| 287 | + assert_expected(aoti_image, reference_image, rtol=0, atol=1e-4) |
| 288 | + assert ( |
| 289 | + reference_ar == aoti_ar.tolist() |
| 290 | + ), f"AOTI model: expected {reference_ar} but got {aoti_ar.tolist()}" |
273 | 291 |
|
274 | 292 | # This test setup mirrors the one in torchtune:
|
275 | 293 | # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
|
|
0 commit comments