Skip to content

Commit 3a379e9

Browse files
tugsbayasgalanfacebook-github-bot
authored andcommitted
Add unlifting pass under private config (#4)
Summary: X-link: pytorch/pytorch#104897 Pull Request resolved: #4 We wanna do this little by little. For now, I tried only on DissectedPartsModel which needs to use aot_export version. Reviewed By: JacobSzwejbka Differential Revision: D46785735 fbshipit-source-id: a89cc5090e558ba050cb63b7ffe97b1a95bc8820
1 parent cf124e0 commit 3a379e9

File tree

6 files changed

+359
-47
lines changed

6 files changed

+359
-47
lines changed

backends/test/test_backends.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -618,11 +618,7 @@ def forward(self, x_raw, h, c):
618618
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
619619

620620
program_without_delegates = (
621-
exir.capture(
622-
composite_m,
623-
(input_x, input_h, input_c),
624-
exir.CaptureConfig(pt2_mode=True),
625-
)
621+
exir.capture(CompositeModel(3), inputs)
626622
.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
627623
.to_executorch(
628624
config=exir.ExecutorchBackendConfig(extract_segments=extract_segments),
@@ -726,7 +722,7 @@ def forward(self, x_raw, h, c):
726722

727723
program_without_delegates = (
728724
exir.capture(
729-
composite_m,
725+
CompositeModel(3),
730726
(input_x, input_h, input_c),
731727
exir.CaptureConfig(pt2_mode=True),
732728
)
@@ -962,7 +958,8 @@ def test_quantized_with_delegate(self) -> None:
962958
example_inputs,
963959
exir.CaptureConfig(
964960
pt2_mode=True,
965-
enable_functionalization=False,
961+
enable_aot=True,
962+
_unlift=True,
966963
),
967964
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
968965
FileCheck().check_count("quantize_per_tensor.default", 3).check("addmm").run(

exir/__init__.py

Lines changed: 158 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from collections import namedtuple
66
from dataclasses import dataclass, field
77
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
8+
from unittest.mock import patch
89

910
import sympy
1011
import torch
12+
import torch._export
1113
from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode
1214
from executorch.exir.emit import emit_program, EmitterOutput
1315
from executorch.exir.error import ExportError, ExportErrorType, InternalError
@@ -25,6 +27,7 @@
2527
from executorch.exir.schema import Program
2628
from executorch.exir.serialize import serialize_to_flatbuffer
2729
from executorch.exir.tracer import (
30+
_default_decomposition_table,
2831
dispatch_trace,
2932
dynamo_trace,
3033
ExirDynamoConfig,
@@ -41,6 +44,7 @@
4144
from torch._dynamo.eval_frame import Constraint
4245
from torch._export import CallSpec, export, ExportGraphSignature
4346
from torch._export.exported_program import ExportedProgram
47+
from torch._export.passes import ReplaceViewOpsWithViewCopyOpsPass
4448
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
4549
InputDim,
4650
RangeConstraint,
@@ -49,12 +53,156 @@
4953
from torch.fx._compatibility import compatibility
5054
from torch.fx.experimental.proxy_tensor import make_fx
5155
from torch.fx.experimental.symbolic_shapes import ShapeEnv
56+
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
5257
from torch.utils import _pytree as pytree
5358

5459

5560
Val = Any
5661

5762

63+
def _unlift(gm, inp_pos_to_param_buffer_name, in_spec, out_spec, state_dict):
64+
count = 0
65+
# Step 1: make lifted params as get_attr
66+
for node in gm.graph.nodes:
67+
if node.op == "placeholder":
68+
if count in inp_pos_to_param_buffer_name:
69+
with gm.graph.inserting_after(node):
70+
getattr_node = gm.graph.get_attr(
71+
inp_pos_to_param_buffer_name[count]
72+
)
73+
node.replace_all_uses_with(getattr_node)
74+
metadata = node.meta
75+
gm.graph.erase_node(node)
76+
getattr_node.meta = metadata
77+
count += 1
78+
79+
# Step 2: Fix the input/output of the graph now that we deleted
80+
# some args.
81+
gm.graph.lint()
82+
names = [f"arg_{i}" for i in range(len(in_spec.children_specs))]
83+
gm.graph._codegen = _PyTreeCodeGen(
84+
_PyTreeInfo(
85+
names,
86+
in_spec,
87+
out_spec,
88+
)
89+
)
90+
gm.recompile()
91+
92+
# Step 3: Find state references in HigherOrderOps and recursively
93+
# fix them.
94+
for node in gm.graph.nodes:
95+
if node.op == "call_function" and node.target == torch.ops.cond:
96+
pred, true_graph, false_graph, operands = node.args
97+
true_gm = getattr(gm, true_graph.name)
98+
false_gm = getattr(gm, false_graph.name)
99+
inp_pos_to_param_buffer_name_for_submod = {}
100+
real_operands = []
101+
for ix, operand in enumerate(operands):
102+
if operand.target in inp_pos_to_param_buffer_name.values():
103+
inp_pos_to_param_buffer_name_for_submod[ix] = operand.target
104+
true_gm.register_buffer(operand.target, state_dict[operand.target])
105+
false_gm.register_buffer(operand.target, state_dict[operand.target])
106+
else:
107+
real_operands.append(operand)
108+
node.args = (pred, true_graph, false_graph, real_operands)
109+
110+
_, in_spec = pytree.tree_flatten(real_operands)
111+
112+
_unlift(
113+
true_gm,
114+
inp_pos_to_param_buffer_name_for_submod,
115+
in_spec,
116+
None,
117+
state_dict,
118+
)
119+
_unlift(
120+
false_gm,
121+
inp_pos_to_param_buffer_name_for_submod,
122+
in_spec,
123+
None,
124+
state_dict,
125+
)
126+
if node.op == "call_function" and node.target.__name__ == "map_impl":
127+
body_graph, num_mapped, *operands = node.args
128+
body_gm = getattr(gm, body_graph.name)
129+
inp_pos_to_buffer_name_for_submod = {}
130+
real_operands = []
131+
for ix, operand in enumerate(operands):
132+
if operand.target in inp_pos_to_param_buffer_name.values():
133+
inp_pos_to_buffer_name_for_submod[ix] = operand.target
134+
body_gm.register_buffer(operand.target, state_dict[operand.target])
135+
else:
136+
real_operands.append(operand)
137+
node.args = (body_graph, num_mapped, *real_operands)
138+
139+
_, in_spec = pytree.tree_flatten(real_operands)
140+
141+
_unlift(
142+
body_gm, inp_pos_to_buffer_name_for_submod, in_spec, None, state_dict
143+
)
144+
gm.graph.lint()
145+
gm.graph.eliminate_dead_code()
146+
gm.recompile()
147+
return gm
148+
149+
150+
def unlift_exported_program_lifted_states(
151+
ep: torch._export.exported_program.ExportedProgram,
152+
):
153+
new_gm = copy.deepcopy(ep.graph_module)
154+
155+
# TODO Fix the period in params/buffers names later
156+
# maybe a pass to replace graph signature with fixed names
157+
param_buffer_name_to_corrected_name = {}
158+
159+
for name, stuff in ep.state_dict.items():
160+
if name in ep.graph_signature.buffers:
161+
if "." in name:
162+
new_gm.register_buffer(name.replace(".", "_"), stuff)
163+
param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
164+
else:
165+
new_gm.register_buffer(name, stuff)
166+
elif name in ep.graph_signature.parameters:
167+
if "." in name:
168+
new_gm.register_parameter(name.replace(".", "_"), stuff)
169+
param_buffer_name_to_corrected_name[name] = name.replace(".", "_")
170+
else:
171+
new_gm.register_parameter(name, stuff)
172+
else:
173+
raise AssertionError("encountered not registered param/buffer")
174+
175+
count = 0
176+
inp_pos_to_param_buffer_name = {}
177+
for node in new_gm.graph.nodes:
178+
if node.op == "placeholder":
179+
if node.name in ep.graph_signature.inputs_to_buffers:
180+
buffer_name = ep.graph_signature.inputs_to_buffers[node.name]
181+
if buffer_name in param_buffer_name_to_corrected_name:
182+
inp_pos_to_param_buffer_name[
183+
count
184+
] = param_buffer_name_to_corrected_name[buffer_name]
185+
else:
186+
inp_pos_to_param_buffer_name[count] = buffer_name
187+
if node.name in ep.graph_signature.inputs_to_parameters:
188+
param_name = ep.graph_signature.inputs_to_parameters[node.name]
189+
if param_name in param_buffer_name_to_corrected_name:
190+
inp_pos_to_param_buffer_name[
191+
count
192+
] = param_buffer_name_to_corrected_name[param_name]
193+
else:
194+
inp_pos_to_param_buffer_name[count] = param_name
195+
count += 1
196+
new_gm = _unlift(
197+
new_gm,
198+
inp_pos_to_param_buffer_name,
199+
ep.call_spec.in_spec,
200+
ep.call_spec.out_spec,
201+
ep.state_dict,
202+
)
203+
return new_gm
204+
205+
58206
@compatibility(is_backward_compatible=False)
59207
@dataclass
60208
class CaptureConfig:
@@ -63,6 +211,7 @@ class CaptureConfig:
63211
enable_dynamic_shape: bool = False
64212
enable_aot: bool = False
65213
_dynamo_config: "ExirDynamoConfig" = ExirDynamoConfig()
214+
_unlift: bool = False
66215

67216

68217
@compatibility(is_backward_compatible=False)
@@ -400,8 +549,15 @@ def capture(
400549
"Functionalization is required for enable_aot.",
401550
)
402551

403-
ep = export(f, args, _add_runtime_assertions=False, constraints=constraints)
404-
return ep # pyre-ignore
552+
# TODO remove this later
553+
with patch("torch._export.DECOMP_TABLE", _default_decomposition_table()):
554+
ep = export(
555+
f, args, _add_runtime_assertions=False, constraints=constraints
556+
)
557+
ep = ep.transform(ReplaceViewOpsWithViewCopyOpsPass())
558+
if not config._unlift:
559+
return ep # pyre-ignore
560+
graph_module = unlift_exported_program_lifted_states(ep)
405561

406562
elif config.enable_dynamic_shape:
407563
if not config._dynamo_config.dynamic_shapes:

exir/dialects/edge/edge.yaml

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@
8989
mat2: T0
9090
__ret_0: T0
9191

92+
- func: aten::arange.start_step
93+
namespace: edge
94+
inherits: aten::arange.start_step
95+
type_alias:
96+
T0: [Byte, Char, Double, Float, Int, Long, Short]
97+
type_constraint:
98+
- __ret_0: T0
99+
92100
- func: aten::bmm
93101
namespace: edge
94102
inherits: aten::bmm
@@ -198,14 +206,43 @@
198206
- self: T0
199207
__ret_0: T0
200208

201-
- func: aten::lift_fresh_copy
209+
- func: aten::index_select
202210
namespace: edge
203-
inherits: aten::lift_fresh_copy
211+
inherits: aten::index_select
204212
type_alias:
205-
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
213+
T0: [Bool]
214+
T1: [Byte]
215+
T2: [Char]
216+
T3: [Double]
217+
T4: [Float]
218+
T5: [Int]
219+
T6: [Long]
220+
T7: [Short]
206221
type_constraint:
207222
- self: T0
223+
index: T6
208224
__ret_0: T0
225+
- self: T1
226+
index: T6
227+
__ret_0: T1
228+
- self: T2
229+
index: T6
230+
__ret_0: T2
231+
- self: T3
232+
index: T6
233+
__ret_0: T3
234+
- self: T4
235+
index: T6
236+
__ret_0: T4
237+
- self: T5
238+
index: T6
239+
__ret_0: T5
240+
- self: T6
241+
index: T6
242+
__ret_0: T6
243+
- self: T7
244+
index: T6
245+
__ret_0: T7
209246

210247
- func: aten::masked_fill.Scalar
211248
namespace: edge
@@ -245,16 +282,6 @@
245282
mask: T0
246283
__ret_0: T7
247284

248-
- func: aten::minimum
249-
namespace: edge
250-
inherits: aten::minimum
251-
type_alias:
252-
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
253-
type_constraint:
254-
- self: T0
255-
other: T0
256-
__ret_0: T0
257-
258285
- func: aten::mm
259286
namespace: edge
260287
inherits: aten::mm
@@ -324,15 +351,6 @@
324351
- self: T0
325352
__ret_0: T0
326353

327-
- func: aten::select_copy.int
328-
namespace: edge
329-
inherits: aten::select_copy.int
330-
type_alias:
331-
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
332-
type_constraint:
333-
- self: T0
334-
__ret_0: T0
335-
336354
- func: aten::sigmoid
337355
namespace: edge
338356
inherits: aten::sigmoid
@@ -383,9 +401,25 @@
383401
other: T0
384402
__ret_0: T0
385403

386-
- func: aten::t
404+
- func: aten::sym_numel
405+
namespace: edge
406+
inherits: aten::sym_numel
407+
type_alias:
408+
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
409+
type_constraint:
410+
- self: T0
411+
412+
- func: aten::sym_size.int
413+
namespace: edge
414+
inherits: aten::sym_size.int
415+
type_alias:
416+
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
417+
type_constraint:
418+
- self: T0
419+
420+
- func: aten::t_copy
387421
namespace: edge
388-
inherits: aten::t
422+
inherits: aten::t_copy
389423
type_alias:
390424
T0: [Bool, Byte, Char, Double, Float, Int, Long, Short]
391425
type_constraint:

exir/dialects/edge/yaml_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def get_test_gen_key(op_name: str) -> str:
143143
opdb_key = opdb_key[:-5]
144144
elif opdb_key == "sym_size":
145145
opdb_key = "resize_"
146+
elif opdb_key == "sym_numel":
147+
opdb_key = "abs"
146148
elif opdb_key == "convolution":
147149
opdb_key = "conv_transpose2d"
148150
elif opdb_key == "embedding":

0 commit comments

Comments
 (0)