@@ -180,6 +180,63 @@ module attributes { transform.with_named_sequence } {
180
180
181
181
// -----
182
182
183
+ module attributes { transform.with_named_sequence } {
184
+ transform.named_sequence @print_elementwise (%arg0: !transform.any_op {transform.readonly }) {
185
+ transform.debug.emit_remark_at %arg0 , " elementwise" : !transform.any_op
186
+ transform.yield
187
+ }
188
+
189
+ transform.named_sequence @match_structured_body_elementwise (%arg0: !transform.any_op {transform.readonly }) -> !transform.any_op {
190
+ %0 = transform.match.structured failures (propagate ) %arg0 : (!transform.any_op ) -> !transform.any_op {
191
+ ^bb0 (%arg1: !transform.any_op ):
192
+ transform.match.structured.body %arg1 { elementwise } : !transform.any_op
193
+ transform.match.structured.yield %arg1 : !transform.any_op
194
+ }
195
+ transform.yield %0 : !transform.any_op
196
+ }
197
+
198
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.consumed }) {
199
+ transform.foreach_match in %arg0
200
+ @match_structured_body_elementwise -> @print_elementwise
201
+ : (!transform.any_op ) -> !transform.any_op
202
+ transform.yield
203
+ }
204
+
205
+ func.func @payload (%in1: tensor <2 xf32 >, %in2: tensor <2 xf32 >, %in3: tensor <2 x3 xf32 >, %out: tensor <2 xf32 >, %out2: tensor <2 x3 xf32 >) -> (tensor <2 xf32 >, tensor <2 x3 xf32 >, tensor <2 x3 xf32 >) attributes { transform.target_tag = " start_here" } {
206
+ %cst0 = arith.constant 0.0 : f32
207
+ %c0 = arith.constant 0 : index
208
+ %c1 = arith.constant 1 : index
209
+ // expected-remark @below {{elementwise}}
210
+ %fill = linalg.fill ins (%cst0: f32 ) outs (%out: tensor <2 xf32 >) -> tensor <2 xf32 >
211
+ // expected-remark @below {{elementwise}}
212
+ %add = linalg.map {arith.addf } ins (%in1 , %in2: tensor <2 xf32 >, tensor <2 xf32 >) outs (%fill: tensor <2 xf32 >)
213
+ %non_elementwise = linalg.generic
214
+ {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>],
215
+ iterator_types = [" parallel" , " parallel" ]}
216
+ ins (%in1 , %in3: tensor <2 xf32 >, tensor <2 x3 xf32 >) outs (%out2: tensor <2 x3 xf32 >) {
217
+ ^bb0 (%arg0: f32 , %arg1: f32 , %arg3: f32 ):
218
+ %0 = arith.addf %arg0 , %arg1 : f32
219
+ %1 = tensor.dim %add , %c0 : tensor <2 xf32 >
220
+ %2 = arith.subi %1 , %c1 : index
221
+ %3 = tensor.extract %add [%2 ] : tensor <2 xf32 >
222
+ %4 = arith.mulf %0 , %3 : f32
223
+ linalg.yield %4 : f32
224
+ } -> tensor <2 x3 xf32 >
225
+ // expected-remark @below {{elementwise}}
226
+ %add_bcast = linalg.generic
227
+ {index ing_maps = [affine_map <(d0 , d1 ) -> (d0 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>, affine_map <(d0 , d1 ) -> (d0 , d1 )>],
228
+ iterator_types = [" parallel" , " parallel" ]}
229
+ ins (%in1 , %in3: tensor <2 xf32 >, tensor <2 x3 xf32 >) outs (%out2: tensor <2 x3 xf32 >) {
230
+ ^bb0 (%arg0: f32 , %arg1: f32 , %arg3: f32 ):
231
+ %0 = arith.addf %arg0 , %arg1 : f32
232
+ linalg.yield %0 : f32
233
+ } -> tensor <2 x3 xf32 >
234
+ return %add , %add_bcast , %non_elementwise : tensor <2 xf32 >, tensor <2 x3 xf32 >, tensor <2 x3 xf32 >
235
+ }
236
+ }
237
+
238
+ // -----
239
+
183
240
module attributes { transform.with_named_sequence } {
184
241
transform.named_sequence @print_reduction (%arg0: !transform.any_op {transform.readonly }) {
185
242
transform.debug.emit_remark_at %arg0 , " reduction" : !transform.any_op
0 commit comments