Skip to content

Commit 5866c19

Browse files
authored
Create utility to rebind args/kwargs.
Differential Revision: D75029675 Pull Request resolved: #10987
1 parent 2ede762 commit 5866c19

File tree

3 files changed

+34
-25
lines changed

3 files changed

+34
-25
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ python_library(
211211
typing = True,
212212
deps = [
213213
":pass_utils",
214+
":utils",
214215
"//executorch/backends/cadence/aot:pass_utils",
215216
"//executorch/exir:pass_base",
216217
"//executorch/exir/dialects:lib",

backends/cadence/aot/simplify_ops.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
CadencePassAttribute,
1717
register_cadence_pass,
1818
)
19+
from executorch.backends.cadence.aot.utils import rebind
1920
from executorch.exir.dialects._ops import ops as exir_ops
2021
from executorch.exir.dialects.edge._ops import EdgeOpOverload
2122
from executorch.exir.pass_base import ExportPass, ProxyValue
22-
from torch.fx.operator_schemas import get_signature_for_torch_op
2323

2424

2525
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -117,32 +117,11 @@ class BindOptionalArgsPass(ExportPass):
117117
def call_operator(self, op, args, kwargs, meta):
118118
if not isinstance(op, EdgeOpOverload):
119119
return super().call_operator(op, args, kwargs, meta)
120-
assert callable(op)
121120

122-
torch_op_schemas = get_signature_for_torch_op(op._op)
123-
if len(torch_op_schemas) == 0:
124-
return super().call_operator(op, args, kwargs, meta)
125-
126-
matched_schemas = []
127-
# Iterate through all of the schema until we find one that matches
128-
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
129-
# values. If none matches, `new_args_and_kwargs` will be None
130-
for candidate_signature in torch_op_schemas:
131-
try:
132-
candidate_signature.bind(*args, **kwargs)
133-
matched_schemas.append(candidate_signature)
134-
except TypeError:
135-
continue
136-
137-
if len(matched_schemas) != 1:
138-
# Did not match any schema. Cannot normalize
139-
return super().call_operator(op, args, kwargs, meta)
140-
141-
sig = matched_schemas[0]
142-
bound_args = sig.bind(*args, **kwargs)
143-
bound_args.apply_defaults()
121+
if (updated_args := rebind(op, args, kwargs)) is not None:
122+
args, kwargs = updated_args
144123

145-
return super().call_operator(op, bound_args.args, bound_args.kwargs, meta)
124+
return super().call_operator(op, args, kwargs, meta)
146125

147126

148127
# This class encapsulates all the functions that simplify the op's args

backends/cadence/aot/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from executorch.exir import ExecutorchProgramManager, memory
1919
from executorch.exir.dialects._ops import ops as exir_ops
2020
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
21+
from executorch.exir.pass_base import Argument
2122
from tabulate import tabulate
23+
from torch.fx.operator_schemas import get_signature_for_torch_op
2224

2325
from torch.utils._pytree import tree_flatten
2426

@@ -308,3 +310,30 @@ def get_size(self, exir_id: int) -> int:
308310
# Return default memory config for the backend
309311
def get_default_memory_config() -> MemoryConfig:
310312
return MemoryConfig(memory_sizes=[0x1000000000])
313+
314+
315+
def rebind(
316+
op: EdgeOpOverload, args: tuple[Argument, ...], kwargs: dict[str, Argument]
317+
) -> Optional[tuple[tuple[Argument, ...], dict[str, Argument]]]:
318+
"""Populates optional args and binds args/kwargs based on schema."""
319+
torch_op_schemas = get_signature_for_torch_op(op._op)
320+
321+
matched_schemas = []
322+
# Iterate through all of the schema until we find one that matches
323+
# If one matches, populate `new_args_and_kwargs` with the new args/kwargs
324+
# values. If none matches, `new_args_and_kwargs` will be None
325+
for candidate_signature in torch_op_schemas:
326+
try:
327+
candidate_signature.bind(*args, **kwargs)
328+
matched_schemas.append(candidate_signature)
329+
except TypeError:
330+
continue
331+
332+
if len(matched_schemas) != 1:
333+
# Did not match any schema. Cannot normalize
334+
return None
335+
336+
bound_args = matched_schemas[0].bind(*args, **kwargs)
337+
bound_args.apply_defaults()
338+
339+
return bound_args.args, bound_args.kwargs

0 commit comments

Comments
 (0)