Skip to content

[mlir][ArmSME][NFC] Check early for unsupported mask ops #135955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

This is to avoid rollbacks in the dialect conversion, which are expensive.

Note: This is in preparation of the One-Shot Dialect Conversion refactoring.

This is to avoid rollbacks in the dialect conversion, which are expensive.

Note: This is in preparation of the One-Shot Dialect Conversion refactoring.
@llvmbot
Copy link
Member

llvmbot commented Apr 16, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This is to avoid rollbacks in the dialect conversion, which are expensive.

Note: This is in preparation of the One-Shot Dialect Conversion refactoring.


Full diff: https://github.com/llvm/llvm-project/pull/135955.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+12-5)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 6ed29903ea407..630414030d98b 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -77,11 +77,6 @@ FailureOr<scf::ForOp> createLoadStoreForOverTileSlices(
   Value upperBound;
   if (mask) {
     auto createMaskOp = mask.getDefiningOp<vector::CreateMaskOp>();
-    if (!createMaskOp)
-      return rewriter.notifyMatchFailure(
-          loc, "unsupported mask op, only 'vector.create_mask' is "
-               "currently supported");
-
     auto maskDim0 = createMaskOp.getOperands()[0];
     auto maskDim1 = createMaskOp.getOperands()[1];
 
@@ -184,6 +179,10 @@ struct TileLoadOpConversion : public OpRewritePattern<arm_sme::TileLoadOp> {
 
     Value initTile;
     if (mask) {
+      if (!mask.getDefiningOp<vector::CreateMaskOp>())
+        return rewriter.notifyMatchFailure(
+            loc, "unsupported mask op, only 'vector.create_mask' is "
+                 "currently supported");
       auto padOp = tileLoadOp.getPadding();
       assert(padOp && "expected padding when masking!");
 
@@ -373,6 +372,14 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
 
   LogicalResult matchAndRewrite(arm_sme::TileStoreOp tileStoreOp,
                                 PatternRewriter &rewriter) const override {
+    if (Value mask = tileStoreOp.getMask()) {
+      if (!mask.getDefiningOp<vector::CreateMaskOp>())
+        return rewriter.notifyMatchFailure(
+            tileStoreOp.getLoc(),
+            "unsupported mask op, only 'vector.create_mask' is "
+            "currently supported");
+    }
+
     // Create a loop that stores each active ZA tile slice from memory.
     return createLoadStoreForOverTileSlices(
         rewriter, tileStoreOp.getLoc(), tileStoreOp.getVectorType(),

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

@matthias-springer matthias-springer merged commit 1a09ffe into main Apr 17, 2025
13 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/arm_tile_no_rollback branch April 17, 2025 06:48
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
This is to avoid rollbacks in the dialect conversion, which are
expensive.

Note: This is in preparation of the One-Shot Dialect Conversion
refactoring.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants