Skip to content

Commit fad986b

Browse files
committed
Fix naming and test
1 parent be94fb0 commit fad986b

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,20 +1607,20 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16071607
}
16081608
};
16091609

1610-
FailureOr<VectorType> dropNonScalableUnitDimType(VectorType VT) {
1611-
int removed = 0;
1612-
auto shape = VT.getShape();
1613-
auto builder = VectorType::Builder(VT);
1614-
for (unsigned i = 0; i < shape.size(); i++) {
1615-
if (shape[i] == 1 && !VT.getScalableDims()[i]) {
1616-
builder.dropDim(i - removed);
1617-
removed++;
1610+
FailureOr<VectorType> dropNonScalableUnitDimType(VectorType inVecTy) {
1611+
int numUnitDimsDropped = 0;
1612+
auto inVecShape = inVecTy.getShape();
1613+
auto newVecBuilder = VectorType::Builder(inVecTy);
1614+
for (unsigned i = 0; i < inVecShape.size(); i++) {
1615+
if (inVecShape[i] == 1 && !inVecTy.getScalableDims()[i]) {
1616+
newVecBuilder.dropDim(i - numUnitDimsDropped);
1617+
numUnitDimsDropped++;
16181618
}
16191619
}
16201620

1621-
if (removed == 0)
1621+
if (numUnitDimsDropped == 0)
16221622
return failure();
1623-
return VectorType(builder);
1623+
return VectorType(newVecBuilder);
16241624
}
16251625

16261626
/// For vectors with at least an unit dim, replaces:

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ func.func @fold_unit_inner_dim(%arg0 : vector<8x1x3xf128>,
509509

510510
// CHECK-LABEL: func.func @fold_unit_inner_dim(
511511
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
512-
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8xxf128> {
512+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
513513
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
514514
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
515515
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
@@ -521,8 +521,8 @@ func.func @fold_unit_inner_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
521521
%arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
522522
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
523523
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
524-
%res = vector.shape_cast %add : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
525-
return %mul : vector<8x[1]x3xf128>
524+
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
525+
return %res : vector<8x[1]x3xf128>
526526
}
527527

528528
// CHECK-LABEL: func.func @fold_unit_inner_dim_scalable(

0 commit comments

Comments
 (0)