Skip to content

[vector][mlir] Canonicalize to shape_cast where possible #140583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 129 additions & 63 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1691,10 +1691,36 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}

/// vector.splat, and vector.shape_cast that just prepends 1's are
/// special cases of vector.broadcast. This function returns true
/// if \p op is one of these operations.
static bool isBroadcastLike(Operation *op) {

if (isa<vector::BroadcastOp, SplatOp>(op))
return true;

// a shape_cast which just prepends 1's is broadcast-like.
auto shapeCast = dyn_cast<vector::ShapeCastOp>(op);
if (!shapeCast)
return false;

ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
ArrayRef<int64_t> srcShape = shapeCast.getSourceVectorType().getShape();

// A rank-reducing shape_cast cannot be broadcast-like.
if (srcShape.size() > dstShape.size())
return false;

bool isSuffix = (srcShape == dstShape.take_back(srcShape.size()));
return isSuffix;
}

/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
static Value foldExtractFromBroadcastLike(ExtractOp extractOp) {

Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))

if (!defOp || !isBroadcastLike(defOp))
return Value();

Value source = defOp->getOperand(0);
Expand All @@ -1721,14 +1747,22 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
broadcastVecType.getShape().take_back(extractResultRank))
return Value();

auto broadcastOp = cast<vector::BroadcastOp>(defOp);
int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank();
assert(defOp->getNumResults() == 1 && "all broadcast-like ops have 1 result");
auto dstType = dyn_cast<VectorType>(defOp->getResult(0).getType());
assert(dstType && "all broadcast-like ops have vector results");

int64_t broadcastDstRank = dstType.getRank();

// Detect all the positions that come from "dim-1" broadcasting.
// These dimensions correspond to "dim-1" broadcasted dims; set the mathching
// These dimensions correspond to "dim-1" broadcasted dims; set the matching
// extract position to `0` when extracting from the source operand.
llvm::SetVector<int64_t> broadcastedUnitDims =
broadcastOp.computeBroadcastedUnitDims();
auto broadcastedUnitDims = [&]() -> llvm::SetVector<int64_t> {
if (auto broadcastOp = dyn_cast<BroadcastOp>(defOp)) {
return broadcastOp.computeBroadcastedUnitDims();
}
return {};
}();

SmallVector<OpFoldResult> extractPos(extractOp.getMixedPosition());
OpBuilder b(extractOp.getContext());
int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank;
Expand Down Expand Up @@ -2163,7 +2197,7 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return getResult();
if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
return res;
if (auto res = foldExtractFromBroadcast(*this))
if (auto res = foldExtractFromBroadcastLike(*this))
return res;
if (auto res = foldExtractFromShuffle(*this))
return res;
Expand All @@ -2181,15 +2215,16 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {

namespace {

// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
// Pattern to rewrite a ExtractOp(broadcast-like) -> Broadcast.
class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
Operation *defOp = extractOp.getVector().getDefiningOp();
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))

if (!defOp || !isBroadcastLike(defOp))
return failure();

Value source = defOp->getOperand(0);
Expand Down Expand Up @@ -2351,11 +2386,41 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
return success();
}

/// BEFORE:
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
/// AFTER:
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
VectorType sourceType = extractOp.getSourceVectorType();
VectorType outType = dyn_cast<VectorType>(extractOp.getType());
if (!outType)
return failure();

// Negative values in `position` indicates poison, which cannot be
// represented with a shape_cast
if (llvm::any_of(extractOp.getMixedPosition(),
[](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
return failure();

if (sourceType.getNumElements() != outType.getNumElements())
return failure();

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, outType,
extractOp.getVector());
return success();
}
};

} // namespace

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context);
results
.add<ExtractOpFromBroadcast, ExtractOpFromCreateMask, ExtractToShapeCast>(
context);
results.add(foldExtractFromShapeCastToShapeCast);
results.add(foldExtractFromFromElements);
}
Expand Down Expand Up @@ -2867,13 +2932,36 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
return success();
}
};

/// BEFORE:
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
/// AFTER:
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
struct BroadcastToShapeCast final
: public OpRewritePattern<vector::BroadcastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::BroadcastOp broadcast,
PatternRewriter &rewriter) const override {
auto sourceType = dyn_cast<VectorType>(broadcast.getSourceType());
if (!sourceType) {
return rewriter.notifyMatchFailure(
broadcast, "source is a scalar, shape_cast doesn't support scalar");
}

VectorType outType = broadcast.getType();
if (sourceType.getNumElements() != outType.getNumElements())
return failure();

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcast, outType,
broadcast.getSource());
return success();
}
};
} // namespace

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// BroadcastToShapeCast is not a default canonicalization, it is opt-in by
// calling `populateCastAwayVectorLeadingOneDimPatterns`
results.add<BroadcastFolder>(context);
results.add<BroadcastFolder, BroadcastToShapeCast>(context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -5991,10 +6079,7 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};

/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as either
/// i) Y = ShapeCast(X), or
/// ii) Y = Broadcast(X)
/// If both (i) and (ii) are possible, (i) is chosen.
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
Expand All @@ -6009,22 +6094,6 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
auto srcVectorType = dyn_cast<VectorType>(broadcastOp.getSourceType());
bool srcIsScalar = !srcVectorType;

// Replace Y = ShapeCast(Broadcast(X)) with Y = ShapeCast(X).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as now it happens in 2 steps during canonicalization. The first converts the Broadcast to a ShapeCast. The second combines the 2 ShapeCasts.

