21
21
from typing import Dict , Optional , Sequence , Union , NewType
22
22
23
23
24
+ @register_attribute_builder ("ParamOperandIndexAttr" )
25
+ def _paramOperandIndexAttr (x : int , context ) -> Attribute :
26
+ return Attribute .parse (f"#transform.param_operand_index<{ x } >" , context = context )
27
+
28
+
24
29
@_ods_cext .register_operation (_Dialect , replace = True )
25
30
class CastOp (CastOp ):
26
31
def __init__ (
@@ -214,11 +219,6 @@ def __init__(
214
219
super ().__init__ (_get_op_results_or_values (operands ), loc = loc , ip = ip )
215
220
216
221
217
- @register_attribute_builder ("ParamOperandIndexAttr" )
218
- def _paramOperandIndexAttr (x : int , context ) -> Attribute :
219
- return Attribute .parse (f"#transform.param_operand_index<{ x } >" , context = context )
220
-
221
-
222
222
@_ods_cext .register_operation (_Dialect , replace = True )
223
223
class ApplyRegisteredPassOp (ApplyRegisteredPassOp ):
224
224
def __init__ (
@@ -227,10 +227,12 @@ def __init__(
227
227
pass_name : Union [str , StringAttr ],
228
228
target : Union [Operation , Value , OpView ],
229
229
* ,
230
- options : Dict [
231
- Union [str , StringAttr ],
232
- Union [Attribute , Value , Operation , OpView ],
233
- ] = {},
230
+ options : Optional [
231
+ Dict [
232
+ Union [str , StringAttr ],
233
+ Union [Attribute , Value , Operation , OpView ],
234
+ ]
235
+ ] = None ,
234
236
loc = None ,
235
237
ip = None ,
236
238
):
@@ -241,20 +243,16 @@ def __init__(
241
243
context = (loc and loc .context ) or Context .current
242
244
243
245
cur_param_operand_idx = 0
244
- for key , value in options .items ():
246
+ for key , value in options .items () if options is not None else {} :
245
247
if isinstance (key , StringAttr ):
246
248
key = key .value
247
249
248
250
if isinstance (value , (Value , Operation , OpView )):
249
- value = _get_op_result_or_value (value )
250
- # v = Attribute.parse(
251
- # f"#transform.param_operand_index<{cur_param_operand_idx}>",
252
- # context=context,
253
- # )
254
- v = _paramOperandIndexAttr (cur_param_operand_idx , context )
255
- options_dict [key ] = v
251
+ dynamic_options .append (_get_op_result_or_value (value ))
252
+ options_dict [key ] = ParamOperandIndexAttr (
253
+ cur_param_operand_idx , context
254
+ )
256
255
cur_param_operand_idx += 1
257
- dynamic_options .append (value )
258
256
elif isinstance (value , Attribute ):
259
257
options_dict [key ] = value
260
258
elif isinstance (value , str ):
@@ -279,10 +277,12 @@ def apply_registered_pass(
279
277
pass_name : Union [str , StringAttr ],
280
278
target : Union [Operation , Value , OpView ],
281
279
* ,
282
- options : Dict [
283
- Union [str , StringAttr ],
284
- Union [Attribute , Value , Operation , OpView ],
285
- ] = {},
280
+ options : Optional [
281
+ Dict [
282
+ Union [str , StringAttr ],
283
+ Union [Attribute , Value , Operation , OpView ],
284
+ ]
285
+ ] = None ,
286
286
loc = None ,
287
287
ip = None ,
288
288
) -> Value :
0 commit comments