Skip to content

[mlir][ArmSME] Migrate arm-sme-vector-legalization to dialect conversion #121101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 31, 2024

Conversation

matthias-springer
Copy link
Member

Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.

@llvmbot
Copy link
Member

llvmbot commented Dec 25, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Matthias Springer (matthias-springer)

Changes

Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.


Full diff: https://github.com/llvm/llvm-project/pull/121101.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+56-38)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 61767f3b21c9c3..12c65a72babcb8 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -17,7 +17,7 @@
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
 #include "mlir/Dialect/Index/IR/IndexOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -25,7 +25,8 @@
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/Transforms/OneToNTypeConversion.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 #define DEBUG_TYPE "arm-sme-vector-legalization"
 
@@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
 /// Legalize `arith.constant dense<value>` splat operations to fit within SME
 /// tiles by decomposing them into tile-sized operations.
 struct LegalizeArithConstantOpsByDecomposition
-    : public OneToNOpConversionPattern<arith::ConstantOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<arith::ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
   matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+                  ConversionPatternRewriter &rewriter) const override {
     auto vectorType = dyn_cast<VectorType>(constantOp.getType());
     auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
     if (!vectorType || !denseAttr || !denseAttr.isSplat())
@@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition
     auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
     auto tileSplat = rewriter.create<arith::ConstantOp>(
         constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
-    rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
-                       adaptor.getResultMapping());
+    SmallVector<Value> repl(tileCount, tileSplat);
+    rewriter.replaceOpWithMultiple(constantOp, {repl});
 
     return success();
   }
@@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition
 /// Legalize `vector.outerproduct` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeVectorOuterProductOpsByDecomposition
-    : public OneToNOpConversionPattern<vector::OuterProductOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::OuterProductOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::OuterProductOp outerProductOp,
+                  OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto vectorType = outerProductOp.getResultVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
       return rewriter.notifyMatchFailure(outerProductOp,
@@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
       auto maskOp = outerProductOp.getMaskingOp();
       mask = maskOp.getMask();
       rootOp = maskOp;
+      rewriter.setInsertionPoint(rootOp);
     }
 
     if (!isSupportedMaskOp(mask))
@@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
       resultSMETiles.push_back(maskedOuterProduct->getResult(0));
     }
 
-    rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
+    rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
     return success();
   }
 };
@@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
 // (invalid). This pattern matches on `vector.mask` then calls into the
 // `vector.outerproduct` pattern to work around this issue.
 struct LegalizeMaskedVectorOuterProductOpsByDecomposition
-    : public OneToNOpConversionPattern<vector::MaskOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::MaskOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
             maskOp.getMaskableOp())) {
       LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
@@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
 /// Legalize `vector.transfer_read` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeTransferReadOpsByDecomposition
-    : public OneToNOpConversionPattern<vector::TransferReadOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::TransferReadOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto vectorType = readOp.getVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
       return rewriter.notifyMatchFailure(readOp,
@@ -319,7 +322,7 @@ struct LegalizeTransferReadOpsByDecomposition
       resultSMETiles.push_back(smeRead);
     }
 
-    rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
+    rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
     return success();
   }
 };
@@ -327,12 +330,12 @@ struct LegalizeTransferReadOpsByDecomposition
 /// Legalize `vector.transfer_write` operations to fit within SME tiles by
 /// decomposing them into tile-sized operations.
 struct LegalizeTransferWriteOpsByDecomposition
-    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::TransferWriteOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     auto vectorType = writeOp.getVectorType();
     if (!isMultipleOfSMETileVectorType(vectorType))
       return rewriter.notifyMatchFailure(writeOp,
@@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition
 /// }
 /// ```
 struct LegalizeMultiTileTransferWriteAsStoreLoop
-    : public OneToNOpConversionPattern<vector::TransferWriteOp> {
-  using OneToNOpConversionPattern::OneToNOpConversionPattern;
+    : public OpConversionPattern<vector::TransferWriteOp> {
+  using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
-                  OneToNPatternRewriter &rewriter) const override {
+  matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
     if (writeOp.hasPureTensorSemantics())
       return rewriter.notifyMatchFailure(
           writeOp, "TODO: tensor semantics are unsupported");
@@ -936,10 +939,16 @@ struct VectorLegalizationPass
           return success();
         });
 
-    patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
-                 LiftIllegalVectorTransposeToMemory,
-                 ConvertIllegalShapeCastOpsToTransposes,
-                 LowerIllegalTransposeStoreViaZA>(context);
+    // Apply preprocessing patterns.
+    RewritePatternSet rewritePatterns(context);
+    rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
+                        LiftIllegalVectorTransposeToMemory,
+                        ConvertIllegalShapeCastOpsToTransposes,
+                        LowerIllegalTransposeStoreViaZA>(context);
+    if (failed(
+            applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
+      return signalPassFailure();
+
     // Note: These two patterns are added with a high benefit to ensure:
     //  - Masked outer products are handled before unmasked ones
     //  - Multi-tile writes are lowered as a store loop (if possible)
@@ -950,11 +959,20 @@ struct VectorLegalizationPass
                  LegalizeVectorOuterProductOpsByDecomposition,
                  LegalizeTransferReadOpsByDecomposition,
                  LegalizeTransferWriteOpsByDecomposition>(converter, context);
-    populateFuncTypeConversionPatterns(converter, patterns);
-    scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
-
-    if (failed(applyPartialOneToNConversion(getOperation(), converter,
-                                            std::move(patterns))))
+    populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+                                                                   converter);
+    populateCallOpTypeConversionPattern(patterns, converter);
+    populateReturnOpTypeConversionPattern(patterns, converter);
+    scf::populateSCFStructuralTypeConversions(converter, patterns);
+
+    ConversionTarget target(getContext());
+    target.markUnknownOpDynamicallyLegal(
+        [&](Operation *op) { return converter.isLegal(op); });
+    target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+      return converter.isSignatureLegal(op.getFunctionType());
+    });
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
       return signalPassFailure();
   }
 };

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, massive thanks for pushing on this!

@matthias-springer matthias-springer merged commit 31613de into main Dec 31, 2024
11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/arm_dialect_conv branch December 31, 2024 11:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants