Skip to content

Refactor preprocess to use EagerModelBase #6530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions examples/models/llama3_2_vision/preprocess/export_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,47 @@
# LICENSE file in the root directory of this source tree.

import torch
from executorch.examples.models.llama3_2_vision.preprocess.export_preprocess_lib import (
export_preprocess,
get_example_inputs,
lower_to_executorch_preprocess,
from executorch.examples.models.llama3_2_vision.preprocess.model import (
CLIPImageTransformModel,
PreprocessConfig,
)
from executorch.exir import EdgeCompileConfig, to_edge


def main():
# Eager model.
model = CLIPImageTransformModel(PreprocessConfig())

# ExecuTorch
ep_et = export_preprocess()
et = lower_to_executorch_preprocess(ep_et)
with open("preprocess_et.pte", "wb") as file:
et.write_to_file(file)

# AOTInductor
ep_aoti = export_preprocess()
torch._inductor.aot_compile(
ep_aoti.module(),
get_example_inputs(),
options={"aot_inductor.output_path": "preprocess_aoti.so"},
# Export.
ep = torch.export.export(
model.get_eager_model(),
model.get_example_inputs(),
dynamic_shapes=model.get_dynamic_shapes(),
strict=False,
)

# Executorch
edge_program = to_edge(
ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
)
et_program = edge_program.to_executorch()
with open("preprocess_et.pte", "wb") as file:
et_program.write_to_file(file)

# Export.
# ep = torch.export.export(
# model.get_eager_model(),
# model.get_example_inputs(),
# dynamic_shapes=model.get_dynamic_shapes(),
# strict=False,
# )
#
# # AOTInductor
# torch._inductor.aot_compile(
# ep.module(),
# model.get_example_inputs(),
# options={"aot_inductor.output_path": "preprocess_aoti.so"},
# )
Comment on lines +35 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why commented out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#6553

Aoti still requires pytorch/pytorch#137063, Richard said will try to land eow.



if __name__ == "__main__":
Expand Down

This file was deleted.

69 changes: 69 additions & 0 deletions examples/models/llama3_2_vision/preprocess/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch

from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa
from torch.export import Dim
from torchtune.models.clip.inference._transform import _CLIPImageTransform

from ...model_base import EagerModelBase


@dataclass
class PreprocessConfig:
image_mean: Optional[List[float]] = None
image_std: Optional[List[float]] = None
resample: str = "bilinear"
max_num_tiles: int = 4
tile_size: int = 224
antialias: bool = False


class CLIPImageTransformModel(EagerModelBase):
def __init__(
self,
config: PreprocessConfig,
):
super().__init__()

# Eager model.
self.model = _CLIPImageTransform(
image_mean=config.image_mean,
image_std=config.image_std,
resample=config.resample,
max_num_tiles=config.max_num_tiles,
tile_size=config.tile_size,
antialias=config.antialias,
)

# Replace non-exportable ops with custom ops.
self.model.tile_crop = torch.ops.preprocess.tile_crop.default

def get_eager_model(self) -> torch.nn.Module:
return self.model

def get_example_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image = torch.ones(3, 800, 600)
target_size = torch.tensor([448, 336])
canvas_size = torch.tensor([448, 448])
return (image, target_size, canvas_size)

def get_dynamic_shapes(self) -> Dict[str, Dict[int, Dim]]:
img_h = Dim("img_h", min=1, max=4000)
img_w = Dim("img_w", min=1, max=4000)

dynamic_shapes = {
"image": {1: img_h, 2: img_w},
"target_size": None,
"canvas_size": None,
}
return dynamic_shapes
Loading