|
13 | 13 | import PIL
|
14 | 14 | import torch
|
15 | 15 |
|
| 16 | +from executorch.extension.pybindings import portable_lib # noqa # usort: skip |
| 17 | +from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip |
| 18 | +from executorch.extension.pybindings.portable_lib import ( |
| 19 | + _load_for_executorch_from_buffer, |
| 20 | +) |
| 21 | + |
16 | 22 | from parameterized import parameterized
|
17 | 23 | from PIL import Image
|
18 | 24 |
|
|
21 | 27 | CLIPImageTransform,
|
22 | 28 | )
|
23 | 29 |
|
24 |
| -from torchtune.modules.transforms import ( |
| 30 | +from torchtune.modules.transforms.vision_utils.get_canvas_best_fit import ( |
25 | 31 | find_supported_resolutions,
|
26 | 32 | get_canvas_best_fit,
|
| 33 | +) |
| 34 | + |
| 35 | +from torchtune.modules.transforms.vision_utils.get_inscribed_size import ( |
27 | 36 | get_inscribed_size,
|
28 | 37 | )
|
29 | 38 | from torchvision.transforms.v2 import functional as F
|
30 | 39 |
|
31 |
| -from .export_preprocess_lib import export_preprocess |
| 40 | +from .export_preprocess_lib import export_preprocess, lower_to_executorch_preprocess |
32 | 41 |
|
33 | 42 |
|
34 | 43 | @dataclass
|
@@ -74,6 +83,13 @@ def prepare_inputs(
|
74 | 83 | F.grayscale_to_rgb_image(F.to_image(image)), scale=True
|
75 | 84 | )
|
76 | 85 |
|
| 86 | + # The above converts the PIL image into a torchvision tv_tensor. |
| 87 | + # Convert the tv_tensor into a torch.Tensor. |
| 88 | + image_tensor = image_tensor + 0 |
| 89 | + |
| 90 | + # Ensure tensor is contiguous for executorch. |
| 91 | + image_tensor = image_tensor.contiguous() |
| 92 | + |
77 | 93 | # Calculate possible resolutions.
|
78 | 94 | possible_resolutions = config.possible_resolutions
|
79 | 95 | if possible_resolutions is None:
|
@@ -187,6 +203,9 @@ def test_preprocess(
|
187 | 203 | max_num_tiles=config.max_num_tiles,
|
188 | 204 | )
|
189 | 205 |
|
| 206 | + executorch_model = lower_to_executorch_preprocess(exported_model) |
| 207 | + executorch_module = _load_for_executorch_from_buffer(executorch_model.buffer) |
| 208 | + |
190 | 209 | # Prepare image input.
|
191 | 210 | image = (
|
192 | 211 | np.random.randint(0, 256, np.prod(image_size))
|
@@ -225,20 +244,25 @@ def test_preprocess(
|
225 | 244 | image=image, config=config
|
226 | 245 | )
|
227 | 246 |
|
228 |
| - # Run eager and exported models. |
| 247 | + # Run eager model and check it matches reference model. |
229 | 248 | eager_image, eager_ar = eager_model(
|
230 | 249 | image_tensor, inscribed_size, best_resolution
|
231 | 250 | )
|
232 | 251 | eager_ar = eager_ar.tolist()
|
| 252 | + self.assertTrue(torch.allclose(reference_image, eager_image)) |
| 253 | + self.assertEqual(reference_ar, eager_ar) |
233 | 254 |
|
| 255 | + # Run exported model and check it matches reference model. |
234 | 256 | exported_image, exported_ar = exported_model.module()(
|
235 | 257 | image_tensor, inscribed_size, best_resolution
|
236 | 258 | )
|
237 | 259 | exported_ar = exported_ar.tolist()
|
238 |
| - |
239 |
| - # Check eager and exported models match reference model. |
240 |
| - self.assertTrue(torch.allclose(reference_image, eager_image)) |
241 | 260 | self.assertTrue(torch.allclose(reference_image, exported_image))
|
| 261 | + self.assertEqual(reference_ar, exported_ar) |
242 | 262 |
|
243 |
| - self.assertTrue(reference_ar, eager_ar) |
244 |
| - self.assertTrue(reference_ar, exported_ar) |
| 263 | + # Run executorch model and check it matches reference model. |
| 264 | + et_image, et_ar = executorch_module.forward( |
| 265 | + (image_tensor, inscribed_size, best_resolution) |
| 266 | + ) |
| 267 | + self.assertTrue(torch.allclose(reference_image, et_image)) |
| 268 | + self.assertEqual(reference_ar, et_ar.tolist()) |
0 commit comments