Skip to content

[mlir][ArmSME] Lower extract from 2D scalable create_mask to psel #96066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 20, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jun 19, 2024

Example:

%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%slice = vector.extract %mask[%index]
           : vector<[8]xi1> from vector<[4]x[8]xi1>

Becomes:

%mask_rows = vector.create_mask %a : vector<[4]xi1>
%mask_cols = vector.create_mask %b : vector<[8]xi1>
%slice = arm_sve.psel %mask_cols, %mask_rows[%index]
           : vector<[8]xi1>, vector<[4]xi1>

Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this is currently the most logical place for this lowering.

Example:
```mlir
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%slice = vector.extract %mask[%index]
                                  : vector<[8]xi1> from vector<[4]x[8]xi1>
```
Becomes:
```
%mask_rows = vector.create_mask %a : vector<[4]xi1>
%mask_cols = vector.create_mask %b : vector<[8]xi1>
%slice = arm_sve.psel %mask_cols, %mask_rows[%index]
                                   : vector<[8]xi1>, vector<[4]xi1>
```

Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this
is currently the most logical place for this lowering.
@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2024

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

Example:

%mask = vector.create_mask %a, %b : vector&lt;[4]x[8]xi1&gt;
%slice = vector.extract %mask[%index]
                                  : vector&lt;[8]xi1&gt; from vector&lt;[4]x[8]xi1&gt;

Becomes:

%mask_rows = vector.create_mask %a : vector&lt;[4]xi1&gt;
%mask_cols = vector.create_mask %b : vector&lt;[8]xi1&gt;
%slice = arm_sve.psel %mask_cols, %mask_rows[%index]
                                   : vector&lt;[8]xi1&gt;, vector&lt;[4]xi1&gt;

Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this is currently the most logical place for this lowering.


Full diff: https://github.com/llvm/llvm-project/pull/96066.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+1-1)
  • (modified) mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+75-3)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp (+1)
  • (modified) mlir/test/Conversion/VectorToArmSME/unsupported.mlir (+51)
  • (modified) mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir (+32)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index db67d6a5ff128..9ab5faf9559a3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1276,7 +1276,7 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
     Pass that converts vector dialect operations into equivalent ArmSME dialect
     operations.
   }];
-  let dependentDialects = ["arm_sme::ArmSMEDialect"];
+  let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
index b062f65e914e8..6a81a09776d37 100644
--- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME
 
   LINK_LIBS PUBLIC
   MLIRArmSMEDialect
+  MLIRArmSVEDialect
   MLIRLLVMCommonConversion
   )
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 56ae46a6098ee..0e8575531d9b0 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -10,6 +10,7 @@
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/Support/Casting.h"
@@ -549,6 +550,77 @@ struct VectorExtractToArmSMELowering
   }
 };
 
+/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
+/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
+/// SVE 2.1), so this is currently the most logical place for this lowering.
+///
+/// Example:
+/// ```mlir
+/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+/// %slice = vector.extract %mask[%index]
+///                                   : vector<[8]xi1> from vector<[4]x[8]xi1>
+/// ```
+/// Becomes:
+/// ```
+/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
+/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
+/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
+///                                    : vector<[8]xi1>, vector<[4]xi1>
+/// ```
+struct VectorExtractFromMaskToPselLowering
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    if (extractOp.getNumIndices() != 1)
+      return rewriter.notifyMatchFailure(extractOp, "not single extract index");
+
+    auto resultType = extractOp.getResult().getType();
+    auto resultVectorType = dyn_cast<VectorType>(resultType);
+    if (!resultVectorType)
+      return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
+
+    auto createMaskOp =
+        extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+    if (!createMaskOp)
+      return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
+
+    auto maskType = createMaskOp.getVectorType();
+    if (maskType.getRank() != 2 || !maskType.allDimsScalable())
+      return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
+
+    auto isSVEPredicateSize = [](int64_t size) {
+      return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
+    };
+
+    auto rowsBaseSize = maskType.getDimSize(0);
+    auto colsBaseSize = maskType.getDimSize(1);
+    if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "mask dimensions not SVE predicate-sized");
+
+    auto loc = extractOp.getLoc();
+    VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
+    VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
+
+    // Create the two 1-D masks at the location of the 2-D create_mask (which is
+    // usually outside a loop). This prevents the need for later hoisting.
+    rewriter.setInsertionPoint(createMaskOp);
+    auto rowMask = rewriter.create<vector::CreateMaskOp>(
+        loc, rowMaskType, createMaskOp.getOperand(0));
+    auto colMask = rewriter.create<vector::CreateMaskOp>(
+        loc, colMaskType, createMaskOp.getOperand(1));
+
+    rewriter.setInsertionPoint(extractOp);
+    auto position =
+        vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
+    rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
+                                                 position[0]);
+    return success();
+  }
+};
+
 /// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
 /// `arm_sme.move_tile_slice_to_vector`.
 ///
