Skip to content

Commit e8d6118

Browse files
committed
refactor preprocess to use EagerModelBase
ghstack-source-id: d2ac7e1 Pull Request resolved: #6530
1 parent 85d3ff6 commit e8d6118

File tree

3 files changed

+104
-101
lines changed

3 files changed

+104
-101
lines changed

examples/models/llama3_2_vision/preprocess/export_preprocess.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,47 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from executorch.examples.models.llama3_2_vision.preprocess.export_preprocess_lib import (
9-
export_preprocess,
10-
get_example_inputs,
11-
lower_to_executorch_preprocess,
8+
from executorch.examples.models.llama3_2_vision.preprocess.model import (
9+
CLIPImageTransformModel,
10+
PreprocessConfig,
1211
)
12+
from executorch.exir import EdgeCompileConfig, to_edge
1313

1414

1515
def main():
16+
# Eager model.
17+
model = CLIPImageTransformModel(PreprocessConfig())
1618

17-
# ExecuTorch
18-
ep_et = export_preprocess()
19-
et = lower_to_executorch_preprocess(ep_et)
20-
with open("preprocess_et.pte", "wb") as file:
21-
et.write_to_file(file)
22-
23-
# AOTInductor
24-
ep_aoti = export_preprocess()
25-
torch._inductor.aot_compile(
26-
ep_aoti.module(),
27-
get_example_inputs(),
28-
options={"aot_inductor.output_path": "preprocess_aoti.so"},
19+
# Export.
20+
ep = torch.export.export(
21+
model.get_eager_model(),
22+
model.get_example_inputs(),
23+
dynamic_shapes=model.get_dynamic_shapes(),
24+
strict=False,
25+
)
26+
27+
# Executorch
28+
edge_program = to_edge(
29+
ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
2930
)
31+
et_program = edge_program.to_executorch()
32+
with open("preprocess_et.pte", "wb") as file:
33+
et_program.write_to_file(file)
34+
35+
# Export.
36+
# ep = torch.export.export(
37+
# model.get_eager_model(),
38+
# model.get_example_inputs(),
39+
# dynamic_shapes=model.get_dynamic_shapes(),
40+
# strict=False,
41+
# )
42+
#
43+
# # AOTInductor
44+
# torch._inductor.aot_compile(
45+
# ep.module(),
46+
# model.get_example_inputs(),
47+
# options={"aot_inductor.output_path": "preprocess_aoti.so"},
48+
# )
3049

3150

3251
if __name__ == "__main__":

examples/models/llama3_2_vision/preprocess/export_preprocess_lib.py

Lines changed: 0 additions & 85 deletions
This file was deleted.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from dataclasses import dataclass
10+
from typing import Dict, List, Optional, Tuple
11+
12+
import torch
13+
14+
from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa
15+
from torch.export import Dim
16+
from torchtune.models.clip.inference._transform import _CLIPImageTransform
17+
18+
from ...model_base import EagerModelBase
19+
20+
21+
@dataclass
22+
class PreprocessConfig:
23+
image_mean: Optional[List[float]] = None
24+
image_std: Optional[List[float]] = None
25+
resample: str = "bilinear"
26+
max_num_tiles: int = 4
27+
tile_size: int = 224
28+
antialias: bool = False
29+
30+
31+
class CLIPImageTransformModel(EagerModelBase):
32+
def __init__(
33+
self,
34+
config: PreprocessConfig,
35+
):
36+
super().__init__()
37+
38+
# Eager model.
39+
self.model = _CLIPImageTransform(
40+
image_mean=config.image_mean,
41+
image_std=config.image_std,
42+
resample=config.resample,
43+
max_num_tiles=config.max_num_tiles,
44+
tile_size=config.tile_size,
45+
antialias=config.antialias,
46+
)
47+
48+
# Replace non-exportable ops with custom ops.
49+
self.model.tile_crop = torch.ops.preprocess.tile_crop.default
50+
51+
def get_eager_model(self) -> torch.nn.Module:
52+
return self.model
53+
54+
def get_example_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
55+
image = torch.ones(3, 800, 600)
56+
target_size = torch.tensor([448, 336])
57+
canvas_size = torch.tensor([448, 448])
58+
return (image, target_size, canvas_size)
59+
60+
def get_dynamic_shapes(self) -> Dict[str, Dict[int, Dim]]:
61+
img_h = Dim("img_h", min=1, max=4000)
62+
img_w = Dim("img_w", min=1, max=4000)
63+
64+
dynamic_shapes = {
65+
"image": {1: img_h, 2: img_w},
66+
"target_size": None,
67+
"canvas_size": None,
68+
}
69+
return dynamic_shapes

0 commit comments

Comments
 (0)