@@ -76,22 +76,30 @@ def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
76
76
strict = False ,
77
77
)
78
78
79
- # aoti_path = torch._inductor.aot_compile(
80
- # exported_model.module(),
81
- # model.get_example_inputs(),
82
- # )
79
+ aoti_path = torch ._inductor .aot_compile (
80
+ exported_model .module (),
81
+ model .get_example_inputs (),
82
+ )
83
83
84
84
edge_program = to_edge (
85
85
exported_model , compile_config = EdgeCompileConfig (_check_ir_validity = False )
86
86
)
87
87
executorch_model = edge_program .to_executorch ()
88
88
89
+ # Re-export, as lowering to executorch changes the graph.
90
+ exported_model = torch .export .export (
91
+ model .get_eager_model (),
92
+ model .get_example_inputs (),
93
+ dynamic_shapes = model .get_dynamic_shapes (),
94
+ strict = False ,
95
+ )
96
+
89
97
return {
90
98
"config" : config ,
91
99
"reference_model" : reference_model ,
92
100
"model" : model ,
93
101
"exported_model" : exported_model ,
94
- # "aoti_path": aoti_path,
102
+ "aoti_path" : aoti_path ,
95
103
"executorch_model" : executorch_model ,
96
104
}
97
105
@@ -237,11 +245,11 @@ def run_preprocess(
237
245
self .assertEqual (reference_ar , et_ar .tolist ())
238
246
239
247
# Run aoti model and check it matches reference model.
240
- # aoti_path = models["aoti_path"]
241
- # aoti_model = torch._export.aot_load(aoti_path, "cpu")
242
- # aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
243
- # self.assertTrue(torch.allclose(reference_image, aoti_image))
244
- # self.assertEqual(reference_ar, aoti_ar.tolist())
248
+ aoti_path = models ["aoti_path" ]
249
+ aoti_model = torch ._export .aot_load (aoti_path , "cpu" )
250
+ aoti_image , aoti_ar = aoti_model (image_tensor , inscribed_size , best_resolution )
251
+ self .assertTrue (torch .allclose (reference_image , aoti_image ))
252
+ self .assertEqual (reference_ar , aoti_ar .tolist ())
245
253
246
254
# This test setup mirrors the one in torchtune:
247
255
# https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
0 commit comments