Skip to content

Commit 3909145

Browse files
committed
FIXUP: Handle all unit dims case
1 parent b236558 commit 3909145

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,8 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16221622
}
16231623
};
16241624

1625-
// Scalable unit dimensions are not supported. Folding such dimensions would
1625+
// Helper function dropping unit non-scalable dimension from a VectorType.
1626+
// Scalable unit dimensions are not dropped. Folding such dimensions would
16261627
// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
16271628
// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
16281629
// future.
@@ -1638,6 +1639,11 @@ static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
16381639
newShape.push_back(dim);
16391640
newScalableDims.push_back(isScalable);
16401641
}
1642+
// All dims have been dropped, we need to return a legal shape for VectorType.
1643+
if (newShape.empty()) {
1644+
newShape.push_back(1);
1645+
newScalableDims.push_back(false);
1646+
}
16411647

16421648
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
16431649
}

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,21 @@ func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
640640

641641
// -----
642642

643+
func.func @fold_all_unit_dims(%arg0: vector<1x1xf32>) -> vector<1xf32> {
644+
%0 = arith.mulf %arg0, %arg0 : vector<1x1xf32>
645+
%res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32>
646+
return %res : vector<1xf32>
647+
}
648+
649+
// CHECK-LABEL: func.func @fold_all_unit_dims(
650+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
651+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
652+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
653+
// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
654+
// CHECK: return %[[VAL_3]] : vector<1xf32>
655+
656+
// -----
657+
643658
func.func @negative_out_of_bound_transfer_read(
644659
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
645660
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)