@@ -140,6 +140,77 @@ def __init__(
140
140
)
141
141
142
142
143
+ @_ods_cext .register_operation (_Dialect , replace = True )
144
+ class FuseOp (FuseOp ):
145
+ """Specialization for FuseOp class."""
146
+
147
+ @overload
148
+ def __init__ (
149
+ self ,
150
+ loop_types : Union [Type , Sequence [Type ]],
151
+ target : Union [Operation , Value , OpView ],
152
+ * ,
153
+ tile_sizes : Optional [Union [DynamicIndexList , ArrayAttr ]] = None ,
154
+ tile_interchange : OptionalIntList = None ,
155
+ apply_cleanup : Optional [bool ] = False ,
156
+ loc = None ,
157
+ ip = None ,
158
+ ):
159
+ ...
160
+
161
+ @overload
162
+ def __init__ (
163
+ self ,
164
+ target : Union [Operation , Value , OpView ],
165
+ * ,
166
+ tile_sizes : Optional [Union [DynamicIndexList , ArrayAttr ]] = None ,
167
+ tile_interchange : OptionalIntList = None ,
168
+ apply_cleanup : Optional [bool ] = False ,
169
+ loc = None ,
170
+ ip = None ,
171
+ ):
172
+ ...
173
+
174
+ def __init__ (
175
+ self ,
176
+ loop_types_or_target : Union [Type , Sequence [Type ], Operation , OpView , Value ],
177
+ target_or_none : Optional [Union [Operation , Value , OpView ]] = None ,
178
+ * ,
179
+ tile_sizes : Optional [Union [DynamicIndexList , ArrayAttr ]] = None ,
180
+ tile_interchange : OptionalIntList = None ,
181
+ apply_cleanup : Optional [bool ] = False ,
182
+ loc = None ,
183
+ ip = None ,
184
+ ):
185
+ tile_sizes = tile_sizes if tile_sizes else []
186
+ tile_interchange = tile_interchange if tile_interchange else []
187
+ _ , tile_sizes , _ = _dispatch_dynamic_index_list (tile_sizes )
188
+ _ , tile_interchange , _ = _dispatch_dynamic_index_list (tile_interchange )
189
+ num_loops = sum (0 if v == 0 else 1 for v in tile_sizes )
190
+
191
+ if isinstance (loop_types_or_target , (Operation , Value , OpView )):
192
+ loop_types = [transform .AnyOpType .get ()] * num_loops
193
+ target = loop_types_or_target
194
+ assert target_or_none is None , "Cannot construct FuseOp with two targets."
195
+ else :
196
+ loop_types = (
197
+ ([loop_types_or_target ] * num_loops )
198
+ if isinstance (loop_types_or_target , Type )
199
+ else loop_types_or_target
200
+ )
201
+ target = target_or_none
202
+ super ().__init__ (
203
+ target .type ,
204
+ loop_types ,
205
+ target ,
206
+ tile_sizes = tile_sizes ,
207
+ tile_interchange = tile_interchange ,
208
+ apply_cleanup = apply_cleanup ,
209
+ loc = loc ,
210
+ ip = ip ,
211
+ )
212
+
213
+
143
214
@_ods_cext .register_operation (_Dialect , replace = True )
144
215
class GeneralizeOp (GeneralizeOp ):
145
216
"""Specialization for GeneralizeOp class."""
0 commit comments