@@ -499,15 +499,15 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
499
499
500
500
// -----
501
501
502
- func.func @fold_unit_inner_dim (%arg0 : vector <8 x1 x3 xf128 >,
502
+ func.func @fold_inner_unit_dim (%arg0 : vector <8 x1 x3 xf128 >,
503
503
%arg1 : vector <1 x8 x3 xf128 >) -> vector <8 x3 xf128 > {
504
504
%sc_arg1 = vector.shape_cast %arg1 : vector <1 x8 x3 xf128 > to vector <8 x1 x3 xf128 >
505
505
%mul = arith.mulf %arg0 , %sc_arg1 : vector <8 x1 x3 xf128 >
506
506
%res = vector.shape_cast %mul : vector <8 x1 x3 xf128 > to vector <8 x3 xf128 >
507
507
return %res : vector <8 x3 xf128 >
508
508
}
509
509
510
- // CHECK-LABEL: func.func @fold_unit_inner_dim (
510
+ // CHECK-LABEL: func.func @fold_inner_unit_dim (
511
511
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
512
512
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
513
513
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
@@ -517,15 +517,15 @@ func.func @fold_unit_inner_dim(%arg0 : vector<8x1x3xf128>,
517
517
518
518
// -----
519
519
520
- func.func @fold_unit_inner_dim_scalable (%arg0 : vector <8 x1 x[1 ]x3 xf128 >,
520
+ func.func @fold_inner_unit_dim_scalable (%arg0 : vector <8 x1 x[1 ]x3 xf128 >,
521
521
%arg1 : vector <1 x8 x[1 ]x3 xf128 >) -> vector <8 x[1 ]x3 xf128 > {
522
522
%sc_arg1 = vector.shape_cast %arg1 : vector <1 x8 x[1 ]x3 xf128 > to vector <8 x1 x[1 ]x3 xf128 >
523
523
%mul = arith.mulf %arg0 , %sc_arg1 : vector <8 x1 x[1 ]x3 xf128 >
524
524
%res = vector.shape_cast %mul : vector <8 x1 x[1 ]x3 xf128 > to vector <8 x[1 ]x3 xf128 >
525
525
return %res : vector <8 x[1 ]x3 xf128 >
526
526
}
527
527
528
- // CHECK-LABEL: func.func @fold_unit_inner_dim_scalable (
528
+ // CHECK-LABEL: func.func @fold_inner_unit_dim_scalable (
529
529
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
530
530
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
531
531
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
0 commit comments