-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Migrate arm-sme-vector-legalization
to dialect conversion
#121101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Migrate arm-sme-vector-legalization
to dialect conversion
#121101
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: Matthias Springer (matthias-springer) ChangesUse the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon. Full diff: https://github.com/llvm/llvm-project/pull/121101.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 61767f3b21c9c3..12c65a72babcb8 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -17,7 +17,7 @@
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -25,7 +25,8 @@
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "arm-sme-vector-legalization"
@@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
/// Legalize `arith.constant dense<value>` splat operations to fit within SME
/// tiles by decomposing them into tile-sized operations.
struct LegalizeArithConstantOpsByDecomposition
- : public OneToNOpConversionPattern<arith::ConstantOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = dyn_cast<VectorType>(constantOp.getType());
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
if (!vectorType || !denseAttr || !denseAttr.isSplat())
@@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
auto tileSplat = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
- rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
- adaptor.getResultMapping());
+ SmallVector<Value> repl(tileCount, tileSplat);
+ rewriter.replaceOpWithMultiple(constantOp, {repl});
return success();
}
@@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeVectorOuterProductOpsByDecomposition
- : public OneToNOpConversionPattern<vector::OuterProductOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::OuterProductOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::OuterProductOp outerProductOp,
+ OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = outerProductOp.getResultVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(outerProductOp,
@@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
auto maskOp = outerProductOp.getMaskingOp();
mask = maskOp.getMask();
rootOp = maskOp;
+ rewriter.setInsertionPoint(rootOp);
}
if (!isSupportedMaskOp(mask))
@@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
resultSMETiles.push_back(maskedOuterProduct->getResult(0));
}
- rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+ rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
return success();
}
};
@@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
// (invalid). This pattern matches on `vector.mask` then calls into the
// `vector.outerproduct` pattern to work around this issue.
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
- : public OneToNOpConversionPattern<vector::MaskOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::MaskOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
maskOp.getMaskableOp())) {
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
@@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
/// Legalize `vector.transfer_read` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferReadOpsByDecomposition
- : public OneToNOpConversionPattern<vector::TransferReadOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::TransferReadOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = readOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(readOp,
@@ -319,7 +322,7 @@ struct LegalizeTransferReadOpsByDecomposition
resultSMETiles.push_back(smeRead);
}
- rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+ rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
return success();
}
};
@@ -327,12 +330,12 @@ struct LegalizeTransferReadOpsByDecomposition
/// Legalize `vector.transfer_write` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeTransferWriteOpsByDecomposition
- : public OneToNOpConversionPattern<vector::TransferWriteOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::TransferWriteOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
auto vectorType = writeOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(writeOp,
@@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition
/// }
/// ```
struct LegalizeMultiTileTransferWriteAsStoreLoop
- : public OneToNOpConversionPattern<vector::TransferWriteOp> {
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
+ : public OpConversionPattern<vector::TransferWriteOp> {
+ using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
- OneToNPatternRewriter &rewriter) const override {
+ matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
if (writeOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
writeOp, "TODO: tensor semantics are unsupported");
@@ -936,10 +939,16 @@ struct VectorLegalizationPass
return success();
});
- patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
- LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes,
- LowerIllegalTransposeStoreViaZA>(context);
+ // Apply preprocessing patterns.
+ RewritePatternSet rewritePatterns(context);
+ rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
+ LiftIllegalVectorTransposeToMemory,
+ ConvertIllegalShapeCastOpsToTransposes,
+ LowerIllegalTransposeStoreViaZA>(context);
+ if (failed(
+ applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
+ return signalPassFailure();
+
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
@@ -950,11 +959,20 @@ struct VectorLegalizationPass
LegalizeVectorOuterProductOpsByDecomposition,
LegalizeTransferReadOpsByDecomposition,
LegalizeTransferWriteOpsByDecomposition>(converter, context);
- populateFuncTypeConversionPatterns(converter, patterns);
- scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
-
- if (failed(applyPartialOneToNConversion(getOperation(), converter,
- std::move(patterns))))
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
+ populateReturnOpTypeConversionPattern(patterns, converter);
+ scf::populateSCFStructuralTypeConversions(converter, patterns);
+
+ ConversionTarget target(getContext());
+ target.markUnknownOpDynamicallyLegal(
+ [&](Operation *op) { return converter.isLegal(op); });
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return converter.isSignatureLegal(op.getFunctionType());
+ });
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
return signalPassFailure();
}
};
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, massive thanks for pushing on this!
Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.