-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Lower extract from 2D scalable create_mask to psel #96066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Example: ```mlir %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> ``` Becomes: ``` %mask_rows = vector.create_mask %a : vector<[4]xi1> %mask_cols = vector.create_mask %b : vector<[8]xi1> %slice = arm_sve.psel %mask_cols, %mask_rows[%index] : vector<[8]xi1>, vector<[4]xi1> ``` Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this is currently the most logical place for this lowering.
@llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesExample: %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
%slice = vector.extract %mask[%index]
: vector<[8]xi1> from vector<[4]x[8]xi1> Becomes: %mask_rows = vector.create_mask %a : vector<[4]xi1>
%mask_cols = vector.create_mask %b : vector<[8]xi1>
%slice = arm_sve.psel %mask_cols, %mask_rows[%index]
: vector<[8]xi1>, vector<[4]xi1> Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this is currently the most logical place for this lowering. Full diff: https://github.com/llvm/llvm-project/pull/96066.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index db67d6a5ff128..9ab5faf9559a3 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1276,7 +1276,7 @@ def ConvertVectorToArmSME : Pass<"convert-vector-to-arm-sme"> {
Pass that converts vector dialect operations into equivalent ArmSME dialect
operations.
}];
- let dependentDialects = ["arm_sme::ArmSMEDialect"];
+ let dependentDialects = ["arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
index b062f65e914e8..6a81a09776d37 100644
--- a/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToArmSME/CMakeLists.txt
@@ -10,5 +10,6 @@ add_mlir_conversion_library(MLIRVectorToArmSME
LINK_LIBS PUBLIC
MLIRArmSMEDialect
+ MLIRArmSVEDialect
MLIRLLVMCommonConversion
)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 56ae46a6098ee..0e8575531d9b0 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/Support/Casting.h"
@@ -549,6 +550,77 @@ struct VectorExtractToArmSMELowering
}
};
+/// Lower a `vector.extract` from a 2-D scalable `vector.create_mask` to
+/// `arm_sve.psel`. Note: While psel is under ArmSVE it requires SME (or
+/// SVE 2.1), so this is currently the most logical place for this lowering.
+///
+/// Example:
+/// ```mlir
+/// %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+/// %slice = vector.extract %mask[%index]
+/// : vector<[8]xi1> from vector<[4]x[8]xi1>
+/// ```
+/// Becomes:
+/// ```
+/// %mask_rows = vector.create_mask %a : vector<[4]xi1>
+/// %mask_cols = vector.create_mask %b : vector<[8]xi1>
+/// %slice = arm_sve.psel %mask_cols, %mask_rows[%index]
+/// : vector<[8]xi1>, vector<[4]xi1>
+/// ```
+struct VectorExtractFromMaskToPselLowering
+ : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ if (extractOp.getNumIndices() != 1)
+ return rewriter.notifyMatchFailure(extractOp, "not single extract index");
+
+ auto resultType = extractOp.getResult().getType();
+ auto resultVectorType = dyn_cast<VectorType>(resultType);
+ if (!resultVectorType)
+ return rewriter.notifyMatchFailure(extractOp, "result not VectorType");
+
+ auto createMaskOp =
+ extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(extractOp, "source not CreateMaskOp");
+
+ auto maskType = createMaskOp.getVectorType();
+ if (maskType.getRank() != 2 || !maskType.allDimsScalable())
+ return rewriter.notifyMatchFailure(createMaskOp, "not 2-D scalable mask");
+
+ auto isSVEPredicateSize = [](int64_t size) {
+ return size > 0 && size <= 16 && llvm::isPowerOf2_32(uint32_t(size));
+ };
+
+ auto rowsBaseSize = maskType.getDimSize(0);
+ auto colsBaseSize = maskType.getDimSize(1);
+ if (!isSVEPredicateSize(rowsBaseSize) || !isSVEPredicateSize(colsBaseSize))
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "mask dimensions not SVE predicate-sized");
+
+ auto loc = extractOp.getLoc();
+ VectorType rowMaskType = VectorType::Builder(maskType).dropDim(1);
+ VectorType colMaskType = VectorType::Builder(maskType).dropDim(0);
+
+ // Create the two 1-D masks at the location of the 2-D create_mask (which is
+ // usually outside a loop). This prevents the need for later hoisting.
+ rewriter.setInsertionPoint(createMaskOp);
+ auto rowMask = rewriter.create<vector::CreateMaskOp>(
+ loc, rowMaskType, createMaskOp.getOperand(0));
+ auto colMask = rewriter.create<vector::CreateMaskOp>(
+ loc, colMaskType, createMaskOp.getOperand(1));
+
+ rewriter.setInsertionPoint(extractOp);
+ auto position =
+ vector::getAsValues(rewriter, loc, extractOp.getMixedPosition());
+ rewriter.replaceOpWithNewOp<arm_sve::PselOp>(extractOp, colMask, rowMask,
+ position[0]);
+ return success();
+ }
+};
+
/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
/// `arm_sme.move_tile_slice_to_vector`.
///
@@ -728,7 +800,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
TransferReadToArmSMELowering, TransferWriteToArmSMELowering,
TransposeOpToArmSMELowering, VectorLoadToArmSMELowering,
VectorStoreToArmSMELowering, VectorOuterProductToArmSMELowering,
- VectorExtractToArmSMELowering, VectorInsertToArmSMELowering,
- VectorPrintToArmSMELowering, FoldTransferWriteOfExtractTileSlice>(
- &ctx);
+ VectorExtractToArmSMELowering, VectorExtractFromMaskToPselLowering,
+ VectorInsertToArmSMELowering, VectorPrintToArmSMELowering,
+ FoldTransferWriteOfExtractTileSlice>(&ctx);
}
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
index 2601f31be11a3..cc00bf4ca190a 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp
@@ -9,6 +9,7 @@
#include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
diff --git a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
index 8ed52cde784ce..ff7b4bcb5f65a 100644
--- a/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/unsupported.mlir
@@ -192,3 +192,54 @@ func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vecto
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
+
+// -----
+
+/// Not SVE predicate-sized.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_0
+func.func @negative_vector_extract_to_psel_0(%a: index, %b: index, %index: index) -> vector<[32]xi1>
+{
+ // CHECK-NOT: arm_sve.psel
+ %mask = vector.create_mask %a, %b : vector<[4]x[32]xi1>
+ %slice = vector.extract %mask[%index] : vector<[32]xi1> from vector<[4]x[32]xi1>
+ return %slice : vector<[32]xi1>
+}
+
+// -----
+
+/// Source not 2-D scalable mask.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_1
+func.func @negative_vector_extract_to_psel_1(%a: index, %b: index, %index: index) -> vector<[8]xi1>
+{
+ // CHECK-NOT: arm_sve.psel
+ %mask = vector.create_mask %a, %b : vector<4x[8]xi1>
+ %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<4x[8]xi1>
+ return %slice : vector<[8]xi1>
+}
+
+// -----
+
+/// Source not vector.create_mask.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_2
+func.func @negative_vector_extract_to_psel_2(%mask: vector<[4]x[8]xi1>, %index: index) -> vector<[8]xi1>
+{
+ // CHECK-NOT: arm_sve.psel
+ %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+ return %slice : vector<[8]xi1>
+}
+
+// -----
+
+/// Not psel-like extract.
+
+// CHECK-LABEL: @negative_vector_extract_to_psel_3
+func.func @negative_vector_extract_to_psel_3(%a: index, %b: index, %index: index) -> i1
+{
+ // CHECK-NOT: arm_sve.psel
+ %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+ %el = vector.extract %mask[2, %index] : i1 from vector<[4]x[8]xi1>
+ return %el : i1
+}
diff --git a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
index 8aeffb066de90..ff21c70b2aa55 100644
--- a/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
+++ b/mlir/test/Conversion/VectorToArmSME/vector-to-arm-sme.mlir
@@ -1320,3 +1320,35 @@ func.func @vector_extract_element_f64(%row: index, %col: index) -> f64 {
%el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64>
return %el : f64
}
+
+// -----
+
+// CHECK-LABEL: @dynamic_vector_extract_mask_to_psel(
+// CHECK-SAME: %[[A:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[B:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: index)
+func.func @dynamic_vector_extract_mask_to_psel(%a: index, %b: index, %index: index) -> vector<[8]xi1>
+{
+ // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[4]xi1>
+ // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[8]xi1>
+ // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[INDEX]]] : vector<[8]xi1>, vector<[4]xi1>
+ %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1>
+ %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1>
+ return %slice : vector<[8]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_mask_to_psel(
+// CHECK-SAME: %[[A:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[B:[a-z0-9]+]]: index)
+func.func @vector_extract_mask_to_psel(%a: index, %b: index) -> vector<[2]xi1>
+{
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
+ // CHECK: %[[MASK_ROWS:.*]] = vector.create_mask %[[A]] : vector<[16]xi1>
+ // CHECK: %[[MASK_COLS:.*]] = vector.create_mask %[[B]] : vector<[2]xi1>
+ // CHECK: arm_sve.psel %[[MASK_COLS]], %[[MASK_ROWS]][%[[C1]]] : vector<[2]xi1>, vector<[16]xi1>
+ %mask = vector.create_mask %a, %b : vector<[16]x[2]xi1>
+ %slice = vector.extract %mask[1] : vector<[2]xi1> from vector<[16]x[2]xi1>
+ return %slice : vector<[2]xi1>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
few minor nits, otherwise LGTM cheers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, modulo a couple of small requests. Thanks!
…vm#96066) Example: ```mlir %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> ``` Becomes: ```mlir %mask_rows = vector.create_mask %a : vector<[4]xi1> %mask_cols = vector.create_mask %b : vector<[8]xi1> %slice = arm_sve.psel %mask_cols, %mask_rows[%index] : vector<[8]xi1>, vector<[4]xi1> ``` Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this is currently the most logical place for this lowering.
…vm#96066) Example: ```mlir %mask = vector.create_mask %a, %b : vector<[4]x[8]xi1> %slice = vector.extract %mask[%index] : vector<[8]xi1> from vector<[4]x[8]xi1> ``` Becomes: ```mlir %mask_rows = vector.create_mask %a : vector<[4]xi1> %mask_cols = vector.create_mask %b : vector<[8]xi1> %slice = arm_sve.psel %mask_cols, %mask_rows[%index] : vector<[8]xi1>, vector<[4]xi1> ``` Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this is currently the most logical place for this lowering.
Example:
Becomes:
Note: While psel is under ArmSVE it requires SME (or SVE 2.1), so this is currently the most logical place for this lowering.