Skip to content

Commit c93faa4

Browse files
nujaaMacDue
authored andcommitted
[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. (llvm#92934)
Generalizes `DropUnitDimFromElementwiseOps` to support inner unit dimensions. This change stems from improving lowering of contractionOps for Arm SME. Where we end up with inner unit dimensions on MulOp, BroadcastOp and TransposeOp, preventing the generation of outerproducts. discussed [here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa). --------- Co-authored-by: Benjamin Maxwell <[email protected]>
1 parent d9dfde4 commit c93faa4

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,7 +1612,27 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16121612
}
16131613
};
16141614

1615-
/// For vectors with either leading or trailing unit dim, replaces:
1615+
// Scalable unit dimensions are not supported. Folding such dimensions would
1616+
// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
1617+
// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
1618+
// future.
1619+
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
1620+
auto inVecShape = inVecTy.getShape();
1621+
SmallVector<int64_t> newShape;
1622+
SmallVector<bool> newScalableDims;
1623+
for (auto [dim, isScalable] :
1624+
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
1625+
if (dim == 1 && !isScalable)
1626+
continue;
1627+
1628+
newShape.push_back(dim);
1629+
newScalableDims.push_back(isScalable);
1630+
}
1631+
1632+
return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
1633+
}
1634+
1635+
/// For vectors with at least an unit dim, replaces:
16161636
/// elementwise(a, b)
16171637
/// with:
16181638
/// sc_a = shape_cast(a)
@@ -1624,20 +1644,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16241644
/// required to be rank > 1.
16251645
///
16261646
/// Ex:
1627-
/// ```
16281647
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
16291648
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
1630-
/// ```
16311649
///
16321650
/// gets converted to:
16331651
///
1634-
/// ```
16351652
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
16361653
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
16371654
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
16381655
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
16391656
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
1640-
/// ```
16411657
///
16421658
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
16431659
/// `%cast`.
@@ -1657,42 +1673,29 @@ struct DropUnitDimFromElementwiseOps final
16571673
// guaranteed to have identical shapes (with some exceptions such as
16581674
// `arith.select`) and it suffices to only check one of them.
16591675
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
1660-
if (!sourceVectorType)
1661-
return failure();
1662-
if (sourceVectorType.getRank() < 2)
1663-
return failure();
1664-
1665-
bool hasTrailingDimUnitFixed =
1666-
((sourceVectorType.getShape().back() == 1) &&
1667-
(!sourceVectorType.getScalableDims().back()));
1668-
bool hasLeadingDimUnitFixed =
1669-
((sourceVectorType.getShape().front() == 1) &&
1670-
(!sourceVectorType.getScalableDims().front()));
1671-
if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
1676+
if (!sourceVectorType || sourceVectorType.getRank() < 2)
16721677
return failure();
16731678

1674-
// Drop leading/trailing unit dim by applying vector.shape_cast to all
1675-
// operands
1676-
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;
1677-
16781679
SmallVector<Value> newOperands;
16791680
auto loc = op->getLoc();
16801681
for (auto operand : op->getOperands()) {
16811682
auto opVectorType = cast<VectorType>(operand.getType());
1682-
VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
1683+
auto newVType = dropNonScalableUnitDimFromType(opVectorType);
1684+
if (newVType == opVectorType)
1685+
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
1686+
16831687
auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
16841688
newOperands.push_back(opSC);
16851689
}
16861690

16871691
VectorType newResultVectorType =
1688-
VectorType::Builder(resultVectorType).dropDim(dim);
1689-
// Create an updated elementwise Op without leading/trailing unit dim
1692+
dropNonScalableUnitDimFromType(resultVectorType);
1693+
// Create an updated elementwise Op without unit dim.
16901694
Operation *elementwiseOp =
16911695
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
16921696
newResultVectorType, op->getAttrs());
16931697

1694-
// Restore the leading/trailing unit dim by applying vector.shape_cast
1695-
// to the result
1698+
// Restore the unit dim by applying vector.shape_cast to the result.
16961699
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
16971700
elementwiseOp->getResult(0));
16981701

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,42 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
499499

500500
// -----
501501

502+
func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
503+
%arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
504+
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
505+
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
506+
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
507+
return %res : vector<8x3xf128>
508+
}
509+
510+
// CHECK-LABEL: func.func @fold_inner_unit_dim(
511+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
512+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
513+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
514+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
515+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
516+
// CHECK: return %[[VAL_4]] : vector<8x3xf128>
517+
518+
// -----
519+
520+
func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
521+
%arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
522+
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
523+
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
524+
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
525+
return %res : vector<8x[1]x3xf128>
526+
}
527+
528+
// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
529+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
530+
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
531+
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
532+
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
533+
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
534+
// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>
535+
536+
// -----
537+
502538
func.func @negative_out_of_bound_transfer_read(
503539
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
504540
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)