Skip to content

Commit e1a7803

Browse files
committed
Fix Python signature
1 parent 9890aee commit e1a7803

File tree

1 file changed

+22
-22
lines changed

1 file changed

+22
-22
lines changed

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@
2121
from typing import Dict, Optional, Sequence, Union, NewType
2222

2323

24+
@register_attribute_builder("ParamOperandIndexAttr")
25+
def _paramOperandIndexAttr(x: int, context) -> Attribute:
26+
return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
27+
28+
2429
@_ods_cext.register_operation(_Dialect, replace=True)
2530
class CastOp(CastOp):
2631
def __init__(
@@ -214,11 +219,6 @@ def __init__(
214219
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
215220

216221

217-
@register_attribute_builder("ParamOperandIndexAttr")
218-
def _paramOperandIndexAttr(x: int, context) -> Attribute:
219-
return Attribute.parse(f"#transform.param_operand_index<{x}>", context=context)
220-
221-
222222
@_ods_cext.register_operation(_Dialect, replace=True)
223223
class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
224224
def __init__(
@@ -227,10 +227,12 @@ def __init__(
227227
pass_name: Union[str, StringAttr],
228228
target: Union[Operation, Value, OpView],
229229
*,
230-
options: Dict[
231-
Union[str, StringAttr],
232-
Union[Attribute, Value, Operation, OpView],
233-
] = {},
230+
options: Optional[
231+
Dict[
232+
Union[str, StringAttr],
233+
Union[Attribute, Value, Operation, OpView],
234+
]
235+
] = None,
234236
loc=None,
235237
ip=None,
236238
):
@@ -241,20 +243,16 @@ def __init__(
241243
context = (loc and loc.context) or Context.current
242244

243245
cur_param_operand_idx = 0
244-
for key, value in options.items():
246+
for key, value in options.items() if options is not None else {}:
245247
if isinstance(key, StringAttr):
246248
key = key.value
247249

248250
if isinstance(value, (Value, Operation, OpView)):
249-
value = _get_op_result_or_value(value)
250-
# v = Attribute.parse(
251-
# f"#transform.param_operand_index<{cur_param_operand_idx}>",
252-
# context=context,
253-
# )
254-
v = _paramOperandIndexAttr(cur_param_operand_idx, context)
255-
options_dict[key] = v
251+
dynamic_options.append(_get_op_result_or_value(value))
252+
options_dict[key] = ParamOperandIndexAttr(
253+
cur_param_operand_idx, context
254+
)
256255
cur_param_operand_idx += 1
257-
dynamic_options.append(value)
258256
elif isinstance(value, Attribute):
259257
options_dict[key] = value
260258
elif isinstance(value, str):
@@ -279,10 +277,12 @@ def apply_registered_pass(
279277
pass_name: Union[str, StringAttr],
280278
target: Union[Operation, Value, OpView],
281279
*,
282-
options: Dict[
283-
Union[str, StringAttr],
284-
Union[Attribute, Value, Operation, OpView],
285-
] = {},
280+
options: Optional[
281+
Dict[
282+
Union[str, StringAttr],
283+
Union[Attribute, Value, Operation, OpView],
284+
]
285+
] = None,
286286
loc=None,
287287
ip=None,
288288
) -> Value:

0 commit comments

Comments
 (0)