Skip to content

Commit a166a25

Browse files
lucylqfacebook-github-bot
authored andcommitted
Export aoti for preprocess (#5354)
Summary: Export AOTI for preprocess. Requires: D62651605 Pull Request resolved: #5354 Test Plan: ``` python -m unittest examples/models/flamingo/preprocess/test_preprocess.py ``` Reviewed By: larryliu0820 Differential Revision: D62662418 Pulled By: lucylq fbshipit-source-id: a094a14870fcca820fb4b739684b5e75aefd36d3
1 parent 2b3cc27 commit a166a25

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

examples/models/flamingo/preprocess/export_preprocess.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,30 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from export_preprocess_lib import export_preprocess, lower_to_executorch_preprocess
7+
import torch
8+
from executorch.examples.models.flamingo.preprocess.export_preprocess_lib import (
9+
export_preprocess,
10+
get_example_inputs,
11+
lower_to_executorch_preprocess,
12+
)
813

914

1015
def main():
16+
# Export
1117
ep = export_preprocess()
12-
et = lower_to_executorch_preprocess(ep)
1318

14-
with open("preprocess.pte", "wb") as file:
19+
# ExecuTorch
20+
et = lower_to_executorch_preprocess(ep)
21+
with open("preprocess_et.pte", "wb") as file:
1522
et.write_to_file(file)
1623

24+
# AOTInductor
25+
torch._inductor.aot_compile(
26+
ep.module(),
27+
get_example_inputs(),
28+
options={"aot_inductor.output_path": "preprocess_aoti.so"},
29+
)
30+
1731

1832
if __name__ == "__main__":
1933
main()

examples/models/flamingo/preprocess/test_preprocess.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,11 @@
3737
)
3838
from torchvision.transforms.v2 import functional as F
3939

40-
from .export_preprocess_lib import export_preprocess, lower_to_executorch_preprocess
40+
from .export_preprocess_lib import (
41+
export_preprocess,
42+
get_example_inputs,
43+
lower_to_executorch_preprocess,
44+
)
4145

4246

4347
@dataclass
@@ -206,6 +210,11 @@ def test_preprocess(
206210
executorch_model = lower_to_executorch_preprocess(exported_model)
207211
executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer)
208212

213+
aoti_path = torch._inductor.aot_compile(
214+
exported_model.module(),
215+
get_example_inputs(),
216+
)
217+
209218
# Prepare image input.
210219
image = (
211220
np.random.randint(0, 256, np.prod(image_size))
@@ -266,3 +275,9 @@ def test_preprocess(
266275
)
267276
self.assertTrue(torch.allclose(reference_image, et_image))
268277
self.assertEqual(reference_ar, et_ar.tolist())
278+
279+
# Run aoti model and check it matches reference model.
280+
aoti_model = torch._export.aot_load(aoti_path, "cpu")
281+
aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
282+
self.assertTrue(torch.allclose(reference_image, aoti_image))
283+
self.assertEqual(reference_ar, aoti_ar.tolist())

0 commit comments

Comments
 (0)