Skip to content

[mlir][ArmSME] Fold extracts from 3D create_masks of SME-like masks #80148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
//===----------------------------------------------------------------------===//
//
// This pass legalizes vector operations so they can be lowered to ArmSME.
// Currently, this only implements the decomposition of vector operations that
// use vector sizes larger than an SME tile, into multiple SME-sized operations.
//
// Note: In the context of this pass 'tile' always refers to an SME tile.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
Expand All @@ -35,6 +34,10 @@ using namespace mlir::arm_sme;

namespace {

//===----------------------------------------------------------------------===//
// Decomposition of vector operations larger than an SME tile
//===----------------------------------------------------------------------===//

// Common match failure reasons.
static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
"op vector size is not multiple of SME tiles");
Expand Down Expand Up @@ -338,13 +341,86 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};

//===----------------------------------------------------------------------===//
// ArmSME-specific fixup canonicalizations/folds
//===----------------------------------------------------------------------===//

/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
/// necessary for the mask to be lowered to ArmSME.
///
/// Example:
///
/// BEFORE:
/// ```mlir
/// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
/// %subMask = vector.extract %mask[2]
/// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
/// ```
///
/// AFTER:
/// ```mlir
/// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
/// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
/// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
/// ```
struct FoldExtractFromVectorOfSMELikeCreateMasks
: public OpRewritePattern<vector::ExtractOp> {
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
auto loc = extractOp.getLoc();
auto createMaskOp =
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
if (!createMaskOp)
return rewriter.notifyMatchFailure(
extractOp, "extract not from vector.create_mask op");

VectorType extractedMaskType =
llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
if (!extractedMaskType)
return rewriter.notifyMatchFailure(extractOp,
"extracted type is not a vector type");

auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
if (numScalable != 2)
return rewriter.notifyMatchFailure(
extractOp, "expected extracted type to be an SME-like mask");

// TODO: Support multiple extraction indices.
if (extractOp.getStaticPosition().size() != 1)
return rewriter.notifyMatchFailure(
extractOp, "only a single extraction index is supported");

auto frontMaskDim = createMaskOp.getOperand(0);
if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
return rewriter.notifyMatchFailure(
extractOp,
"constant vector.create_masks dims should be folded elsewhere");

auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto extractionIndex = getValueOrCreateConstantIndexOp(
rewriter, loc, extractOp.getMixedPosition()[0]);
auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
frontMaskDim);
auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);

rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
extractOp, extractedMaskType,
ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
return success();
}
};

struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
auto *context = &getContext();
OneToNTypeConverter converter;
RewritePatternSet patterns(context);

converter.addConversion([](Type type) { return type; });
converter.addConversion(
[](VectorType vectorType,
Expand All @@ -358,6 +434,7 @@ struct VectorLegalizationPass
return success();
});

patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks>(context);
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/ArmSME/vector-legalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,39 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
return
}

// -----

// CHECK-LABEL: @extract_from_vector_create_mask_non_constant_dim(
// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index)
func.func @extract_from_vector_create_mask_non_constant_dim(%dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi sgt, %[[DIM0]], %[[C2]] : index
// CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
// CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
// CHECK-NEXT: return %[[EXTRACT]]
%mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
%extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
return %extract : vector<[4]x[4]xi1>
}

// -----

// CHECK-LABEL: @non_constant_extract_from_vector_create_mask_non_constant(
// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: index,
// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index)
func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: index, %dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi slt, %[[INDEX]], %[[DIM0]] : index
// CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
// CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
// CHECK-NEXT: return %[[EXTRACT]]
%mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
%extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
return %extract : vector<[4]x[4]xi1>
}