@@ -64,22 +64,30 @@ def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
64
64
strict = False ,
65
65
)
66
66
67
- # aoti_path = torch._inductor.aot_compile(
68
- # exported_model.module(),
69
- # model.get_example_inputs(),
70
- # )
67
+ aoti_path = torch ._inductor .aot_compile (
68
+ exported_model .module (),
69
+ model .get_example_inputs (),
70
+ )
71
71
72
72
edge_program = to_edge (
73
73
exported_model , compile_config = EdgeCompileConfig (_check_ir_validity = False )
74
74
)
75
75
executorch_model = edge_program .to_executorch ()
76
76
77
+ # Re-export as ExecuTorch edits the ExportedProgram.
78
+ exported_model = torch .export .export (
79
+ model .get_eager_model (),
80
+ model .get_example_inputs (),
81
+ dynamic_shapes = model .get_dynamic_shapes (),
82
+ strict = False ,
83
+ )
84
+
77
85
return {
78
86
"config" : config ,
79
87
"reference_model" : reference_model ,
80
88
"model" : model ,
81
89
"exported_model" : exported_model ,
82
- # "aoti_path": aoti_path,
90
+ "aoti_path" : aoti_path ,
83
91
"executorch_model" : executorch_model ,
84
92
}
85
93
@@ -265,11 +273,13 @@ def run_preprocess(
265
273
), f"Executorch model: expected { reference_ar } but got { et_ar .tolist ()} "
266
274
267
275
# Run aoti model and check it matches reference model.
268
- # aoti_path = models["aoti_path"]
269
- # aoti_model = torch._export.aot_load(aoti_path, "cpu")
270
- # aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
271
- # self.assertTrue(torch.allclose(reference_image, aoti_image))
272
- # self.assertEqual(reference_ar, aoti_ar.tolist())
276
+ aoti_path = models ["aoti_path" ]
277
+ aoti_model = torch ._export .aot_load (aoti_path , "cpu" )
278
+ aoti_image , aoti_ar = aoti_model (image_tensor , inscribed_size , best_resolution )
279
+ assert torch .allclose (reference_image , aoti_image )
280
+ assert (
281
+ reference_ar == aoti_ar .tolist ()
282
+ ), f"AOTI model: expected { reference_ar } but got { aoti_ar .tolist ()} "
273
283
274
284
# This test setup mirrors the one in torchtune:
275
285
# https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
0 commit comments