Skip to content

[mlir][ArmSME] Fold extracts from 3D create_masks of SME-like masks #80148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 2, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jan 31, 2024

When unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks.

i.e.:

%mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
        : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>

ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below.

%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>

When unrolling the reduction dimension of something like a matmul for
SME, it is possible to get 3D masks, which are vectors of SME-like
masks. The 2D masks for individual operations are then extracted from
the 3D masks.

i.e.:

```mlir
%mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
        : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
```

ArmSME only supports lowering 2D create_masks, so we must fold the
extract into the create_mask. This can be done by checking if the
extraction index is within the true region, then using that select the
first dimension of the 2D mask. This is shown below.

```mlir
%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
```
@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2024

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

When unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks.

i.e.:

%mask = vector.create_mask %nonConstantDim, %a, %b : vector&lt;4x[4]x[4]xi1&gt;
%subMask = vector.extract %mask[2]
        : vector&lt;[4]x[4]xi1&gt; from vector&lt;4x[4]x[4]xi1&gt;

ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below.

%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector&lt;[4]x[4]xi1&gt;

Full diff: https://github.com/llvm/llvm-project/pull/80148.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+80-3)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+36)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 85ec53c2618aa..14b9d8e34da65 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -7,13 +7,12 @@
 //===----------------------------------------------------------------------===//
 //
 // This pass legalizes vector operations so they can be lowered to ArmSME.
-// Currently, this only implements the decomposition of vector operations that
-// use vector sizes larger than an SME tile, into multiple SME-sized operations.
 //
 // Note: In the context of this pass 'tile' always refers to an SME tile.
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Transforms/Passes.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
 
 namespace {
 
+//===----------------------------------------------------------------------===//
+// Decomposition of vector operations larger than an SME tile
+//===----------------------------------------------------------------------===//
+
 // Common match failure reasons.
 static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
     "op vector size is not multiple of SME tiles");
@@ -338,13 +341,86 @@ struct LegalizeTransferWriteOpsByDecomposition
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
+/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
+/// necessary for the mask to be lowered to ArmSME.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
+///  %subMask = vector.extract %mask[2]
+///          : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
+///  %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
+///  %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
+///  ```
+struct FoldExtractFromVectorOfSMELikeCreateMasks
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    auto loc = extractOp.getLoc();
+    auto createMaskOp =
+        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(
+          extractOp, "extract not from vector.create_mask op");
+
+    VectorType extractedMaskType =
+        llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+    if (!extractedMaskType)
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "extracted type is not a vector type");
+
+    auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
+    if (numScalable != 2)
+      return rewriter.notifyMatchFailure(
+          extractOp, "expected extracted type to be an SME-like mask");
+
+    // TODO: Support multiple extraction indices.
+    if (extractOp.getStaticPosition().size() != 1)
+      return rewriter.notifyMatchFailure(
+          extractOp, "only a single extraction index is supported");
+
+    auto frontMaskDim = createMaskOp.getOperand(0);
+    if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
+      return rewriter.notifyMatchFailure(
+          extractOp,
+          "constant vector.create_masks dims should be folded elsewhere");
+
+    auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto extractionIndex = getValueOrCreateConstantIndexOp(
+        rewriter, loc, extractOp.getMixedPosition()[0]);
+    auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
+        loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
+        frontMaskDim);
+    auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
+        loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
+
+    rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+        extractOp, extractedMaskType,
+        ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
     auto *context = &getContext();
     OneToNTypeConverter converter;
     RewritePatternSet patterns(context);
-
     converter.addConversion([](Type type) { return type; });
     converter.addConversion(
         [](VectorType vectorType,
@@ -358,6 +434,7 @@ struct VectorLegalizationPass
           return success();
         });
 
+    patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a20abeefedcfd..a2526db9b4831 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -266,3 +266,39 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
   vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
   return
 }
