Skip to content

Commit 4af5cf1

Browse files
committed
FIXUP: Handle all unit dims case
1 parent b236558 commit 4af5cf1

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1628,17 +1628,24 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16281628
// future.
16291629
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
16301630
auto inVecShape = inVecTy.getShape();
1631+
auto inVecScalableDims = inVecTy.getScalableDims();
16311632
SmallVector<int64_t> newShape;
16321633
SmallVector<bool> newScalableDims;
1633-
for (auto [dim, isScalable] :
1634-
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1635-
if (dim == 1 && !isScalable)
1636-
continue;
1634+
if (llvm::all_of(inVecShape, [](int64_t dim) { return dim == 1; }) &&
1635+
llvm::none_of(inVecScalableDims,
1636+
[](bool isScalable) { return isScalable; })) {
1637+
newShape.push_back(1);
1638+
newScalableDims.push_back(false);
1639+
} else {
1640+
for (auto [dim, isScalable] :
1641+
llvm::zip_equal(inVecShape, inVecScalableDims)) {
1642+
if (dim == 1 && !isScalable)
1643+
continue;
16371644

1638-
newShape.push_back(dim);
1639-
newScalableDims.push_back(isScalable);
1645+
newShape.push_back(dim);
1646+
newScalableDims.push_back(isScalable);
1647+
}
16401648
}
1641-
16421649
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
16431650
}
16441651

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,21 @@ func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
640640

641641
// -----
642642

643+
func.func @fold_all_unit_dims(%arg0: vector<1x1xf32>) -> vector<1xf32> {
644+
%0 = arith.mulf %arg0, %arg0 : vector<1x1xf32>
645+
%res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32>
646+
return %res : vector<1xf32>
647+
}
648+
649+
// CHECK-LABEL: func.func @fold_all_unit_dims(
650+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
651+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
652+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
653+
// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
654+
// CHECK: return %[[VAL_3]] : vector<1xf32>
655+
656+
// -----
657+
643658
func.func @negative_out_of_bound_transfer_read(
644659
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
645660
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)