Skip to content

Commit b236558

Browse files
nujaaMacDue
andcommitted
[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. (#92934)
Generalizes `DropUnitDimFromElementwiseOps` to support inner unit dimensions. This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts. discussed [here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa). --------- Co-authored-by: Benjamin Maxwell <[email protected]>
1 parent 08ce147 commit b236558

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,27 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16221622
}
16231623
};
16241624

1625-
/// For vectors with either leading or trailing unit dim, replaces:
1625+
// Scalable unit dimensions are not supported. Folding such dimensions would
1626+
// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
1627+
// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
1628+
// future.
1629+
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1630+
auto inVecShape = inVecTy.getShape();
1631+
SmallVector<int64_t> newShape;
1632+
SmallVector<bool> newScalableDims;
1633+
for (auto [dim, isScalable] :
1634+
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1635+
if (dim == 1 && !isScalable)
1636+
continue;
1637+
1638+
newShape.push_back(dim);
1639+
newScalableDims.push_back(isScalable);
1640+
}
1641+
1642+
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1643+
}
1644+
1645+
/// For vectors with at least an unit dim, replaces:
16261646
/// elementwise(a, b)
16271647
/// with:
16281648
/// sc_a = shape_cast(a)
@@ -1634,20 +1654,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16341654
/// required to be rank > 1.
16351655
///
16361656
/// Ex:
1637-
/// ```
16381657
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
16391658
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1640-
/// ```
16411659
///
16421660
/// gets converted to:
16431661
///
1644-
/// ```
16451662
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
16461663
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
16471664
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
16481665
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
16491666
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1650-
/// ```
16511667
///
16521668
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
16531669
/// `%cast`.
@@ -1667,42 +1683,29 @@ struct DropUnitDimFromElementwiseOps final
16671683
// guaranteed to have identical shapes (with some exceptions such as
16681684
// `arith.select`) and it suffices to only check one of them.
16691685
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1670-
if (!sourceVectorType)
1671-
return failure();
1672-
if (sourceVectorType.getRank() < 2)
1673-
return failure();
1674-
1675-
bool hasTrailingDimUnitFixed =
1676-
((sourceVectorType.getShape().back() == 1) &&
1677-
(!sourceVectorType.getScalableDims().back()));
1678-
bool hasLeadingDimUnitFixed =
1679-
((sourceVectorType.getShape().front() == 1) &&
1680-
(!sourceVectorType.getScalableDims().front()));
1681-
if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
1686+
if (!sourceVectorType || sourceVectorType.getRank() < 2)
16821687
return failure();
16831688

1684-
// Drop leading/trailing unit dim by applying vector.shape_cast to all
1685-
// operands
1686-
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
1687-
16881689
SmallVector<Value> newOperands;
16891690
auto loc = op->getLoc();
16901691
for (auto operand : op->getOperands()) {
16911692
auto opVectorType = cast<VectorType>(operand.getType());
1692-
VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
1693+
auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1694+
if (newVType == opVectorType)
1695+
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1696+
16931697
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
16941698
newOperands.push_back(opSC);
16951699
}
16961700

16971701
VectorType newResultVectorType =
1698-
VectorType::Builder(resultVectorType).dropDim(dim);
1699-
// Create an updated elementwise Op without leading/trailing unit dim
1702+
dropNonScalableUnitDimFromType(resultVectorType);
1703+
// Create an updated elementwise Op without unit dim.
17001704
Operation *elementwiseOp =
17011705
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
17021706
newResultVectorType, op->getAttrs());
17031707

1704-
// Restore the leading/trailing unit dim by applying vector.shape_cast
1705-
// to the result
1708+
// Restore the unit dim by applying vector.shape_cast to the result.
17061709
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
17071710
elementwiseOp->getResult(0));
17081711

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,42 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
604604

605605
// -----
606606

607+
func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
608+
%arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
609+
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
610+
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
611+
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
612+
return %res : vector<8x3xf128>
613+
}
614+
615+
// CHECK-LABEL: func.func @fold_inner_unit_dim(
616+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
617+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
618+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
619+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
620+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
621+
// CHECK: return %[[VAL_4]] : vector<8x3xf128>
622+
623+
// -----
624+
625+
func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
626+
%arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
627+
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
628+
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
629+
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
630+
return %res : vector<8x[1]x3xf128>
631+
}
632+
633+
// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
634+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
635+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
636+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
637+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
638+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
639+
// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>
640+
641+
// -----
642+
607643
func.func @negative_out_of_bound_transfer_read(
608644
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
609645
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)