+
+// -----
+
+// CHECK-LABEL: @extract_from_vector_create_mask_non_constant_dim(
+// CHECK-SAME:                                                    %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                    %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                    %[[DIM2:[a-z0-9]+]]: index)
+func.func @extract_from_vector_create_mask_non_constant_dim(%dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+  // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi sgt, %[[DIM0]], %[[C2]] : index
+  // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
+  // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
+  // CHECK-NEXT: return %[[EXTRACT]]
+  %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
+  %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+  return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @non_constant_extract_from_vector_create_mask_non_constant(
+// CHECK-SAME:                                                             %[[INDEX:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                             %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                             %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME:                                                             %[[DIM2:[a-z0-9]+]]: index)
+func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: index, %dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
+  // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi slt, %[[INDEX]], %[[DIM0]] : index
+  // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
+  // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
+  // CHECK-NEXT: return %[[EXTRACT]]
+  %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
+  %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+  return %extract : vector<[4]x[4]xi1>
+}

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionally the rewrite looks good to me, but I am wondering if it should be moved alongside Vector dialect transforms. I know this is only relevant for ArmSME right now but semantically these operations and types are expressible in the Vector dialect and if we pretended there were another target where 2-D scalable vector types were relevant then this would apply equally as well I think?

@MacDue
Copy link
Member Author

MacDue commented Jan 31, 2024

Functionally the rewrite looks good to me, but I am wondering if it should be moved alongside Vector dialect transforms. I know this is only relevant for ArmSME right now but semantically these operations and types are expressible in the Vector dialect and if we pretended there were another target where 2-D scalable vector types were relevant then this would apply equally as well I think?

I don't want to muddy the vector dialect with rewrites like this, -arm-sme-vector-legalization was always intended to include rewrites like this (it purely works at the vector dialect level already). The reasons for this rewrite are tied to SME and how the current lowering decomposes masks, and I feel it's unlikely to become generally useful.

It also makes things simpler to have everything in one place when it comes to integrating stuff into IREE.

@c-rhodes
Copy link
Collaborator

c-rhodes commented Feb 1, 2024

Functionally the rewrite looks good to me, but I am wondering if it should be moved alongside Vector dialect transforms. I know this is only relevant for ArmSME right now but semantically these operations and types are expressible in the Vector dialect and if we pretended there were another target where 2-D scalable vector types were relevant then this would apply equally as well I think?

I don't want to muddy the vector dialect with rewrites like this, -arm-sme-vector-legalization was always intended to include rewrites like this (it purely works at the vector dialect level already). The reasons for this rewrite are tied to SME and how the current lowering decomposes masks, and I feel it's unlikely to become generally useful.

I thought the same for code I specifically wrote for SVE in the past but it turned out to not be the case :)

But I do understand it's not obvious this will be generally useful, and probably more so than previous examples I'm referring to. There is a ton of common infrastructure we're benefiting from here and it's good to pay it forward where we can, but finding that line isn't always easy.

It also makes things simpler to have everything in one place when it comes to integrating stuff into IREE.

Hadn't considered IREE integration.

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, cheers

@MacDue MacDue merged commit c2dea71 into llvm:main Feb 2, 2024
@MacDue MacDue deleted the create_mask_sme_fold branch February 2, 2024 10:06
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
…lvm#80148)

When unrolling the reduction dimension of something like a matmul for
SME, it is possible to get 3D masks, which are vectors of SME-like
masks. The 2D masks for individual operations are then extracted from
the 3D masks.

i.e.:

```mlir
%mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
        : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
```

ArmSME only supports lowering 2D create_masks, so we must fold the
extract into the create_mask. This can be done by checking if the
extraction index is within the true region, then using that select the
first dimension of the 2D mask. This is shown below.

```mlir
%extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
```
MacDue added a commit to MacDue/llvm-project that referenced this pull request Feb 8, 2024
…mlir

This tests both llvm#80148 and llvm#80170 work together to allow unrolling the
reduction dimension of a matmul.
MacDue added a commit that referenced this pull request Feb 12, 2024
…mlir (#81160)

This tests both #80148 and #80170 work together to allow unrolling the
reduction dimension of a matmul.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants