@@ -249,11 +249,11 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
249
249
// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
250
250
// CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
251
251
252
- func.func private @scalable_dims (%A : vector <8 x[4 ]x2 xf32 >, %B: vector <8 x[4 ]xf32 >) -> vector <8 x[4 ]xf32 > {
252
+ func.func private @vector_multi_reduction_non_scalable_dim (%A : vector <8 x[4 ]x2 xf32 >, %B: vector <8 x[4 ]xf32 >) -> vector <8 x[4 ]xf32 > {
253
253
%0 = vector.multi_reduction <add >, %A , %B [2 ] : vector <8 x[4 ]x2 xf32 > to vector <8 x[4 ]xf32 >
254
254
return %0 : vector <8 x[4 ]xf32 >
255
255
}
256
- // CHECK-LABEL: func.func private @scalable_dims (
256
+ // CHECK-LABEL: func.func private @vector_multi_reduction_non_scalable_dim (
257
257
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
258
258
// CHECK-SAME: %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
259
259
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<0.000000e+00> : vector<[32]xf32>
@@ -282,12 +282,12 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
282
282
// CHECK: return %[[VAL_163]] : vector<8x[4]xf32>
283
283
284
284
// Check that OneDimMultiReductionToTwoDim handles scalable dim
285
- func.func @scalable_dim_1d (%A: vector <[4 ]xf32 >, %B: f32 , %C: vector <[4 ]xi1 >) -> f32 {
285
+ func.func @vector_multi_reduction_scalable_dim_1d (%A: vector <[4 ]xf32 >, %B: f32 , %C: vector <[4 ]xi1 >) -> f32 {
286
286
%0 = vector.mask %C { vector.multi_reduction <add >, %A , %B [0 ] : vector <[4 ]xf32 > to f32 } : vector <[4 ]xi1 > -> f32
287
287
return %0 : f32
288
288
}
289
289
290
- // CHECK-LABEL: func.func @scalable_dim_1d (
290
+ // CHECK-LABEL: func.func @vector_multi_reduction_scalable_dim_1d (
291
291
// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
292
292
// CHECK-SAME: %[[ARG_1:.*]]: f32,
293
293
// CHECK-SAME: %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
@@ -298,6 +298,30 @@ func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) ->
298
298
// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
299
299
// CHECK: return %[[VAL_4]] : f32
300
300
301
+ func.func @vector_multi_reduction_scalable_dim_2d (%A: vector <2 x[4 ]xf32 >, %B: vector <2 xf32 >, %C: vector <2 x[4 ]xi1 >) -> vector <2 xf32 > {
302
+ %0 = vector.mask %C { vector.multi_reduction <add >, %A , %B [1 ] : vector <2 x[4 ]xf32 > to vector <2 xf32 > } : vector <2 x[4 ]xi1 > -> vector <2 xf32 >
303
+ return %0 : vector <2 xf32 >
304
+ }
305
+
306
+ // CHECK-LABEL: func.func @vector_multi_reduction_scalable_dim_2d(
307
+ // CHECK-SAME: %[[ARG_0:.*]]: vector<2x[4]xf32>,
308
+ // CHECK-SAME: %[[ARG_1:.*]]: vector<2xf32>,
309
+ // CHECK-SAME: %[[ARG_2:.*]]: vector<2x[4]xi1>) -> vector<2xf32> {
310
+ // CHECK-DAG: %[[C1_idx:.*]] = arith.constant 1 : index
311
+ // CHECK-DAG: %[[C0_idx:.*]] = arith.constant 0 : index
312
+ // CHECK-DAG: %[[C0_2xf32:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
313
+ // CHECK: %[[ARG0_0:.*]] = vector.extract %[[ARG_0]][0] : vector<[4]xf32> from vector<2x[4]xf32>
314
+ // CHECK: %[[ARG1_0:.*]] = vector.extract %[[ARG_1]][0] : f32 from vector<2xf32>
315
+ // CHECK: %[[ARG2_0:.*]] = vector.extract %[[ARG_2]][0] : vector<[4]xi1> from vector<2x[4]xi1>
316
+ // CHECK: %[[REDUCE_0:.*]] = vector.mask %[[ARG2_0]] { vector.reduction <add>, %[[ARG0_0]], %[[ARG1_0]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
317
+ // CHECK: %[[INSERT_0:.*]] = vector.insertelement %[[REDUCE_0]], %[[C0_2xf32]][%[[C0_idx]] : index] : vector<2xf32>
318
+ // CHECK: %[[ARG0_1:.*]] = vector.extract %[[ARG_0]][1] : vector<[4]xf32> from vector<2x[4]xf32>
319
+ // CHECK: %[[ARG1_1:.*]] = vector.extract %[[ARG_1]][1] : f32 from vector<2xf32>
320
+ // CHECK: %[[ARG2_1:.*]] = vector.extract %[[ARG_2]][1] : vector<[4]xi1> from vector<2x[4]xi1>
321
+ // CHECK: %[[REDUCE_1:.*]] = vector.mask %[[ARG2_1]] { vector.reduction <add>, %[[ARG0_1]], %[[ARG1_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
322
+ // CHECK: %[[INSERT_1:.*]] = vector.insertelement %[[REDUCE_1]], %[[INSERT_0]][%[[C1_idx]] : index] : vector<2xf32>
323
+ // CHECK: return %[[INSERT_1]] : vector<2xf32>
324
+
301
325
module attributes {transform.with_named_sequence } {
302
326
transform.named_sequence @__transform_main (%root : !transform.any_op {transform.readonly }) {
303
327
%func_op = transform.structured.match ops {[" func.func" ]} in %root : (!transform.any_op ) -> !transform.op <" func.func" >
0 commit comments