-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Fold extracts from 3D create_masks of SME-like masks #80148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks. i.e.: ```mlir %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> %subMask = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> ``` ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below. ```mlir %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> ```
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesWhen unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks. i.e.: %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
%subMask = vector.extract %mask[2]
: vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below. %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
%newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
%subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> Full diff: https://github.com/llvm/llvm-project/pull/80148.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 85ec53c2618aa..14b9d8e34da65 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -7,13 +7,12 @@
//===----------------------------------------------------------------------===//
//
// This pass legalizes vector operations so they can be lowered to ArmSME.
-// Currently, this only implements the decomposition of vector operations that
-// use vector sizes larger than an SME tile, into multiple SME-sized operations.
//
// Note: In the context of this pass 'tile' always refers to an SME tile.
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
namespace {
+//===----------------------------------------------------------------------===//
+// Decomposition of vector operations larger than an SME tile
+//===----------------------------------------------------------------------===//
+
// Common match failure reasons.
static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
"op vector size is not multiple of SME tiles");
@@ -338,13 +341,86 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};
+//===----------------------------------------------------------------------===//
+// ArmSME-specific fixup canonicalizations/folds
+//===----------------------------------------------------------------------===//
+
+/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
+/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
+/// necessary for the mask to be lowered to ArmSME.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
+/// %subMask = vector.extract %mask[2]
+/// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
+/// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
+/// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
+/// ```
+struct FoldExtractFromVectorOfSMELikeCreateMasks
+ : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ auto loc = extractOp.getLoc();
+ auto createMaskOp =
+ extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ extractOp, "extract not from vector.create_mask op");
+
+ VectorType extractedMaskType =
+ llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
+ if (!extractedMaskType)
+ return rewriter.notifyMatchFailure(extractOp,
+ "extracted type is not a vector type");
+
+ auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
+ if (numScalable != 2)
+ return rewriter.notifyMatchFailure(
+ extractOp, "expected extracted type to be an SME-like mask");
+
+ // TODO: Support multiple extraction indices.
+ if (extractOp.getStaticPosition().size() != 1)
+ return rewriter.notifyMatchFailure(
+ extractOp, "only a single extraction index is supported");
+
+ auto frontMaskDim = createMaskOp.getOperand(0);
+ if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
+ return rewriter.notifyMatchFailure(
+ extractOp,
+ "constant vector.create_masks dims should be folded elsewhere");
+
+ auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto extractionIndex = getValueOrCreateConstantIndexOp(
+ rewriter, loc, extractOp.getMixedPosition()[0]);
+ auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
+ loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
+ frontMaskDim);
+ auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
+ loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
+
+ rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
+ extractOp, extractedMaskType,
+ ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
OneToNTypeConverter converter;
RewritePatternSet patterns(context);
-
converter.addConversion([](Type type) { return type; });
converter.addConversion(
[](VectorType vectorType,
@@ -358,6 +434,7 @@ struct VectorLegalizationPass
return success();
});
+ patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks>(context);
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index a20abeefedcfd..a2526db9b4831 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -266,3 +266,39 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
return
}
+
+// -----
+
+// CHECK-LABEL: @extract_from_vector_create_mask_non_constant_dim(
+// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index)
+func.func @extract_from_vector_create_mask_non_constant_dim(%dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi sgt, %[[DIM0]], %[[C2]] : index
+ // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
+ // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
+ // CHECK-NEXT: return %[[EXTRACT]]
+ %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
+ %extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+ return %extract : vector<[4]x[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @non_constant_extract_from_vector_create_mask_non_constant(
+// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index)
+func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: index, %dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi slt, %[[INDEX]], %[[DIM0]] : index
+ // CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
+ // CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
+ // CHECK-NEXT: return %[[EXTRACT]]
+ %mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
+ %extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
+ return %extract : vector<[4]x[4]xi1>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Functionally the rewrite looks good to me, but I am wondering if it should be moved alongside Vector dialect transforms. I know this is only relevant for ArmSME right now but semantically these operations and types are expressible in the Vector dialect and if we pretended there were another target where 2-D scalable vector types were relevant then this would apply equally as well I think?
I don't want to muddy the vector dialect with rewrites like this, It also makes things simpler to have everything in one place when it comes to integrating stuff into IREE. |
I thought the same for code I specifically wrote for SVE in the past but it turned out to not be the case :) But I do understand it's not obvious this will be generally useful, and probably more so than previous examples I'm referring to. There is a ton of common infrastructure we're benefiting from here and it's good to pay it forward where we can, but finding that line isn't always easy.
Hadn't considered IREE integration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, cheers
…lvm#80148) When unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks. i.e.: ```mlir %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> %subMask = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> ``` ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below. ```mlir %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> ```
…mlir This tests both llvm#80148 and llvm#80170 work together to allow unrolling the reduction dimension of a matmul.
When unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks.
i.e.:
ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below.