Skip to content

Commit d863214

Browse files
larryliu0820facebook-github-bot
authored andcommitted
Do not error out when op is not in ATen namespace (#4229)
Summary: if `_check_ir_validity=True` we shouldn't error out if an operator is a custom op. Pull Request resolved: #4229 Reviewed By: JacobSzwejbka Differential Revision: D59652257 Pulled By: larryliu0820 fbshipit-source-id: a52739f4082756f741241b9ac6a31f6bf6967ecb
1 parent f9efb05 commit d863214

File tree

2 files changed

+7
-15
lines changed

2 files changed

+7
-15
lines changed

exir/program/test/test_program.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -537,16 +537,18 @@ def _test_edge_dialect_verifier(self, callable, validate_ir=True):
537537
_ = to_edge(exported_foo, compile_config=edge_compile_config)
538538

539539
def test_edge_dialect_custom_op(self):
540+
# We shouldn't error out if there's a custom op in the graph.
540541
def _use_foo_add(a: torch.Tensor, b: torch.Tensor):
541542
return torch.ops.exir_program_test_op.foo(a, b)
542543

543544
from torch._export.verifier import SpecViolationError
544545

545-
with self.assertRaises(SpecViolationError):
546+
try:
547+
# This should not raise error
546548
self._test_edge_dialect_verifier(_use_foo_add)
547-
548-
# This should not raise error
549-
self._test_edge_dialect_verifier(_use_foo_add, False)
549+
self._test_edge_dialect_verifier(_use_foo_add, False)
550+
except SpecViolationError:
551+
self.fail("Should not error out on custom op")
550552

551553
def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module):
552554
# This is the pre-dispatch export that we will be switching to primarily

exir/verification/verifier.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,7 @@ def _get_exception_list(self) -> List[torch._ops.OpOverload]:
9898
def check_valid_op(self, op):
9999
if isinstance(op, OpOverload):
100100
# TODO These special ops should be removable easily.
101-
if (
102-
op.namespace
103-
in [
104-
"quantized_decomposed",
105-
"boltnn_nimble",
106-
"nimble",
107-
"quantized",
108-
"dim_order_ops",
109-
]
110-
or op in self._get_exception_list()
111-
):
101+
if op.namespace != "aten" or op in self._get_exception_list():
112102
return
113103
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
114104
# NOTE(qihan): whether view_copy operators are marked as canonical is still under

0 commit comments

Comments
 (0)