@@ -728,7 +800,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
            TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
            TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
            VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
-           VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
-           VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
-          &ctx);
+           VectorExtractToArmSMELowering, VectorExtractFromMaskToPselLowering,
+           VectorInsertToArmSMELowering, VectorPrintToArmSMELowering,
+           FoldTransferWriteOfExtractTileSlice>(&ctx);
 }
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
index 2601f31be11a3..cc00bf4ca190a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
 
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 8ed52cde784ce..ff7b4bcb5f65a 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -192,3 +192,54 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
+
+// -----
+
+/// Not SVE predicate-sized.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_0
+func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1>
+{
+  // CHECK-NOT: arm_sve.psel
+  %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1>
+  %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1>
+  return %slice : vector<[32]xi1>
+}
+
+// -----
+
+/// Source not 2-D scalable mask.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_1
+func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1>
+{
+  // CHECK-NOT: arm_sve.psel
+  %mask = vector.create_mask %a, %b : vector<4x[8]xi1>
+  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1>
+  return %slice : vector<[8]xi1>
+}
+
+// -----
+
+/// Source not vector.create_mask.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_2
+func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1>
+{
+  // CHECK-NOT: arm_sve.psel
+  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+  return %slice : vector<[8]xi1>
+}
+
+// -----
+
+/// Not psel-like extract.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_3
+func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1
+{
+  // CHECK-NOT: arm_sve.psel
+  %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+  %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1>
+  return %el : i1
+}
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 8aeffb066de90..ff21c70b2aa55 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -1320,3 +1320,35 @@ func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
   %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
   return %el : f64
 }
+
+// -----
+
+// CHECK-LABEL: @dynamic_vector_extract_mask_to_psel(
+// CHECK-SAME:                                       %[[A:[a-z0-9]+]]:  index,
+// CHECK-SAME:                                       %[[B:[a-z0-9]+]]: index,
+// CHECK-SAME:                                       %[[INDEX:[a-z0-9]+]]: index)
+func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1>
+{
+  // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1>
+  // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1>
+  // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1>
+  %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+  %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+  return %slice : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_mask_to_psel(
+// CHECK-SAME:                               %[[A:[a-z0-9]+]]:  index,
+// CHECK-SAME:                               %[[B:[a-z0-9]+]]: index)
+func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1>
+{
+  // CHECK: %[[C1:.*]] = arith.constant 1 : index
+  // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[16]xi1>
+  // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[2]xi1>
+  // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[C1]]] : vector<[2]xi1>, vector<[16]xi1>
+  %mask = vector.create_mask %a, %b : vector<[16]x[2]xi1>
+  %slice = vector.extract %mask[1] : vector<[2]xi1> from vector<[16]x[2]xi1>
+  return %slice : vector<[2]xi1>
+}

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

few minor nits, otherwise LGTM cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, modulo a couple of small requests. Thanks!

@MacDue MacDue merged commit e2296d8 into llvm:main Jun 20, 2024
5 of 6 checks passed
@MacDue MacDue deleted the lower_to_psel branch June 20, 2024 09:27
MacDue added a commit to MacDue/llvm-project that referenced this pull request Jul 3, 2024
…vm#96066)

Example:
```mlir
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%slice = vector.extract %mask[%index]
           : vector<[8]xi1> from vector<[4]x[8]xi1>
```
Becomes:
```mlir
%mask_rows = vector.create_mask %a : vector<[4]xi1>
%mask_cols = vector.create_mask %b : vector<[8]xi1>
%slice = arm_sve.psel %mask_cols, %mask_rows[%index]
           : vector<[8]xi1>, vector<[4]xi1>
```

Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this
is currently the most logical place for this lowering.
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…vm#96066)

Example:
```mlir
%mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%slice = vector.extract %mask[%index]
           : vector<[8]xi1> from vector<[4]x[8]xi1>
```
Becomes:
```mlir
%mask_rows = vector.create_mask %a : vector<[4]xi1>
%mask_cols = vector.create_mask %b : vector<[8]xi1>
%slice = arm_sve.psel %mask_cols, %mask_rows[%index]
           : vector<[8]xi1>, vector<[4]xi1>
```

Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this
is currently the most logical place for this lowering.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants