Skip to content

Commit 6e788c7

Browse files
pytorchbotangelayi
andauthored
Add fake mode in verifier (#6132)
Add fake mode in verifier (#5805) Summary: Pull Request resolved: #5805 Hopefully fixes https://fb.workplace.com/groups/pytorch.edge.users/permalink/1605630670307220/ Reviewed By: larryliu0820 Differential Revision: D63734251 fbshipit-source-id: 854750227b64125a3609245c6cfcbff26b71f26a (cherry picked from commit c10c96a) Co-authored-by: Angela Yi <[email protected]>
1 parent b077801 commit 6e788c7

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

exir/tests/test_tracer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,17 @@ def f(x: torch.Tensor) -> torch.Tensor:
100100
any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes)
101101
)
102102

103+
def test_ones(self) -> None:
104+
class M(torch.nn.Module):
105+
def forward(self, x):
106+
y = torch.ones(x.shape[0])
107+
return x + y
108+
109+
ep = torch.export.export(
110+
M(), (torch.ones(3),), dynamic_shapes={"x": {0: torch.export.Dim("x")}}
111+
)
112+
exir.to_edge(ep)
113+
103114
def test_possible_input_mutation(self) -> None:
104115
def f(x: torch.Tensor) -> torch.Tensor:
105116
return torch.add(torch.ones(5), torch.ones(5), out=x)

exir/verification/verifier.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import itertools
88
import operator
99
import types
10+
from contextlib import nullcontext
1011
from typing import Any, List, Optional, Tuple, Type
1112

1213
import torch
@@ -19,6 +20,7 @@
1920
RunHigherOrderOperatorError,
2021
)
2122
from torch._dispatch.python import enable_python_dispatcher
23+
from torch._export.utils import _detect_fake_mode_from_gm
2224

2325
from torch._export.verifier import SpecViolationError, Verifier
2426
from torch._ops import OpOverload
@@ -161,8 +163,9 @@ def extract_input(node: torch.fx.Node) -> Optional[FakeTensor]:
161163
def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
162164
validator = EdgeOpArgValidator(gm)
163165
inputs = _get_inputs(gm)
166+
fake_mode = _detect_fake_mode_from_gm(gm) or nullcontext()
164167
try:
165-
with enable_python_dispatcher():
168+
with enable_python_dispatcher(), fake_mode:
166169
validator.run(*inputs)
167170
except RunHigherOrderOperatorError:
168171
# NB: ignore higher order operator in the graph.

0 commit comments

Comments
 (0)