Skip to content

Commit 0d0d77e

Browse files
angelayifacebook-github-bot
authored andcommitted
Remove ExirMetadata
Summary: exir_meta.in_spec/out_spec should now belong in the toplevel exported_program.call_spec.in_spec/out_spec Reviewed By: ydwu4 Differential Revision: D47392208 fbshipit-source-id: d510b7d45565138dac08ad8463b6680083f8b95b
1 parent 20ea601 commit 0d0d77e

18 files changed

+33
-623
lines changed

backends/backend_api.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
get_lowered_module_name,
2424
LoweredBackendModule,
2525
)
26-
from executorch.exir.graph_module import (
27-
attach_export_graph_metadata,
28-
ExirMetadata,
29-
get_control_flow_submodules,
30-
)
26+
from executorch.exir.graph_module import get_control_flow_submodules
3127
from executorch.exir.pass_base import ExportPass
3228
from torch._export.exported_program import ExportedProgram
3329

@@ -166,12 +162,6 @@ def _partition_and_lower(
166162
{},
167163
[],
168164
)
169-
meta = ExirMetadata(
170-
in_spec=None,
171-
out_spec=None,
172-
update_spec=0,
173-
)
174-
attach_export_graph_metadata(submodule_program.graph_module, meta)
175165

176166
lowered_submodule = to_backend(
177167
delegation_spec.backend_id,

backends/xnnpack/test/TARGETS

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ python_unittest(
2323
"//executorch/bundled_program:config",
2424
"//executorch/bundled_program:core",
2525
"//executorch/bundled_program/serialize:lib",
26-
"//executorch/exir:graph_module",
2726
"//executorch/exir:lib",
2827
"//executorch/exir/passes:spec_prop_pass",
2928
"//executorch/exir/serialize:lib",
@@ -54,7 +53,6 @@ python_unittest(
5453
"//executorch/bundled_program:config",
5554
"//executorch/bundled_program:core",
5655
"//executorch/bundled_program/serialize:lib",
57-
"//executorch/exir:graph_module",
5856
"//executorch/exir:lib",
5957
"//executorch/exir/dialects:lib",
6058
"//executorch/exir/passes:spec_prop_pass",
@@ -86,7 +84,6 @@ python_unittest(
8684
"//executorch/bundled_program:config",
8785
"//executorch/bundled_program:core",
8886
"//executorch/bundled_program/serialize:lib",
89-
"//executorch/exir:graph_module",
9087
"//executorch/exir:lib",
9188
"//executorch/exir/passes:spec_prop_pass",
9289
"//executorch/exir/serialize:lib",

backends/xnnpack/test/test_xnnpack_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
serialize_from_bundled_program_to_flatbuffer,
2828
)
2929

30-
from executorch.exir.graph_module import attach_export_graph_metadata, get_exir_meta
31-
3230
from executorch.exir.passes.spec_prop_pass import SpecPropPass
3331
from executorch.exir.serialize import serialize_to_flatbuffer
3432

@@ -306,14 +304,13 @@ def quantize_and_test_model_with_quantizer(
306304
)
307305
captured_program = exir.capture(module, example_inputs, config=capture_config)
308306
m = captured_program.graph_module
309-
exir_meta = get_exir_meta(m)
310307

311308
quantizer = QNNPackQuantizer()
312309
quantization_config = get_symmetric_quantization_config()
313310
quantizer.set_global(quantization_config)
314311
prepared = prepare_pt2e_quantizer(captured_program.graph_module, quantizer)
315312
converted = convert_pt2e(prepared)
316-
attach_export_graph_metadata(converted, exir_meta)
313+
317314
captured_program.graph_module = converted
318315
edge_program = captured_program.to_edge(get_xnnpack_edge_compile_config())
319316
delegated_module = self.lower_module_and_test_output(

exir/TARGETS

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,6 @@ python_library(
3434
"graph_module.py",
3535
],
3636
deps = [
37-
":common",
38-
":error",
39-
":graph",
4037
"//caffe2:torch",
4138
],
4239
)
@@ -129,7 +126,6 @@ python_library(
129126
":control_flow", # @manual
130127
":dynamic_shape",
131128
":error",
132-
":graph_module",
133129
":pass_base", # @manual
134130
":pass_manager",
135131
":schema",
@@ -156,7 +152,6 @@ python_library(
156152
":control_flow",
157153
":delegate",
158154
":error",
159-
":graph_module",
160155
":memory",
161156
":schema",
162157
":tensor",
@@ -234,7 +229,6 @@ python_library(
234229
deps = [
235230
":delegate",
236231
":error",
237-
":graph_module",
238232
":memory",
239233
"//caffe2:torch",
240234
"//caffe2/functorch:functorch_src",

exir/__init__.py

Lines changed: 12 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,6 @@
1111
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
1212
from executorch.exir.emit import emit_program, EmitterOutput
1313
from executorch.exir.error import ExportError, ExportErrorType, InternalError
14-
from executorch.exir.graph_module import (
15-
attach_export_graph_metadata,
16-
EXIR_METADATA,
17-
get_exir_meta,
18-
make_export_graph_module,
19-
reduce_graph_module,
20-
)
2114
from executorch.exir.pass_manager import PassManager, PassType
2215
from executorch.exir.passes import (
2316
aten_to_edge_passes,
@@ -119,35 +112,17 @@ class ExecutorchBackendConfig:
119112

120113

121114
# TODO(ycao): set up "__all__" to limit symbol exposure
122-
def _to_edge(expo_prog, config: EdgeCompileConfig) -> "ExirExportedProgram":
123-
meta = get_exir_meta(expo_prog.graph_module)
115+
def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
124116
if config._check_ir_validity:
125117
try:
126-
EXIRATenDialectVerifier()(expo_prog.graph_module)
118+
EXIRATenDialectVerifier()(ep.graph_module)
127119
except ExportError:
128120
logging.info(
129121
"If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, "
130122
"like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))."
131123
)
132124
raise
133-
mutation = meta.mutation
134-
num_mutated = len(mutation) if mutation is not None else 0
135-
output_node = next(iter(reversed(expo_prog.graph_module.graph.nodes)))
136-
assert output_node.op == "output"
137-
output_node.args = (output_node.args[0][num_mutated:],)
138-
expo_prog.graph_module.graph.eliminate_dead_code()
139125

140-
ep = ExirExportedProgram(
141-
expo_prog.graph_module,
142-
expo_prog.graph_module.graph,
143-
ExportGraphSignature([], [], [], [], {}, {}, {}, None),
144-
CallSpec(meta.in_spec, meta.out_spec),
145-
{},
146-
{},
147-
[],
148-
False,
149-
)
150-
attach_export_graph_metadata(ep.graph_module, meta)
151126
op_replace_pass = [OpReplacePass()] if config._use_edge_ops else []
152127
passes = aten_to_edge_passes.passes + op_replace_pass + config.passes
153128
new_ep = ep.transform(*passes)
@@ -198,7 +173,6 @@ def transform(self, *passes: PassType) -> "ExirExportedProgram":
198173
ep.equality_constraints,
199174
self.after_to_edge_passes,
200175
)
201-
202176
transformed_ep.graph_module.meta.update(ep.graph_module.meta)
203177
transformed_ep.graph_module.meta.update(self.graph_module.meta)
204178
return transformed_ep
@@ -227,24 +201,6 @@ def _to_server(
227201
# return graph_module now.
228202
return res.graph_module
229203

230-
@property
231-
def meta(self):
232-
return self.graph_module.meta
233-
234-
@property
235-
def in_spec(self):
236-
meta = get_exir_meta(self.graph_module)
237-
return meta.in_spec
238-
239-
@property
240-
def out_spec(self):
241-
meta = get_exir_meta(self.graph_module)
242-
return meta.out_spec
243-
244-
@property
245-
def graph(self):
246-
return self.graph_module.graph
247-
248204
@property
249205
def code(self):
250206
return self.graph_module.code
@@ -257,7 +213,6 @@ def to_executorch(
257213
raise RuntimeError("Must run to_edge before to_executorch.")
258214
config = config or ExecutorchBackendConfig()
259215
new_prog = self.transform(*edge_to_executorch_passes(config))
260-
meta = get_exir_meta(new_prog.graph_module)
261216
executorch_prog = ExecutorchProgram(
262217
new_prog,
263218
emit_stacktrace=config.emit_stacktrace,
@@ -266,20 +221,10 @@ def to_executorch(
266221
constant_tensor_alignment=config.constant_tensor_alignment,
267222
delegate_alignment=config.delegate_alignment,
268223
)
269-
attach_export_graph_metadata(executorch_prog.graph_module, meta)
270-
# We only need to update the meta of the root graph module since it is
271-
# reconstructed in ExecutorchProgram. The submodules are exactly the
272-
# original submodules in new_prog.
273224
executorch_prog.graph_module.meta.update(new_prog.graph_module.meta)
274225
executorch_prog.graph_module.meta.update(self.graph_module.meta)
275226
return executorch_prog
276227

277-
def __reduce__(
278-
self,
279-
) -> Tuple[Callable[..., "ExirExportedProgram"], Tuple[bytes]]:
280-
_, (pickled_states,) = self.graph_module.__reduce__()
281-
return (edge_gm_deserializer, (pickled_states,))
282-
283228
def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "ExportedProgram":
284229
gm = self.graph_module.__deepcopy__(memo)
285230
new_ep = ExirExportedProgram(
@@ -292,9 +237,6 @@ def __deepcopy__(self, memo: Optional[Dict[int, Any]] = None) -> "ExportedProgra
292237
copy.deepcopy(self.equality_constraints),
293238
self.after_to_edge_passes,
294239
)
295-
attach_export_graph_metadata(
296-
new_ep.graph_module, get_exir_meta(self.graph_module)
297-
)
298240
new_ep.graph_module.meta.update(self.graph_module.meta)
299241
return new_ep
300242

@@ -418,27 +360,6 @@ def get_multi_method_graph_module(self) -> "MultiMethodExirExportedProgram":
418360
return self._executorch_dialect_ir_program
419361

420362

421-
def edge_gm_deserializer(pickled_states: bytes) -> "ExirExportedProgram":
422-
loaded_gm = reduce_graph_module(pickled_states)
423-
# restore node.meta["val"], which is deleted before pickling
424-
annotated_gm = aten_to_edge_passes(loaded_gm).graph_module
425-
meta = get_exir_meta(annotated_gm)
426-
427-
ep = ExirExportedProgram(
428-
annotated_gm.module,
429-
annotated_gm.graph,
430-
ExportGraphSignature([], [], [], [], {}, {}, {}, None),
431-
CallSpec(meta.in_spec, meta.out_spec),
432-
{},
433-
{},
434-
[],
435-
after_to_edge_passes=True,
436-
)
437-
attach_export_graph_metadata(ep.graph_module, meta)
438-
ep.graph_module.meta.update(annotated_gm.graph_module.meta)
439-
return ep
440-
441-
442363
# This is to bootstrap the missing meta["val"] when 1. ph consists of scalar
443364
# 2. meta["val"] is not properly set in dispatch_trace.
444365
def _instantiate_missing_placeholder_val_with_real_inputs(gm, args):
@@ -468,7 +389,6 @@ def capture(
468389
)
469390

470391
config = config or CaptureConfig()
471-
mutation = None
472392
out_spec = None
473393
# TODO (zhxchen17) Always functionalize in a second pass no matter which path is taken.
474394
flat_args = tuple(pytree.tree_flatten(args)[0])
@@ -578,15 +498,8 @@ def convert_to_fake(x):
578498
tracing_mode=tracing_mode,
579499
_allow_non_fake_inputs=True,
580500
)(*args)
581-
module = make_export_graph_module(
582-
graph_module,
583-
graph_module.graph,
584-
)
585501

586-
flatten_output(module)
587-
meta = get_exir_meta(module)
588-
meta.in_spec = in_spec
589-
meta.out_spec = out_spec
502+
flatten_output(graph_module)
590503

591504
else:
592505
warnings.warn(
@@ -604,24 +517,22 @@ def convert_to_fake(x):
604517
raise InternalError(
605518
"Using AOT mode is not supported for leagacy capture mode, please use pt2_mode=True instead."
606519
)
607-
module = dispatch_trace(f, args)
608-
# TODO(shunting) move this into ExportModuleState
609-
meta = module.meta[EXIR_METADATA]
610-
_instantiate_missing_placeholder_val_with_real_inputs(module, flat_args)
611-
module._apply(torch.Tensor.contiguous)
612-
meta = get_exir_meta(module)
613-
meta.mutation = mutation # pyre-ignore
520+
graph_module = dispatch_trace(f, args)
521+
in_spec, out_spec = graph_module.in_spec, graph_module.out_spec
522+
523+
_instantiate_missing_placeholder_val_with_real_inputs(graph_module, flat_args)
524+
graph_module._apply(torch.Tensor.contiguous)
525+
614526
ep = ExirExportedProgram(
615-
module,
616-
module.graph,
527+
graph_module,
528+
graph_module.graph,
617529
ExportGraphSignature([], [], [], [], {}, {}, {}, None),
618-
CallSpec(meta.in_spec, meta.out_spec),
530+
CallSpec(in_spec, out_spec),
619531
{},
620532
{},
621533
[],
622534
False,
623535
)
624-
attach_export_graph_metadata(ep.graph_module, meta)
625536
return ep
626537

627538

@@ -720,18 +631,6 @@ def access_property_of_default_method(self, property_name: str):
720631
"to access property: {property_name}."""
721632
return getattr(default_program.graph_module, property_name)
722633

723-
@property
724-
def meta(self):
725-
return self.access_property_of_default_method("meta")
726-
727-
@property
728-
def in_spec(self):
729-
return self.meta[EXIR_METADATA].in_spec
730-
731-
@property
732-
def out_spec(self):
733-
return self.meta[EXIR_METADATA].out_spec
734-
735634
@property
736635
def graph(self):
737636
return self.access_property_of_default_method("graph")

exir/delegate.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77
import torch.utils._pytree as pytree
88

99
from executorch.backends.compile_spec_schema import CompileSpec
10-
from executorch.exir.graph_module import (
11-
_get_submodule,
12-
EXIR_METADATA,
13-
make_export_graph_module,
14-
)
10+
from executorch.exir.graph_module import _get_submodule
1511
from executorch.exir.tracer import Value
1612
from torch._export.exported_program import ExportedProgram
1713
from torch._functorch.eager_transforms import (
@@ -281,22 +277,6 @@ def _fixup_output_node(gm: torch.fx.GraphModule) -> None:
281277
return
282278

283279

284-
def generate_in_spec_out_spec(gm: torch.fx.GraphModule) -> None:
285-
output_nodes = []
286-
for node in gm.graph.nodes:
287-
if node.op == "output":
288-
output_nodes = node.args[0]
289-
290-
all_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"]
291-
292-
(_, pytree_in) = tree_flatten(tuple(all_placeholders))
293-
(_, pytree_out) = tree_flatten(tuple(output_nodes))
294-
295-
meta = gm.meta[EXIR_METADATA]
296-
meta.in_spec = pytree_in
297-
meta.out_spec = pytree_out
298-
299-
300280
def create_submodule_from_nodes(
301281
gm: torch.fx.GraphModule,
302282
node_list: NodeList,
@@ -319,14 +299,11 @@ def create_submodule_from_nodes(
319299
sorted_nodes = topo_sort(node_list)
320300

321301
submodule_name = "fused_" + tag
322-
initial_sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
302+
sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(
323303
gm, sorted_nodes, submodule_name
324304
)
325305

326-
_fixup_output_node(initial_sub_gm)
327-
sub_gm = make_export_graph_module(
328-
initial_sub_gm, initial_sub_gm.graph, submodule_name
329-
)
306+
_fixup_output_node(sub_gm)
330307

331308
gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs)
332309
if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor):
@@ -365,7 +342,6 @@ def create_submodule_from_nodes(
365342
submodule_node is not None
366343
), f"No submodule was created with the nodes {node_list} in the graph {gm.graph}"
367344

368-
generate_in_spec_out_spec(sub_gm)
369345
return sub_gm, submodule_node
370346

371347

0 commit comments

Comments
 (0)