|
16 | 16 | CadencePassAttribute,
|
17 | 17 | register_cadence_pass,
|
18 | 18 | )
|
| 19 | +from executorch.backends.cadence.aot.utils import rebind |
19 | 20 | from executorch.exir.dialects._ops import ops as exir_ops
|
20 | 21 | from executorch.exir.dialects.edge._ops import EdgeOpOverload
|
21 | 22 | from executorch.exir.pass_base import ExportPass, ProxyValue
|
22 |
| -from torch.fx.operator_schemas import get_signature_for_torch_op |
23 | 23 |
|
24 | 24 |
|
25 | 25 | @register_cadence_pass(CadencePassAttribute(opt_level=0))
|
@@ -117,32 +117,11 @@ class BindOptionalArgsPass(ExportPass):
|
117 | 117 | def call_operator(self, op, args, kwargs, meta):
|
118 | 118 | if not isinstance(op, EdgeOpOverload):
|
119 | 119 | return super().call_operator(op, args, kwargs, meta)
|
120 |
| - assert callable(op) |
121 | 120 |
|
122 |
| - torch_op_schemas = get_signature_for_torch_op(op._op) |
123 |
| - if len(torch_op_schemas) == 0: |
124 |
| - return super().call_operator(op, args, kwargs, meta) |
125 |
| - |
126 |
| - matched_schemas = [] |
127 |
| - # Iterate through all of the schema until we find one that matches |
128 |
| - # If one matches, populate `new_args_and_kwargs` with the new args/kwargs |
129 |
| - # values. If none matches, `new_args_and_kwargs` will be None |
130 |
| - for candidate_signature in torch_op_schemas: |
131 |
| - try: |
132 |
| - candidate_signature.bind(*args, **kwargs) |
133 |
| - matched_schemas.append(candidate_signature) |
134 |
| - except TypeError: |
135 |
| - continue |
136 |
| - |
137 |
| - if len(matched_schemas) != 1: |
138 |
| - # Did not match any schema. Cannot normalize |
139 |
| - return super().call_operator(op, args, kwargs, meta) |
140 |
| - |
141 |
| - sig = matched_schemas[0] |
142 |
| - bound_args = sig.bind(*args, **kwargs) |
143 |
| - bound_args.apply_defaults() |
| 121 | + if (updated_args := rebind(op, args, kwargs)) is not None: |
| 122 | + args, kwargs = updated_args |
144 | 123 |
|
145 |
| - return super().call_operator(op, bound_args.args, bound_args.kwargs, meta) |
| 124 | + return super().call_operator(op, args, kwargs, meta) |
146 | 125 |
|
147 | 126 |
|
148 | 127 | # This class encapsulates all the functions that simplify the op's args
|
|
0 commit comments