Skip to content

Commit f37bb47

Browse files
zhxchen17facebook-github-bot
authored andcommitted
Improve verifier to not specialize on dialect. (#1528)
Summary: X-link: pytorch/pytorch#116705 Pull Request resolved: #1528 Currently we have a very ugly specialization on edge dialect in verifier like the following: ``` # TODO Remove this branch. if ep.dialect == "EDGE": # !!! Don't change this allowlist. !!! pass else: raise e ``` In this diff we do some additional work to make signature checking also work in exir. We decouple the transformation stack in torch export and exir so that different layers of the stack can evolve in their own fashion and the team can divide and conquer them seperately. Reviewed By: angelayi Differential Revision: D52499225 fbshipit-source-id: d667397b8434d8f6b0d0d097d31a0dd589c329cc
1 parent 03093c2 commit f37bb47

File tree

4 files changed

+110
-14
lines changed

4 files changed

+110
-14
lines changed

backends/xnnpack/passes/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from executorch.exir.pass_base import ExportPass
2626

2727
from executorch.exir.passes.const_prop_pass import ConstPropPass
28+
29+
from executorch.exir.program._program import _transform
2830
from torch._export.pass_base import PassType
2931

3032
from torch.export import ExportedProgram
@@ -77,5 +79,5 @@ def transform(self) -> ExportedProgram:
7779
raise RuntimeError(
7880
f"Expecting ExportPass or ExportPass(), but got pass: {pass_} with type: {type(pass_)}"
7981
)
80-
ep = ep._transform(transform_pass)
82+
ep = _transform(ep, transform_pass)
8183
return ep

exir/capture/_capture.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.exir.capture._config import CaptureConfig
1515
from executorch.exir.error import ExportError, ExportErrorType, InternalError
1616
from executorch.exir.program import ExirExportedProgram, MultiMethodExirExportedProgram
17-
from executorch.exir.program._program import HackedUpExportedProgramDONOTUSE
17+
from executorch.exir.program._program import _transform, HackedUpExportedProgramDONOTUSE
1818
from executorch.exir.tracer import (
1919
_default_decomposition_table,
2020
dispatch_trace,
@@ -170,7 +170,7 @@ def capture( # noqa: C901
170170

171171
ep = export(f, args, constraints=constraints)
172172
ep = ep.run_decompositions(_default_decomposition_table())
173-
ep = ep._transform(ReplaceViewOpsWithViewCopyOpsPass())
173+
ep = _transform(ep, ReplaceViewOpsWithViewCopyOpsPass())
174174
if not config._unlift:
175175
return ExirExportedProgram(ep, False)
176176
graph_module = ep.module()

exir/lowered_backend_module.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ def buffer(
122122
# TODO(chenlai): re-consider recapture instead of manually constructing the program because
123123
# the meta data construction is done manually.
124124
def program(self, emit_stacktrace: bool = False) -> Program:
125+
from executorch.exir.program._program import (
126+
_get_updated_graph_signature,
127+
_transform,
128+
)
129+
125130
"""
126131
Returns the object that represents the ExecuTorch binary before serialization.
127132
"""
@@ -257,8 +262,11 @@ def program(self, emit_stacktrace: bool = False) -> Program:
257262
exported_program = ExportedProgram(
258263
root=lowered_exported_program.graph_module,
259264
graph=lowered_exported_program.graph,
260-
graph_signature=ExportGraphSignature(
261-
input_specs=input_specs, output_specs=output_specs
265+
graph_signature=_get_updated_graph_signature(
266+
ExportGraphSignature(
267+
input_specs=input_specs, output_specs=output_specs
268+
),
269+
lowered_exported_program.graph_module,
262270
),
263271
# TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None)
264272
# somewhere as we should pass it a list of tensors to the lowered module and output a
@@ -271,8 +279,8 @@ def program(self, emit_stacktrace: bool = False) -> Program:
271279
example_inputs=None,
272280
verifier=lowered_exported_program.verifier,
273281
)
274-
exported_program = exported_program._transform(
275-
SpecPropPass(), MemoryPlanningPass("greedy")
282+
exported_program = _transform(
283+
exported_program, SpecPropPass(), MemoryPlanningPass("greedy")
276284
)
277285
emitted_program = emit_program(
278286
exported_program, emit_stacktrace=emit_stacktrace

exir/program/_program.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,98 @@
3636
)
3737
from torch._export import ExportedProgram
3838
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
39-
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
39+
from torch.export.exported_program import (
40+
_get_updated_range_constraints,
41+
ConstantArgument,
42+
ExportGraphSignature,
43+
InputKind,
44+
InputSpec,
45+
OutputSpec,
46+
TensorArgument,
47+
)
4048
from torch.fx import _pytree as fx_pytree
4149
from torch.fx._compatibility import compatibility
50+
from torch.fx.passes.infra.pass_manager import PassManager
4251
from torch.utils import _pytree as pytree
4352

4453
Val = Any
4554

4655

56+
def _get_updated_graph_signature(
57+
old_signature: ExportGraphSignature,
58+
new_gm: torch.fx.GraphModule,
59+
) -> ExportGraphSignature:
60+
"""
61+
Update the graph signature's user_input/user_outputs.
62+
"""
63+
new_input_specs = []
64+
for i, node in enumerate(new_gm.graph.nodes):
65+
if node.op != "placeholder":
66+
break
67+
68+
assert i < len(
69+
old_signature.input_specs
70+
), "Number of inputs changed after transformation"
71+
old_input_spec = old_signature.input_specs[i]
72+
arg = (
73+
old_input_spec.arg
74+
if isinstance(old_input_spec.arg, ConstantArgument)
75+
else type(old_input_spec.arg)(node.name)
76+
)
77+
new_input_specs.append(
78+
InputSpec(old_input_spec.kind, arg, old_input_spec.target)
79+
)
80+
81+
output_node = list(new_gm.graph.nodes)[-1]
82+
assert output_node.op == "output"
83+
84+
new_output_specs = []
85+
for i, node in enumerate(output_node.args[0]):
86+
assert i < len(
87+
old_signature.output_specs
88+
), "Number of outputs changed after transformation"
89+
old_output_spec = old_signature.output_specs[i]
90+
arg = (
91+
old_output_spec.arg
92+
if isinstance(old_output_spec.arg, ConstantArgument)
93+
else type(old_output_spec.arg)(node.name)
94+
)
95+
new_output_specs.append(
96+
OutputSpec(old_output_spec.kind, arg, old_output_spec.target)
97+
)
98+
99+
new_signature = ExportGraphSignature(
100+
input_specs=new_input_specs, output_specs=new_output_specs
101+
)
102+
return new_signature
103+
104+
105+
def _transform(self, *passes: PassType) -> "ExportedProgram":
106+
pm = PassManager(list(passes))
107+
res = pm(self.graph_module)
108+
transformed_gm = res.graph_module if res is not None else self.graph_module
109+
assert transformed_gm is not None
110+
111+
if transformed_gm is self.graph_module and not res.modified:
112+
return self
113+
114+
transformed_ep = ExportedProgram(
115+
transformed_gm,
116+
transformed_gm.graph,
117+
_get_updated_graph_signature(self.graph_signature, transformed_gm),
118+
self.state_dict,
119+
_get_updated_range_constraints(transformed_gm),
120+
copy.deepcopy(self.equality_constraints),
121+
copy.deepcopy(self._module_call_graph),
122+
self.example_inputs,
123+
self.verifier,
124+
self.tensor_constants,
125+
)
126+
transformed_ep.graph_module.meta.update(self.graph_module.meta)
127+
transformed_ep.graph_module.meta.update(res.graph_module.meta)
128+
return transformed_ep
129+
130+
47131
def _copy_module(new_prog, new_gm):
48132
new_prog.meta.update(new_gm.meta)
49133
new_prog.graph = new_gm.graph
@@ -231,7 +315,7 @@ def __init__(
231315
self.after_to_edge_passes = after_to_edge_passes
232316

233317
def transform(self, *passes: PassType) -> "ExirExportedProgram":
234-
self.exported_program = self.exported_program._transform(*passes)
318+
self.exported_program = _transform(self.exported_program, *passes)
235319
return self
236320

237321
def __call__(self, *args: Any) -> Any:
@@ -419,7 +503,7 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
419503
new_ep.exported_program = ExportedProgram(
420504
new_gm,
421505
new_gm.graph,
422-
new_ep.exported_program.graph_signature,
506+
_get_updated_graph_signature(new_ep.exported_program.graph_signature, new_gm),
423507
new_ep.exported_program.state_dict,
424508
new_ep.exported_program.range_constraints,
425509
new_ep.exported_program.equality_constraints,
@@ -755,7 +839,9 @@ def to_edge(
755839
edge_program = ExportedProgram(
756840
root=gm,
757841
graph=gm.graph,
758-
graph_signature=edge_program.graph_signature,
842+
graph_signature=_get_updated_graph_signature(
843+
edge_program.graph_signature, gm
844+
),
759845
state_dict=edge_program.state_dict,
760846
range_constraints=edge_program.range_constraints,
761847
equality_constraints=edge_program.equality_constraints,
@@ -770,7 +856,7 @@ def to_edge(
770856
)
771857
passes = []
772858
passes.extend(aten_to_edge_passes.passes[-2:])
773-
edge_program = edge_program._transform(*passes)
859+
edge_program = _transform(edge_program, *passes)
774860
edge_programs[name] = edge_program
775861
return EdgeProgramManager(edge_programs, constant_methods, config)
776862

@@ -856,7 +942,7 @@ def transform(
856942
if isinstance(passes, dict):
857943
for name, program in self._edge_programs.items():
858944
if name in passes.keys():
859-
new_programs[name] = program._transform(*passes[name])
945+
new_programs[name] = _transform(program, *passes[name])
860946
EXIREdgeDialectVerifier(enable=check_ir_validity)(
861947
new_programs[name].graph_module
862948
)
@@ -865,7 +951,7 @@ def transform(
865951

866952
else: # apply passes to every method
867953
for name, program in self._edge_programs.items():
868-
new_programs[name] = program._transform(*passes)
954+
new_programs[name] = _transform(program, *passes)
869955
EXIREdgeDialectVerifier(enable=check_ir_validity)(
870956
new_programs[name].graph_module
871957
)

0 commit comments

Comments
 (0)