Skip to content

Commit a04a87f

Browse files
[et][dim order] Make edge verifier support empty operator
Differential Revision: [D66801315](https://our.internmc.facebook.com/intern/diff/D66801315/) ghstack-source-id: 256642462 Pull Request resolved: #7188 --------- Co-authored-by: gasoonjia <[email protected]>
1 parent b886504 commit a04a87f

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

exir/verification/test/test_verifier.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,9 @@ def __init__(self):
117117

118118
def forward(self, x: torch.Tensor) -> torch.Tensor:
119119
t1 = x.to(dtype=torch.double, memory_format=torch.channels_last)
120-
t2 = t1 + t1
121-
return t1 * t2
120+
t2 = torch.empty(t1.size(), memory_format=torch.channels_last)
121+
t2.copy_(t1)
122+
return t2
122123

123124
m = Model().eval()
124125

exir/verification/verifier.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1616
from executorch.exir.error import ExportError, ExportErrorType
1717
from executorch.exir.lowered_backend_module import LoweredBackendModule
18+
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
1819
from executorch.exir.verification.arg_validator import (
1920
EdgeOpArgValidator,
2021
RunHigherOrderOperatorError,
2122
)
23+
2224
from torch._dispatch.python import enable_python_dispatcher
2325
from torch._export.utils import _detect_fake_mode_from_gm
2426

@@ -44,7 +46,7 @@ def _check_tensors_are_contiguous(gm: GraphModule) -> None:
4446

4547
def _check_valid_dim_order_ops(op, use_dim_order) -> None:
4648
if use_dim_order:
47-
if op in (torch.ops.aten._to_copy.default,):
49+
if op in DimOrderOpsMap:
4850
raise SpecViolationError(f"{op} should not be used in dim_order mode")
4951
else: # not using dim_order
5052
if op.namespace in ("dim_order_ops",):
@@ -249,7 +251,7 @@ def check_valid_edge_op(self, op):
249251
)
250252
)
251253
if isinstance(op, EdgeOpOverload):
252-
_check_valid_dim_order_ops(op._op, self.use_dim_order)
254+
_check_valid_dim_order_ops(op, self.use_dim_order)
253255
self.check_valid_aten_op(op._op)
254256

255257
if isinstance(op, types.FunctionType):

0 commit comments

Comments
 (0)