Skip to content

Commit 1cea0ee

Browse files
authored
preprocess e2e test
Differential Revision: D61780713 Pull Request resolved: #4887
1 parent 37db39a commit 1cea0ee

File tree

4 files changed

+50
-12
lines changed

4 files changed

+50
-12
lines changed

examples/models/flamingo/export_preprocess_lib.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
11+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
1112
from executorch.exir.program._program import ExecutorchProgramManager
1213

1314
from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa
@@ -76,5 +77,9 @@ def lower_to_executorch_preprocess(
7677
exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
7778
)
7879

79-
et_program = edge_program.to_executorch(ExecutorchBackendConfig())
80+
et_program = edge_program.to_executorch(
81+
ExecutorchBackendConfig(
82+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
83+
)
84+
)
8085
return et_program

examples/models/flamingo/test_preprocess.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
import PIL
1414
import torch
1515

16+
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
17+
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
18+
from executorch.extension.pybindings.portable_lib import (
19+
_load_for_executorch_from_buffer,
20+
)
21+
1622
from parameterized import parameterized
1723
from PIL import Image
1824

@@ -21,14 +27,17 @@
2127
CLIPImageTransform,
2228
)
2329

24-
from torchtune.modules.transforms import (
30+
from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import (
2531
find_supported_resolutions,
2632
get_canvas_best_fit,
33+
)
34+
35+
from torchtune.modules.transforms.vision_utils.get_inscribed_size import (
2736
get_inscribed_size,
2837
)
2938
from torchvision.transforms.v2 import functional as F
3039

31-
from .export_preprocess_lib import export_preprocess
40+
from .export_preprocess_lib import export_preprocess, lower_to_executorch_preprocess
3241

3342

3443
@dataclass
@@ -74,6 +83,13 @@ def prepare_inputs(
7483
F.grayscale_to_rgb_image(F.to_image(image)), scale=True
7584
)
7685

86+
# The above converts the PIL image into a torchvision tv_tensor.
87+
# Convert the tv_tensor into a torch.Tensor.
88+
image_tensor = image_tensor + 0
89+
90+
# Ensure tensor is contiguous for executorch.
91+
image_tensor = image_tensor.contiguous()
92+
7793
# Calculate possible resolutions.
7894
possible_resolutions = config.possible_resolutions
7995
if possible_resolutions is None:
@@ -187,6 +203,9 @@ def test_preprocess(
187203
max_num_tiles=config.max_num_tiles,
188204
)
189205

206+
executorch_model = lower_to_executorch_preprocess(exported_model)
207+
executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer)
208+
190209
# Prepare image input.
191210
image = (
192211
np.random.randint(0, 256, np.prod(image_size))
@@ -225,20 +244,25 @@ def test_preprocess(
225244
image=image, config=config
226245
)
227246

228-
# Run eager and exported models.
247+
# Run eager model and check it matches reference model.
229248
eager_image, eager_ar = eager_model(
230249
image_tensor, inscribed_size, best_resolution
231250
)
232251
eager_ar = eager_ar.tolist()
252+
self.assertTrue(torch.allclose(reference_image, eager_image))
253+
self.assertEqual(reference_ar, eager_ar)
233254

255+
# Run exported model and check it matches reference model.
234256
exported_image, exported_ar = exported_model.module()(
235257
image_tensor, inscribed_size, best_resolution
236258
)
237259
exported_ar = exported_ar.tolist()
238-
239-
# Check eager and exported models match reference model.
240-
self.assertTrue(torch.allclose(reference_image, eager_image))
241260
self.assertTrue(torch.allclose(reference_image, exported_image))
261+
self.assertEqual(reference_ar, exported_ar)
242262

243-
self.assertTrue(reference_ar, eager_ar)
244-
self.assertTrue(reference_ar, exported_ar)
263+
# Run executorch model and check it matches reference model.
264+
et_image, et_ar = executorch_module.forward(
265+
(image_tensor, inscribed_size, best_resolution)
266+
)
267+
self.assertTrue(torch.allclose(reference_image, et_image))
268+
self.assertEqual(reference_ar, et_ar.tolist())

extension/llm/custom_ops/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT)
8282
add_library(
8383
custom_ops_aot_lib SHARED ${_custom_ops__srcs}
8484
${CMAKE_CURRENT_SOURCE_DIR}/op_sdpa_aot.cpp
85+
${CMAKE_CURRENT_SOURCE_DIR}/op_tile_crop.cpp
8586
)
8687
target_include_directories(
8788
custom_ops_aot_lib PUBLIC "${_common_include_directories}"

extension/llm/custom_ops/preprocess_custom_ops.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
# Register and define tile_crop and out variant.
1717
preprocess_op_lib.define("tile_crop(Tensor input, int tile_size) -> Tensor")
1818

19+
# Keep this in sync with model config.
20+
MAX_NUM_TILES = 4
21+
1922

2023
@impl(preprocess_op_lib, "tile_crop", dispatch_key="CompositeExplicitAutograd")
2124
def tile_crop_impl(input: torch.Tensor, tile_size: int) -> torch.Tensor:
@@ -56,6 +59,11 @@ def tile_crop_out_impl(
5659
# Register meta kernel to prevent export tracing into the tile_crop impl.
5760
@torch.library.register_fake("preprocess::tile_crop")
5861
def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor:
59-
# Returned tensor is of size [n, 3, 224, 224], where n is the number of tiles.
60-
# We should export with n = max_num_tiles. Set 50 for now.
61-
return torch.empty([50, output.size(0), 224, 224])
62+
# Returned tensor is of size [n, 3, 224, 224], where n = number of tiles.
63+
# Use an unbacked symint to create an upper-bounded dynamic shape output.
64+
# Otherwise, output is set to a static shape, and we can only output
65+
# tensors of shape [MAX_NUM_TILES, 3, 224, 224].
66+
ctx = torch._custom_ops.get_ctx()
67+
s0 = ctx.create_unbacked_symint()
68+
torch._constrain_as_size(s0, 0, MAX_NUM_TILES)
69+
return torch.empty([s0, output.size(0), tile_size, tile_size])

0 commit comments

Comments
 (0)