// Example:
// %0 = vector.broadcast %in : vector<3x4xf32> to vector<1x3x4xf32>
// %1 = vector.shape_cast %0 : vector<1x3x4xf32> to vector<12xf32>
// to
// %1 = vector.shape_cast %in : vector<3x4xf32> to vector<12xf32>
if (srcVectorType) {
if (srcVectorType.getNumElements() ==
shapeCastOp.getResultVectorType().getNumElements()) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
shapeCastOp, shapeCastOp.getResultVectorType(),
broadcastOp.getSource());
return success();
}
}

// Replace Y = ShapeCast(Broadcast(X)) with Y = Broadcast(X)
// Example
// %0 = vector.broadcast %in : vector<3xf32> to vector<2x4x3xf32>
Expand Down Expand Up @@ -6233,7 +6302,7 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
//
// Example of what NOT to fold:
// Example of what not to fold:
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
//
if (getSourceVectorType() == getResultVectorType() &&
Expand Down Expand Up @@ -6359,32 +6428,6 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
}
};

/// Folds transpose(shape_cast) into a new shape_cast.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this, as it now happens in 2 steps during canonicalization. The first (new) step is to rewrite the transpose as a shape_cast. The second step is to fold shape_cast(shape_cast) to shape_cast.

class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
auto shapeCastOp =
transposeOp.getVector().getDefiningOp<vector::ShapeCastOp>();
if (!shapeCastOp)
return failure();
if (!isOrderPreserving(transposeOp))
return failure();

VectorType resultType = transposeOp.getType();

// We don't need to check isValidShapeCast at this point, because it is
// guaranteed that merging the transpose into the the shape_cast is a valid
// shape_cast, because the transpose just inserts/removes ones.

rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transposeOp, resultType,
shapeCastOp.getSource());
return success();
}
};

/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
Expand Down Expand Up @@ -6480,12 +6523,35 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
}
};

/// BEFORE:
/// %0 = vector.transpose %arg0, [0, 2, 1] :
/// vector<2x1x2xf32> to vector<2x2x1xf32>
/// AFTER:
/// %0 = vector.shape_cast %arg0 :
/// vector<2x1x2xf32> to vector<2x2x1xf32>
struct TransposeToShapeCast final
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransposeOp transpose,
PatternRewriter &rewriter) const override {

if (!isOrderPreserving(transpose)) {
return rewriter.notifyMatchFailure(
transpose, "not order preserving, so not semantically a 'copy'");
}
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
transpose, transpose.getType(), transpose.getVector());
return success();
}
};

} // namespace

void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
FoldTransposeSplat, FoldTransposeBroadcast>(context);
results.add<FoldTransposeBroadcast, FoldTransposeCreateMask,
FoldTransposeSplat, TransposeFolder, TransposeToShapeCast>(
context);
}

//===----------------------------------------------------------------------===//
Expand Down
63 changes: 2 additions & 61 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand Down Expand Up @@ -382,64 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
vector::VectorTransposeLowering vectorTransposeLowering;
};

/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Author note: I've removed this pattern, as it is a special case of TransposeToShapeCast

/// to 2D vectors with at least one unit dim. For example:
///
/// Replace:
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
/// vector<1x4xi32>
/// with:
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
///
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
/// be fixed. Non-unit dim can be scalable.
///
/// TODO: This pattern was introduced specifically to help lower scalable
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
/// to cancel out) would be preferable:
///
/// BEFORE:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
/// AFTER:
/// %0 = some_op
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
///
/// Given the context above, we may want to consider (re-)moving this pattern
/// at some later time. I am leaving it for now in case there are other users
/// that I am not aware of.
class Transpose2DWithUnitDimToShapeCast
: public OpRewritePattern<vector::TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;

Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}

LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
Value input = op.getVector();
VectorType resType = op.getResultVectorType();

// Set up convenience transposition table.
ArrayRef<int64_t> transp = op.getPermutation();

if (resType.getRank() == 2 &&
((resType.getShape().front() == 1 &&
!resType.getScalableDims().front()) ||
(resType.getShape().back() == 1 &&
!resType.getScalableDims().back())) &&
transp == ArrayRef<int64_t>({1, 0})) {
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
return success();
}

return failure();
}
};

/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
/// If the strategy is Shuffle1D, it will be lowered to:
/// vector.shape_cast 2D -> 1D
Expand Down Expand Up @@ -511,8 +452,8 @@ class TransposeOp2DToShuffleLowering
void mlir::vector::populateVectorTransposeLoweringPatterns(
RewritePatternSet &patterns,
VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
benefit);
TransposeOp::getCanonicalizationPatterns(patterns, patterns.getContext());
ShapeCastOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
vectorTransposeLowering, patterns.getContext(), benefit);
}
8 changes: 4 additions & 4 deletions mlir/test/Dialect/ArmSME/vector-legalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i

// -----

// The pass should do nothing (and not crash).
// CHECK-LABEL: @illegal_transpose_no_defining_source_op
func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
// CHECK-LABEL: @transpose_no_defining_source_op
func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
{
// CHECK: vector.transpose
// CHECK: vector.shape_cast
// CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space I'm getting back to this PR. Peephole question: is this operation ok? i.e. is

vector.shape_cast %a vector<[4]x1xf32> to vector<1x[4]xf32>

an acceptable operation to have after running mlir-opt -arm-sme-vector-legalization -cse -canonicalize ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, yes. But I can't guarantee there's no logic that expects vector<[4]x1xf32> instead of vector<1x[4]xf32> ;-) If that's the case, we will fix it and I will be grateful for uncovering this :)

%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}
Expand Down
Loading