Skip to content

Commit 6fd92ab

Browse files
committed
Add preprocess to ci
ghstack-source-id: 4441650 Pull Request resolved: #6544
1 parent e8d6118 commit 6fd92ab

File tree

3 files changed

+155
-114
lines changed

3 files changed

+155
-114
lines changed

.github/workflows/pull.yml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,34 @@ jobs:
231231
# run e2e (export, tokenizer and runner)
232232
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llava.sh
233233
234+
test-preprocess-linux:
235+
name: test-preprocess-linux
236+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
237+
strategy:
238+
fail-fast: false
239+
with:
240+
runner: linux.24xlarge
241+
docker-image: executorch-ubuntu-22.04-clang12
242+
submodules: 'true'
243+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
244+
timeout: 90
245+
script: |
246+
# The generic Linux job chooses to use base env, not the one setup by the image
247+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
248+
conda activate "${CONDA_ENV}"
249+
250+
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh "cmake"
251+
252+
# install pybind
253+
bash install_requirements.sh --pybind xnnpack
254+
255+
# install preprocess requirements
256+
bash examples/models/llama3_2_vision/install_requirements.sh
257+
258+
# run python unittest
259+
python -m unittest examples.models.llama3_2_vision.preprocess.test_preprocess
260+
261+
234262
test-quantized-aot-lib-linux:
235263
name: test-quantized-aot-lib-linux
236264
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main

examples/models/llama3_2_vision/preprocess/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ class PreprocessConfig:
2626
max_num_tiles: int = 4
2727
tile_size: int = 224
2828
antialias: bool = False
29+
# Used for reference eager model from torchtune.
30+
resize_to_max_canvas: bool = False
31+
possible_resolutions: Optional[List[Tuple[int, int]]] = None
2932

3033

3134
class CLIPImageTransformModel(EagerModelBase):

examples/models/llama3_2_vision/preprocess/test_preprocess.py

Lines changed: 124 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,32 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6-
76
import unittest
87

9-
from dataclasses import dataclass
10-
from typing import List, Optional, Tuple
8+
from typing import Any, Dict, List, Tuple
119

1210
import numpy as np
1311
import PIL
1412
import torch
1513

