Skip to content

Verifier for exported program #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 9 additions & 35 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,16 @@

# pye-strict

import copy
import unittest
from typing import Any, Callable, Dict
from typing import Any, Dict

import torch
from executorch.exir import ExecutorchBackendConfig
from executorch.exir.backend.test.op_partitioner_demo import AddMulPartitionerDemo
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.error import ExportError
from executorch.exir.lowered_backend_module import get_lowered_submodules
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes.replace_aten_with_edge_pass import (
aten_to_edge,
should_lower_to_edge,
)
from executorch.exir.pass_base import ExportPass
from executorch.exir.program._program import (
EdgeProgramManager,
ExecutorchProgramManager,
Expand All @@ -31,9 +26,7 @@
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from torch import fx
from torch.export import export, ExportedProgram
from torch.fx import GraphModule, subgraph_rewriter


def get_exported_programs() -> Dict[str, ExportedProgram]:
Expand Down Expand Up @@ -70,32 +63,13 @@ def bar():


class AddToMulPassEdge(ExportPass):
def call(self, graph_module: GraphModule) -> PassResult:
"""
Dummy pass that replaces add with mul
"""

def _trace_and_lower_to_edge_ops(f: Callable) -> fx.GraphModule:
gm = fx.symbolic_trace(f)
for node in gm.graph.nodes:
if node.op == "call_function" and should_lower_to_edge(node.target):
node.target = aten_to_edge(node.target)
gm.recompile()
return gm

def pattern(x: torch.Tensor, y: torch.Tensor):
return torch.ops.aten.add.Tensor(x, y)

def replacement(x: torch.Tensor, y: torch.Tensor):
return torch.ops.aten.mul.Tensor(x, y)

new_graph_module = copy.deepcopy(graph_module)
subgraph_rewriter.replace_pattern_with_filters(
new_graph_module,
_trace_and_lower_to_edge_ops(pattern),
_trace_and_lower_to_edge_ops(replacement),
)
return PassResult(new_graph_module, True)
def call_operator(self, op, args, kwargs, meta):
if op == exir_ops.edge.aten.add.Tensor:
return super().call_operator(
exir_ops.edge.aten.mul.Tensor, args, kwargs, meta
)
else:
return super().call_operator(op, args, kwargs, meta)


class TestProgramManagers(unittest.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions exir/tests/test_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.assertEqual(len(res), len(res_interp))
self.assertTrue(torch.allclose(res, res_interp))


class TestEdgeVerification(unittest.TestCase):
def test_edge_happy(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
Expand Down
2 changes: 1 addition & 1 deletion exir/verification/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ python_library(
],
deps = [
"//caffe2:torch",
"//executorch/exir:delegate",
"//executorch/exir:error",
"//executorch/exir:lowered_backend_module",
"//executorch/exir/dialects/edge:lib",
"//executorch/exir/emit:emit",
],
Expand Down
86 changes: 37 additions & 49 deletions exir/verification/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import List, Optional
import itertools
import operator
from typing import Any, List, Optional, Tuple, Type

import torch
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.error import ExportError, ExportErrorType
from executorch.exir.lowered_backend_module import LoweredBackendModule
from executorch.exir.verification.arg_validator import (
EdgeOpArgValidator,
RunHigherOrderOperatorError,
)

from torch._export.verifier import (
_check_has_fake_tensor,
_check_tensors_are_contiguous,
ATenDialectVerifier,
SpecViolationError,
Verifier,
Expand All @@ -28,30 +29,21 @@


ALLOWED_META_KEYS = {"spec", "stack_trace"}
VALID_BUILTIN_FUNCS = [
executorch_call_delegate,
]


class EXIRATenDialectVerifier(ATenDialectVerifier):
def valid_builtin_funcs(self):
builtin_funcs = super().valid_builtin_funcs()
builtin_funcs.extend(VALID_BUILTIN_FUNCS)
return builtin_funcs

# TODO(angelayi): Delete this function when we migrate all tests to
# because right now old tracer does not add ["val"] metadata
def check_valid(self, gm: GraphModule) -> None: # noqa: C901

for node in gm.graph.nodes:
if node.op in {"call_module", "call_method"}:
def _check_tensors_are_contiguous(gm: GraphModule) -> None:
# Tensors be of contiguous format
for name, param in itertools.chain(gm.named_parameters(), gm.named_buffers()):
if isinstance(param, torch.Tensor):
if not param.is_contiguous():
raise SpecViolationError(
"call_module is not valid: got a class '{}' ".format(node.target),
f"Tensors in Aten dialect must be contiguous, {name} is not contiguous"
)

if node.op == "call_function":
if node.target not in self.valid_builtin_funcs():
self.check_valid_op(node.target)

class EXIRATenDialectVerifier(ATenDialectVerifier):
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (torch.fx.GraphModule, LoweredBackendModule, torch.Tensor)


def _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]:
Expand Down Expand Up @@ -97,15 +89,21 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:


class EXIREdgeDialectVerifier(Verifier):
def __init__(self, check_edge_ops: bool = False) -> None:
def __init__(self, check_edge_ops: bool = True) -> None:
self.check_edge_ops = check_edge_ops

def valid_builtin_funcs(self):
builtin_funcs = super().valid_builtin_funcs()
builtin_funcs.extend(VALID_BUILTIN_FUNCS)
return builtin_funcs
if self.check_edge_ops:
self.check_valid_op = self.check_valid_edge_op
else:
self.check_valid_op = self.check_valid_aten_op

def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
return (torch.fx.GraphModule, LoweredBackendModule, torch.Tensor)

def check_valid_edge_op(self, op):
if op in [operator.getitem]:
return

if isinstance(op, OpOverload) and not isinstance(op, EdgeOpOverload):
raise SpecViolationError(
"Operator {}.{} is not an Edge operator.".format(
Expand All @@ -116,33 +114,23 @@ def check_valid_edge_op(self, op):
def check_valid_aten_op(self, op) -> None:
super().check_valid_op(op)

op_name = op.name if hasattr(op, "name") else op.__name__

if not isinstance(op, OpOverload):
raise SpecViolationError(
"Operator '{}' is not a registered Op".format(op_name),
)

if (
torch.Tag.core not in op.tags # type: ignore[attr-defined]
and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined]
):
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
# discussion.
raise SpecViolationError(
"Operator {}.{} is not Aten Canonical.".format(
op.__module__, op.__name__
if isinstance(op, OpOverload):
if (
torch.Tag.core not in op.tags # type: ignore[attr-defined]
and torch.Tag.view_copy not in op.tags # type: ignore[attr-defined]
):
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
# discussion.
raise SpecViolationError(
"Operator {}.{} is not Aten Canonical.".format(
op.__module__, op.__name__
)
)
)

def check_valid(self, gm: GraphModule) -> None:
def check_additional(self, gm: GraphModule) -> None:
if self.check_edge_ops:
self.check_valid_op = self.check_valid_edge_op
super().check_valid(gm)
_check_tensors_are_contiguous(gm)
_check_tensor_args_matching_op_allowed_dtype(gm)
else:
self.check_valid_op = self.check_valid_aten_op

# Additionally, edge dialect's operator must have same input dtype
for n in gm.graph.nodes:
Expand Down