Skip to content

Commit 8e2f769

Browse files
authored
Merge branch 'main' into migrate-pt2e-top-level
2 parents e4874ec + 5866c19 commit 8e2f769

File tree

9 files changed

+64
-41
lines changed

9 files changed

+64
-41
lines changed

CMakePresets.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
{
55
"name": "common",
66
"hidden": true,
7-
"binaryDir": "${sourceDir}/cmake-out",
8-
"generator": "Unix Makefiles"
7+
"binaryDir": "${sourceDir}/cmake-out"
98
},
109
{
1110
"name": "macos-arm64",

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ python_library(
211211
typing = True,
212212
deps = [
213213
":pass_utils",
214+
":utils",
214215
"//executorch/backends/cadence/aot:pass_utils",
215216
"//executorch/exir:pass_base",
216217
"//executorch/exir/dialects:lib",

backends/cadence/aot/simplify_ops.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
CadencePassAttribute,
1717
register_cadence_pass,
1818
)
19+
from executorch.backends.cadence.aot.utils import rebind
1920
from executorch.exir.dialects._ops import ops as exir_ops
2021
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2122
from executorch.exir.pass_base import ExportPass, ProxyValue
22-
from torch.fx.operator_schemas import get_signature_for_torch_op
2323

2424

2525
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -117,32 +117,11 @@ class BindOptionalArgsPass(ExportPass):
117117
def call_operator(self, op, args, kwargs, meta):
118118
if not isinstance(op, EdgeOpOverload):
119119
return super().call_operator(op, args, kwargs, meta)
120-
assert callable(op)
121120

122-
torch_op_schemas = get_signature_for_torch_op(op._op)
123-
if len(torch_op_schemas) == 0:
124-
return super().call_operator(op, args, kwargs, meta)
125-
126-
matched_schemas = []
127-
# Iterate through all of the schema until we find one that matches
128-
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
129-
# values. If none matches, `new_args_and_kwargs` will be None
130-
for candidate_signature in torch_op_schemas:
131-
try:
132-
candidate_signature.bind(*args, **kwargs)
133-
matched_schemas.append(candidate_signature)
134-
except TypeError:
135-
continue
136-
137-
if len(matched_schemas) != 1:
138-
# Did not match any schema. Cannot normalize
139-
return super().call_operator(op, args, kwargs, meta)
140-
141-
sig = matched_schemas[0]
142-
bound_args = sig.bind(*args, **kwargs)
143-
bound_args.apply_defaults()
121+
if (updated_args := rebind(op, args, kwargs)) is not None:
122+
args, kwargs = updated_args
144123

145-
return super().call_operator(op, bound_args.args, bound_args.kwargs, meta)
124+
return super().call_operator(op, args, kwargs, meta)
146125

147126

148127
# This class encapsulates all the functions that simplify the op's args

backends/cadence/aot/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from executorch.exir import ExecutorchProgramManager, memory
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
21+
from executorch.exir.pass_base import Argument
2122
from tabulate import tabulate
23+
from torch.fx.operator_schemas import get_signature_for_torch_op
2224

2325
from torch.utils._pytree import tree_flatten
2426

@@ -308,3 +310,30 @@ def get_size(self, exir_id: int) -> int:
308310
# Return default memory config for the backend
309311
def get_default_memory_config() -> MemoryConfig:
310312
return MemoryConfig(memory_sizes=[0x1000000000])
313+
314+
315+
def rebind(
316+
op: EdgeOpOverload, args: tuple[Argument, ...], kwargs: dict[str, Argument]
317+
) -> Optional[tuple[tuple[Argument, ...], dict[str, Argument]]]:
318+
"""Populates optional args and binds args/kwargs based on schema."""
319+
torch_op_schemas = get_signature_for_torch_op(op._op)
320+
321+
matched_schemas = []
322+
# Iterate through all of the schema until we find one that matches
323+
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
324+
# values. If none matches, `new_args_and_kwargs` will be None
325+
for candidate_signature in torch_op_schemas:
326+
try:
327+
candidate_signature.bind(*args, **kwargs)
328+
matched_schemas.append(candidate_signature)
329+
except TypeError:
330+
continue
331+
332+
if len(matched_schemas) != 1:
333+
# Did not match any schema. Cannot normalize
334+
return None
335+
336+
bound_args = matched_schemas[0].bind(*args, **kwargs)
337+
bound_args.apply_defaults()
338+
339+
return bound_args.args, bound_args.kwargs

devtools/inspector/_inspector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,7 @@ def find_total_for_module(self, module_name: str) -> float:
12401240
for block in self.event_blocks:
12411241
for event in block.events:
12421242
# Skip OPERATOR_CALL events to avoid double-counting and exclude framework tax
1243-
if event.event_name == "OPERATOR_CALL":
1243+
if event.name == "OPERATOR_CALL":
12441244
continue
12451245

12461246
module_hierarchy = event.module_hierarchy.values()
Loading
-256 Bytes
Binary file not shown.

docs/source/tutorials_source/devtools-integration-tutorial.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -187,24 +187,22 @@ def forward(self, x):
187187

188188
from executorch.devtools import Inspector
189189

190+
190191
# sphinx_gallery_start_ignore
191-
inspector_patch = patch.object(Inspector, "__init__", return_value=None)
192-
inspector_patch_print = patch.object(Inspector, "print_data_tabular", return_value="")
193-
inspector_patch.start()
192+
# inspector_patch = patch.object(Inspector, "__init__", return_value=None)
193+
inspector_patch_print = patch.object(Inspector, "print_data_tabular", return_value=None)
194194
inspector_patch_print.start()
195195
# sphinx_gallery_end_ignore
196196
etrecord_path = "etrecord.bin"
197197
etdump_path = "etdump.etdp"
198198
inspector = Inspector(etdump_path=etdump_path, etrecord=etrecord_path)
199-
# sphinx_gallery_start_ignore
200-
inspector.event_blocks = []
201-
# sphinx_gallery_end_ignore
202199
inspector.print_data_tabular()
203200

204-
# sphinx_gallery_start_ignore
205-
inspector_patch.stop()
206-
inspector_patch_print.stop()
207-
# sphinx_gallery_end_ignore
201+
####################################
202+
#
203+
# Here is an example output:
204+
#
205+
# .. image:: ../_static/img/inspector_tabular_output.png
208206

209207
######################################################################
210208
# Analyzing with an Inspector
@@ -234,11 +232,13 @@ def forward(self, x):
234232
if event.name == "native_call_addmm.out":
235233
print(event.name, event.perf_data.raw if event.perf_data else "")
236234

235+
print()
237236
# Via Dataframe
238237
df = event_block.to_dataframe()
239238
df = df[df.event_name == "native_call_addmm.out"]
240-
print(df[["event_name", "raw"]])
241-
print()
239+
if len(df) > 0:
240+
print(df[["event_name", "raw"]])
241+
print()
242242

243243
######################################################################
244244
# If a user wants to trace an operator back to their model code, they would do
@@ -255,6 +255,20 @@ def forward(self, x):
255255
if slowest is not None:
256256
print(slowest.name)
257257
print()
258+
# sphinx_gallery_start_ignore
259+
slowest_print = patch.object(
260+
slowest,
261+
"stack_traces",
262+
new={
263+
"aten_convolution_default_1_": " File "
264+
'"devtools-integration-tutorial.py", '
265+
"line 82, in forward\n"
266+
" x = F.max_pool2d(F.relu(self.conv2(x)), "
267+
"2)\n"
268+
},
269+
)
270+
slowest_print.start()
271+
# sphinx_gallery_end_ignore
258272
pp.pprint(slowest.stack_traces)
259273
print()
260274
pp.pprint(slowest.module_hierarchy)
@@ -264,7 +278,7 @@ def forward(self, x):
264278
df = df[df.event_name == "native_call_convolution.out"]
265279
if len(df) > 0:
266280
slowest = df.loc[df["p50"].idxmax()]
267-
assert slowest
281+
assert slowest is not None
268282
print(slowest.name)
269283
print()
270284
pp.pprint(slowest.stack_traces if slowest.stack_traces else "")
@@ -281,6 +295,7 @@ def forward(self, x):
281295
######################################################################
282296
# Note: ``find_total_for_module`` is a special first class method of
283297
# `Inspector <../model-inspector.html>`__
298+
#
284299

285300
######################################################################
286301
# Conclusion
2.27 KB
Binary file not shown.

0 commit comments

Comments
 (0)