14+
# Import these first. Otherwise, the custom ops are not registered.
1615
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
17-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
18-
from executorch.examples.models.llama3_2_vision.preprocess.export_preprocess_lib import (
19-
export_preprocess,
20-
get_example_inputs,
21-
lower_to_executorch_preprocess,
16+
from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa # usort: skip
17+
18+
from executorch.examples.models.llama3_2_vision.preprocess.model import (
19+
CLIPImageTransformModel,
20+
PreprocessConfig,
2221
)
22+
23+
from executorch.exir import EdgeCompileConfig, to_edge
24+
2325
from executorch.extension.pybindings.portable_lib import (
2426
_load_for_executorch_from_buffer,
2527
)
2628

27-
from parameterized import parameterized
2829
from PIL import Image
2930

30-
from torchtune.models.clip.inference._transform import (
31-
_CLIPImageTransform,
32-
CLIPImageTransform,
33-
)
31+
from torchtune.models.clip.inference._transform import CLIPImageTransform
3432

3533
from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import (
3634
find_supported_resolutions,
@@ -43,18 +41,6 @@
4341
from torchvision.transforms.v2 import functional as F
4442

4543

46-
@dataclass
47-
class PreprocessConfig:
48-
image_mean: Optional[List[float]] = None
49-
image_std: Optional[List[float]] = None
50-
resize_to_max_canvas: bool = True
51-
resample: str = "bilinear"
52-
antialias: bool = False
53-
tile_size: int = 224
54-
max_num_tiles: int = 4
55-
possible_resolutions = None
56-
57-
5844
class TestImageTransform(unittest.TestCase):
5945
"""
6046
This unittest checks that the exported image transform model produces the
@@ -66,6 +52,58 @@ class TestImageTransform(unittest.TestCase):
6652
https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L26
6753
"""
6854

55+
@staticmethod
56+
def initialize_models(resize_to_max_canvas: bool) -> Dict[str, Any]:
57+
config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas)
58+
59+
reference_model = CLIPImageTransform(
60+
image_mean=config.image_mean,
61+
image_std=config.image_std,
62+
resample=config.resample,
63+
antialias=config.antialias,
64+
tile_size=config.tile_size,
65+
max_num_tiles=config.max_num_tiles,
66+
resize_to_max_canvas=config.resize_to_max_canvas,
67+
possible_resolutions=None,
68+
)
69+
70+
model = CLIPImageTransformModel(config)
71+
72+
exported_model = torch.export.export(
73+
model.get_eager_model(),
74+
model.get_example_inputs(),
75+
dynamic_shapes=model.get_dynamic_shapes(),
76+
strict=False,
77+
)
78+
79+
# aoti_path = torch._inductor.aot_compile(
80+
# exported_model.module(),
81+
# model.get_example_inputs(),
82+
# )
83+
84+
edge_program = to_edge(
85+
exported_model, compile_config=EdgeCompileConfig(_check_ir_validity=False)
86+
)
87+
executorch_model = edge_program.to_executorch()
88+
89+
return {
90+
"config": config,
91+
"reference_model": reference_model,
92+
"model": model,
93+
"exported_model": exported_model,
94+
# "aoti_path": aoti_path,
95+
"executorch_model": executorch_model,
96+
}
97+
98+
@classmethod
99+
def setUpClass(cls):
100+
cls.models_no_resize = TestImageTransform.initialize_models(
101+
resize_to_max_canvas=False
102+
)
103+
cls.models_resize = TestImageTransform.initialize_models(
104+
resize_to_max_canvas=True
105+
)
106+
69107
def setUp(self):
70108
np.random.seed(0)
71109

@@ -121,51 +159,7 @@ def prepare_inputs(
121159

122160
return image_tensor, inscribed_size, best_resolution
123161

124-
# This test setup mirrors the one in torchtune:
125-
# https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
126-
# The values are slightly different, as torchtune uses antialias=True,
127-
# and this test uses antialias=False, which is exportable (has a portable kernel).
128-
@parameterized.expand(
129-
[
130-
(
131-
(100, 400, 3), # image_size
132-
torch.Size([2, 3, 224, 224]), # expected shape
133-
False, # resize_to_max_canvas
134-
[0.2230, 0.1763], # expected_tile_means
135-
[1.0, 1.0], # expected_tile_max
136-
[0.0, 0.0], # expected_tile_min
137-
[1, 2], # expected_aspect_ratio
138-
),
139-
(
140-
(1000, 300, 3), # image_size
141-
torch.Size([4, 3, 224, 224]), # expected shape
142-
True, # resize_to_max_canvas
143-
[0.5005, 0.4992, 0.5004, 0.1651], # expected_tile_means
144-
[0.9976, 0.9940, 0.9936, 0.9906], # expected_tile_max
145-
[0.0037, 0.0047, 0.0039, 0.0], # expected_tile_min
146-
[4, 1], # expected_aspect_ratio
147-
),
148-
(
149-
(200, 200, 3), # image_size
150-
torch.Size([4, 3, 224, 224]), # expected shape
151-
True, # resize_to_max_canvas
152-
[0.5012, 0.5020, 0.5010, 0.4991], # expected_tile_means
153-
[0.9921, 0.9925, 0.9969, 0.9908], # expected_tile_max
154-
[0.0056, 0.0069, 0.0059, 0.0032], # expected_tile_min
155-
[2, 2], # expected_aspect_ratio
156-
),
157-
(
158-
(600, 200, 3), # image_size
159-
torch.Size([3, 3, 224, 224]), # expected shape
160-
False, # resize_to_max_canvas
161-
[0.4472, 0.4468, 0.3031], # expected_tile_means
162-
[1.0, 1.0, 1.0], # expected_tile_max
163-
[0.0, 0.0, 0.0], # expected_tile_min
164-
[3, 1], # expected_aspect_ratio
165-
),
166-
]
167-
)
168-
def test_preprocess(
162+
def run_preprocess(
169163
self,
170164
image_size: Tuple[int],
171165
expected_shape: torch.Size,
@@ -175,45 +169,7 @@ def test_preprocess(
175169
expected_tile_min: List[float],
176170
expected_ar: List[int],
177171
) -> None:
178-
config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas)
179-
180-
reference_model = CLIPImageTransform(
181-
image_mean=config.image_mean,
182-
image_std=config.image_std,
183-
resize_to_max_canvas=config.resize_to_max_canvas,
184-
resample=config.resample,
185-
antialias=config.antialias,
186-
tile_size=config.tile_size,
187-
max_num_tiles=config.max_num_tiles,
188-
possible_resolutions=None,
189-
)
190-
191-
eager_model = _CLIPImageTransform(
192-
image_mean=config.image_mean,
193-
image_std=config.image_std,
194-
resample=config.resample,
195-
antialias=config.antialias,
196-
tile_size=config.tile_size,
197-
max_num_tiles=config.max_num_tiles,
198-
)
199-
200-
exported_model = export_preprocess(
201-
image_mean=config.image_mean,
202-
image_std=config.image_std,
203-
resample=config.resample,
204-
antialias=config.antialias,
205-
tile_size=config.tile_size,
206-
max_num_tiles=config.max_num_tiles,
207-
)
208-
209-
executorch_model = lower_to_executorch_preprocess(exported_model)
210-
executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer)
211-
212-
aoti_path = torch._inductor.aot_compile(
213-
exported_model.module(),
214-
get_example_inputs(),
215-
)
216-
172+
models = self.models_resize if resize_to_max_canvas else self.models_no_resize
217173
# Prepare image input.
218174
image = (
219175
np.random.randint(0, 256, np.prod(image_size))
@@ -223,6 +179,7 @@ def test_preprocess(
223179
image = PIL.Image.fromarray(image)
224180

225181
# Run reference model.
182+
reference_model = models["reference_model"]
226183
reference_output = reference_model(image=image)
227184
reference_image = reference_output["image"]
228185
reference_ar = reference_output["aspect_ratio"].tolist()
@@ -249,10 +206,11 @@ def test_preprocess(
249206
# Pre-work for eager and exported models. The reference model performs these
250207
# calculations and passes the result to _CLIPImageTransform, the exportable model.
251208
image_tensor, inscribed_size, best_resolution = self.prepare_inputs(
252-
image=image, config=config
209+
image=image, config=models["config"]
253210
)
254211

255212
# Run eager model and check it matches reference model.
213+
eager_model = models["model"].get_eager_model()
256214
eager_image, eager_ar = eager_model(
257215
image_tensor, inscribed_size, best_resolution
258216
)
@@ -261,6 +219,7 @@ def test_preprocess(
261219
self.assertEqual(reference_ar, eager_ar)
262220

263221
# Run exported model and check it matches reference model.
222+
exported_model = models["exported_model"]
264223
exported_image, exported_ar = exported_model.module()(
265224
image_tensor, inscribed_size, best_resolution
266225
)
@@ -269,14 +228,65 @@ def test_preprocess(
269228
self.assertEqual(reference_ar, exported_ar)
270229

271230
# Run executorch model and check it matches reference model.
231+
executorch_model = models["executorch_model"]
232+
executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer)
272233
et_image, et_ar = executorch_module.forward(
273234
(image_tensor, inscribed_size, best_resolution)
274235
)
275236
self.assertTrue(torch.allclose(reference_image, et_image))
276237
self.assertEqual(reference_ar, et_ar.tolist())
277238

278239
# Run aoti model and check it matches reference model.
279-
aoti_model = torch._export.aot_load(aoti_path, "cpu")
280-
aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
281-
self.assertTrue(torch.allclose(reference_image, aoti_image))
282-
self.assertEqual(reference_ar, aoti_ar.tolist())
240+
# aoti_path = models["aoti_path"]
241+
# aoti_model = torch._export.aot_load(aoti_path, "cpu")
242+
# aoti_image, aoti_ar = aoti_model(image_tensor, inscribed_size, best_resolution)
243+
# self.assertTrue(torch.allclose(reference_image, aoti_image))
244+
# self.assertEqual(reference_ar, aoti_ar.tolist())
245+
246+
# This test setup mirrors the one in torchtune:
247+
# https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py
248+
# The values are slightly different, as torchtune uses antialias=True,
249+
# and this test uses antialias=False, which is exportable (has a portable kernel).
250+
def test_preprocess1(self):
251+
self.run_preprocess(
252+
(100, 400, 3), # image_size
253+
torch.Size([2, 3, 224, 224]), # expected shape
254+
False, # resize_to_max_canvas
255+
[0.2230, 0.1763], # expected_tile_means
256+
[1.0, 1.0], # expected_tile_max
257+
[0.0, 0.0], # expected_tile_min
258+
[1, 2], # expected_aspect_ratio
259+
)
260+
261+
def test_preprocess2(self):
262+
self.run_preprocess(
263+
(1000, 300, 3), # image_size
264+
torch.Size([4, 3, 224, 224]), # expected shape
265+
True, # resize_to_max_canvas
266+
[0.5005, 0.4992, 0.5004, 0.1651], # expected_tile_means
267+
[0.9976, 0.9940, 0.9936, 0.9906], # expected_tile_max
268+
[0.0037, 0.0047, 0.0039, 0.0], # expected_tile_min
269+
[4, 1], # expected_aspect_ratio
270+
)
271+
272+
def test_preprocess3(self):
273+
self.run_preprocess(
274+
(200, 200, 3), # image_size
275+
torch.Size([4, 3, 224, 224]), # expected shape
276+
True, # resize_to_max_canvas
277+
[0.5012, 0.5020, 0.5010, 0.4991], # expected_tile_means
278+
[0.9921, 0.9925, 0.9969, 0.9908], # expected_tile_max
279+
[0.0056, 0.0069, 0.0059, 0.0032], # expected_tile_min
280+
[2, 2], # expected_aspect_ratio
281+
)
282+
283+
def test_preprocess4(self):
284+
self.run_preprocess(
285+
(600, 200, 3), # image_size
286+
torch.Size([3, 3, 224, 224]), # expected shape
287+
False, # resize_to_max_canvas
288+
[0.4472, 0.4468, 0.3031], # expected_tile_means
289+
[1.0, 1.0, 1.0], # expected_tile_max
290+
[0.0, 0.0, 0.0], # expected_tile_min
291+
[3, 1], # expected_aspect_ratio
292+
)

0 commit comments

Comments
 (0)