Skip to content

Commit de61875

Browse files
nujaaMacDue
andauthored
[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. (#98455)
Generalizes DropUnitDimFromElementwiseOps to support inner unit dimensions. This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts. discussed [here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa). Fix after : #97652 showed an unhandled edge case when all dimensions are one. The generated target VectorType would be `vector<f32>` which is apparently not supported by the mulf. In case all dimensions are dropped, the target vectorType is vector<1xf32> --------- Co-authored-by: Benjamin Maxwell <[email protected]>
1 parent c7309da commit de61875

File tree

2 files changed

+85
-23
lines changed

2 files changed

+85
-23
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,7 +1627,33 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16271627
}
16281628
};
16291629

1630-
/// For vectors with either leading or trailing unit dim, replaces:
1630+
// Helper function dropping unit non-scalable dimension from a VectorType
1631+
// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
1632+
// dimensions are not dropped. Folding such dimensions would require "shifting"
1633+
// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
1634+
// vector<[4]xf32>). This could be implemented in the future.
1635+
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1636+
auto inVecShape = inVecTy.getShape();
1637+
SmallVector<int64_t> newShape;
1638+
SmallVector<bool> newScalableDims;
1639+
for (auto [dim, isScalable] :
1640+
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1641+
if (dim == 1 && !isScalable)
1642+
continue;
1643+
1644+
newShape.push_back(dim);
1645+
newScalableDims.push_back(isScalable);
1646+
}
1647+
// All dims have been dropped, return vector<1xeType>.
1648+
if (newShape.empty()) {
1649+
newShape.push_back(1);
1650+
newScalableDims.push_back(false);
1651+
}
1652+
1653+
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1654+
}
1655+
1656+
/// For vectors with at least one unit dim, replaces:
16311657
/// elementwise(a, b)
16321658
/// with:
16331659
/// sc_a = shape_cast(a)
@@ -1639,20 +1665,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16391665
/// required to be rank > 1.
16401666
///
16411667
/// Ex:
1642-
/// ```
16431668
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
16441669
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1645-
/// ```
16461670
///
16471671
/// gets converted to:
16481672
///
1649-
/// ```
16501673
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
16511674
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
16521675
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
16531676
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
16541677
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1655-
/// ```
16561678
///
16571679
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
16581680
/// `%cast`.
@@ -1677,37 +1699,26 @@ struct DropUnitDimFromElementwiseOps final
16771699
if (sourceVectorType.getRank() < 2)
16781700
return failure();
16791701

1680-
bool hasTrailingDimUnitFixed =
1681-
((sourceVectorType.getShape().back() == 1) &&
1682-
(!sourceVectorType.getScalableDims().back()));
1683-
bool hasLeadingDimUnitFixed =
1684-
((sourceVectorType.getShape().front() == 1) &&
1685-
(!sourceVectorType.getScalableDims().front()));
1686-
if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
1687-
return failure();
1688-
1689-
// Drop leading/trailing unit dim by applying vector.shape_cast to all
1690-
// operands
1691-
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
1692-
16931702
SmallVector<Value> newOperands;
16941703
auto loc = op->getLoc();
16951704
for (auto operand : op->getOperands()) {
16961705
auto opVectorType = cast<VectorType>(operand.getType());
1697-
VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
1706+
auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1707+
if (newVType == opVectorType)
1708+
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1709+
16981710
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
16991711
newOperands.push_back(opSC);
17001712
}
17011713

17021714
VectorType newResultVectorType =
1703-
VectorType::Builder(resultVectorType).dropDim(dim);
1704-
// Create an updated elementwise Op without leading/trailing unit dim
1715+
dropNonScalableUnitDimFromType(resultVectorType);
1716+
// Create an updated elementwise Op without unit dim.
17051717
Operation *elementwiseOp =
17061718
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
17071719
newResultVectorType, op->getAttrs());
17081720

1709-
// Restore the leading/trailing unit dim by applying vector.shape_cast
1710-
// to the result
1721+
// Restore the unit dim by applying vector.shape_cast to the result.
17111722
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
17121723
elementwiseOp->getResult(0));
17131724

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,57 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
604604

605605
// -----
606606

607+
func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
608+
%arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
609+
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
610+
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
611+
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
612+
return %res : vector<8x3xf128>
613+
}
614+
615+
// CHECK-LABEL: func.func @fold_inner_unit_dim(
616+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
617+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
618+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
619+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
620+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
621+
// CHECK: return %[[VAL_4]] : vector<8x3xf128>
622+
623+
// -----
624+
625+
func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
626+
%arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
627+
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
628+
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
629+
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
630+
return %res : vector<8x[1]x3xf128>
631+
}
632+
633+
// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
634+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
635+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
636+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
637+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
638+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
639+
// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>
640+
641+
// -----
642+
643+
func.func @fold_all_unit_dims(%arg0: vector<1x1xf32>) -> vector<1xf32> {
644+
%0 = arith.mulf %arg0, %arg0 : vector<1x1xf32>
645+
%res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32>
646+
return %res : vector<1xf32>
647+
}
648+
649+
// CHECK-LABEL: func.func @fold_all_unit_dims(
650+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
651+
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
652+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
653+
// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
654+
// CHECK: return %[[VAL_3]] : vector<1xf32>
655+
656+
// -----
657+
607658
func.func @negative_out_of_bound_transfer_read(
608659
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
609660
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)