Skip to content

Commit 672f4b4

Browse files
committed
Turn on python dispatcher for EdgeOpArgValidator (#3809)
Summary: Pull Request resolved: #3809 Possibly fixes #3659 We need to enable the python dispatcher so that expand_copy and view_copy will go through the correct meta kernels Reviewed By: larryliu0820 Differential Revision: D58091304 fbshipit-source-id: f8907ee130720b01c629d55f222eb5a7e63a34bd (cherry picked from commit ab6f177)
1 parent 50d1da2 commit 672f4b4

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

exir/program/test/test_program.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.exir.lowered_backend_module import get_lowered_submodules
1818
from executorch.exir.pass_base import ExportPass
1919
from executorch.exir.program._program import (
20+
EdgeCompileConfig,
2021
EdgeProgramManager,
2122
ExecutorchProgramManager,
2223
to_edge,
@@ -26,7 +27,7 @@
2627
from executorch.extension.pybindings.portable_lib import (
2728
_load_for_executorch_from_buffer,
2829
)
29-
from torch.export import export, ExportedProgram
30+
from torch.export import Dim, export, ExportedProgram
3031

3132
from torch.library import impl, Library
3233

@@ -225,6 +226,38 @@ def test_edge_manager_transform(self):
225226
original_res, # x * y + x
226227
)
227228

229+
def test_issue_3659(self):
230+
231+
class Mul(torch.nn.Module):
232+
def __init__(self):
233+
super(Mul, self).__init__()
234+
235+
def forward(self, x: torch.Tensor, y: torch.Tensor):
236+
return torch.matmul(x, y)
237+
238+
def get_eager_model(self) -> torch.nn.Module:
239+
return self
240+
241+
def get_example_inputs(self):
242+
return (torch.randn(1, 3, 10), torch.randn(1, 10, 3))
243+
244+
def get_dynamic_shapes(self):
245+
dim1_x = Dim("Dot_dim1_x", min=2, max=100)
246+
dim2_x = Dim("Dot_dim2_x", min=2, max=100)
247+
return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}}
248+
249+
model = Mul()
250+
ep = torch.export.export(
251+
model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes()
252+
)
253+
254+
to_edge(
255+
ep,
256+
compile_config=EdgeCompileConfig(
257+
_check_ir_validity=True,
258+
),
259+
)
260+
228261
def test_transform_dict_api(self):
229262
edge_manager = to_edge(get_exported_programs(), get_config_methods())
230263

exir/verification/verifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
EdgeOpArgValidator,
1818
RunHigherOrderOperatorError,
1919
)
20+
from torch._dispatch.python import enable_python_dispatcher
2021

2122
from torch._export.verifier import SpecViolationError, Verifier
2223
from torch._ops import OpOverload
@@ -119,7 +120,8 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
119120
validator = EdgeOpArgValidator(gm)
120121
inputs = _get_inputs(gm)
121122
try:
122-
validator.run(*inputs)
123+
with enable_python_dispatcher():
124+
validator.run(*inputs)
123125
except RunHigherOrderOperatorError:
124126
# NB: ignore higher order operator in the graph.
125127
# If we lower a graph module to delegate and then compose it with some other graph module, retrace it,

0 commit comments

Comments
 (0)