Skip to content

Commit ac75602

Browse files
committed
run preprocess pte file
1 parent 7a59904 commit ac75602

File tree

3 files changed

+34
-19
lines changed

3 files changed

+34
-19
lines changed

examples/models/flamingo/export_preprocess_lib.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
11+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
1112
from executorch.exir.program._program import ExecutorchProgramManager
1213

1314
from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa
@@ -76,5 +77,9 @@ def lower_to_executorch_preprocess(
7677
exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
7778
)
7879

79-
et_program = edge_program.to_executorch(ExecutorchBackendConfig())
80+
et_program = edge_program.to_executorch(
81+
ExecutorchBackendConfig(
82+
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
83+
)
84+
)
8085
return et_program
Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
import torch
2+
from executorch.extension.pybindings.portable_lib import _load_for_executorch
3+
from executorch.extension.pybindings import portable_lib
4+
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache
5+
model = _load_for_executorch("preprocess.pte")
16

7+
image = torch.ones(3, 800, 600)
8+
target_size = torch.tensor([448, 336])
9+
canvas_size = torch.tensor([448, 448])
210

3-
class NativePreprocessRunner():
4-
"""
5-
Runs preprocess via ExecuTorch with provided pte file.
6-
"""
7-
def __init__(self, pte_file):
8-
super().__init__()
9-
self.model = _load_for_executorch(pte_file)
10-
11-
def forward(
12-
self,
13-
image: torch.Tensor,
14-
target_size: torch.Tensor,
15-
canvas_size: torch.Tensor,
16-
):
17-
return self.model.forward((image, target_size, canvas_size))
11+
res = model.forward((image, target_size, canvas_size))
12+
print(res[0].shape)
13+
14+
image2 = torch.ones(3, 200, 200)
15+
target_size2 = torch.tensor([224, 224])
16+
canvas_size2 = torch.tensor([224, 224])
17+
res2 = model.forward((image2, target_size2, canvas_size2))
18+
print(res2[0].shape)

extension/llm/custom_ops/preprocess_custom_ops.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515

1616
# Register and define tile_crop and out variant.
1717
preprocess_op_lib.define("tile_crop(Tensor input, int tile_size) -> Tensor")
18+
from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size, ShapeEnv
19+
20+
# Keep this in sync with model config.
21+
MAX_NUM_TILES = 4
1822

1923

2024
@impl(preprocess_op_lib, "tile_crop", dispatch_key="CompositeExplicitAutograd")
@@ -56,6 +60,11 @@ def tile_crop_out_impl(
5660
# Register meta kernel to prevent export tracing into the tile_crop impl.
5761
@torch.library.register_fake("preprocess::tile_crop")
5862
def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor:
59-
# Returned tensor is of size [n, 3, 224, 224], where n is the number of tiles.
60-
# We should export with n = max_num_tiles. Set 50 for now.
61-
return torch.empty([50, output.size(0), 224, 224])
63+
# Returned tensor is of size [n, 3, 224, 224], where n = number of tiles.
64+
# Use an unbacked symint to create an upper-bounded dynamic shape output.
65+
# Otherwise, output is set to a static shape, and we can only output
66+
# tensors of shape [MAX_NUM_TILES, 3, 224, 224].
67+
shape_env = ShapeEnv()
68+
s0 = shape_env.create_unbacked_symint()
69+
_constrain_range_for_size(s0, 0, MAX_NUM_TILES)
70+
return torch.empty([s0, output.size(0), tile_size, tile_size])

0 commit comments

Comments
 (0)