Skip to content

Commit c2dea71

Browse files
authored
[mlir][ArmSME] Fold extracts from 3D create_masks of SME-like masks (#80148)
When unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks. i.e.: ```mlir %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> %subMask = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> ``` ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below. ```mlir %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> ```
1 parent 84c8d03 commit c2dea71

File tree

2 files changed

+116
-3
lines changed

2 files changed

+116
-3
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This pass legalizes vector operations so they can be lowered to ArmSME.
10-
// Currently, this only implements the decomposition of vector operations that
11-
// use vector sizes larger than an SME tile, into multiple SME-sized operations.
1210
//
1311
// Note: In the context of this pass 'tile' always refers to an SME tile.
1412
//
1513
//===----------------------------------------------------------------------===//
1614

15+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1716
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
1817
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
1918
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
3534

3635
namespace {
3736

37+
//===----------------------------------------------------------------------===//
38+
// Decomposition of vector operations larger than an SME tile
39+
//===----------------------------------------------------------------------===//
40+
3841
// Common match failure reasons.
3942
static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
4043
"op vector size is not multiple of SME tiles");
@@ -338,13 +341,86 @@ struct LegalizeTransferWriteOpsByDecomposition
338341
}
339342
};
340343

344+
//===----------------------------------------------------------------------===//
345+
// ArmSME-specific fixup canonicalizations/folds
346+
//===----------------------------------------------------------------------===//
347+
348+
/// Folds an extract from a 3D `vector.create_mask` (which is a vector of
349+
/// SME-like masks), into a compare and a 2D `vector.create_mask`. This is
350+
/// necessary for the mask to be lowered to ArmSME.
351+
///
352+
/// Example:
353+
///
354+
/// BEFORE:
355+
/// ```mlir
356+
/// %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
357+
/// %subMask = vector.extract %mask[2]
358+
/// : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
359+
/// ```
360+
///
361+
/// AFTER:
362+
/// ```mlir
363+
/// %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
364+
/// %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
365+
/// %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
366+
/// ```
367+
struct FoldExtractFromVectorOfSMELikeCreateMasks
368+
: public OpRewritePattern<vector::ExtractOp> {
369+
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
370+
371+
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
372+
PatternRewriter &rewriter) const override {
373+
auto loc = extractOp.getLoc();
374+
auto createMaskOp =
375+
extractOp.getVector().getDefiningOp<vector::CreateMaskOp>();
376+
if (!createMaskOp)
377+
return rewriter.notifyMatchFailure(
378+
extractOp, "extract not from vector.create_mask op");
379+
380+
VectorType extractedMaskType =
381+
llvm::dyn_cast<VectorType>(extractOp.getResult().getType());
382+
if (!extractedMaskType)
383+
return rewriter.notifyMatchFailure(extractOp,
384+
"extracted type is not a vector type");
385+
386+
auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
387+
if (numScalable != 2)
388+
return rewriter.notifyMatchFailure(
389+
extractOp, "expected extracted type to be an SME-like mask");
390+
391+
// TODO: Support multiple extraction indices.
392+
if (extractOp.getStaticPosition().size() != 1)
393+
return rewriter.notifyMatchFailure(
394+
extractOp, "only a single extraction index is supported");
395+
396+
auto frontMaskDim = createMaskOp.getOperand(0);
397+
if (frontMaskDim.getDefiningOp<arith::ConstantOp>())
398+
return rewriter.notifyMatchFailure(
399+
extractOp,
400+
"constant vector.create_masks dims should be folded elsewhere");
401+
402+
auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
403+
auto extractionIndex = getValueOrCreateConstantIndexOp(
404+
rewriter, loc, extractOp.getMixedPosition()[0]);
405+
auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>(
406+
loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex,
407+
frontMaskDim);
408+
auto newMaskFrontDim = rewriter.create<arith::SelectOp>(
409+
loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero);
410+
411+
rewriter.replaceOpWithNewOp<vector::CreateMaskOp>(
412+
extractOp, extractedMaskType,
413+
ValueRange{newMaskFrontDim, createMaskOp.getOperand(2)});
414+
return success();
415+
}
416+
};
417+
341418
struct VectorLegalizationPass
342419
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
343420
void runOnOperation() override {
344421
auto *context = &getContext();
345422
OneToNTypeConverter converter;
346423
RewritePatternSet patterns(context);
347-
348424
converter.addConversion([](Type type) { return type; });
349425
converter.addConversion(
350426
[](VectorType vectorType,
@@ -358,6 +434,7 @@ struct VectorLegalizationPass
358434
return success();
359435
});
360436

437+
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks>(context);
361438
// Note: High benefit to ensure masked outer products are lowered first.
362439
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
363440
converter, context, 1024);

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,39 @@ func.func @transpose_f32_scalable_4x16_via_write(%src: memref<?x?xf32>, %dest: m
266266
vector.transfer_write %0, %dest[%c0, %c0] {permutation_map = #transpose, in_bounds = [true, true]} : vector<[4]x[16]xf32>, memref<?x?xf32>
267267
return
268268
}
269+
270+
// -----
271+
272+
// CHECK-LABEL: @extract_from_vector_create_mask_non_constant_dim(
273+
// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
274+
// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
275+
// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index)
276+
func.func @extract_from_vector_create_mask_non_constant_dim(%dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
277+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
278+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
279+
// CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi sgt, %[[DIM0]], %[[C2]] : index
280+
// CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
281+
// CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
282+
// CHECK-NEXT: return %[[EXTRACT]]
283+
%mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
284+
%extract = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
285+
return %extract : vector<[4]x[4]xi1>
286+
}
287+
288+
// -----
289+
290+
// CHECK-LABEL: @non_constant_extract_from_vector_create_mask_non_constant(
291+
// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: index,
292+
// CHECK-SAME: %[[DIM0:[a-z0-9]+]]: index,
293+
// CHECK-SAME: %[[DIM1:[a-z0-9]+]]: index,
294+
// CHECK-SAME: %[[DIM2:[a-z0-9]+]]: index)
295+
func.func @non_constant_extract_from_vector_create_mask_non_constant(%index: index, %dim0: index, %dim1: index, %dim2: index) -> vector<[4]x[4]xi1> {
296+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
297+
// CHECK-NEXT: %[[DIM0_CMP:.*]] = arith.cmpi slt, %[[INDEX]], %[[DIM0]] : index
298+
// CHECK-NEXT: %[[NEW_DIM0:.*]] = arith.select %[[DIM0_CMP]], %[[DIM1]], %[[C0]] : index
299+
// CHECK-NEXT: %[[EXTRACT:.*]] = vector.create_mask %[[NEW_DIM0]], %[[DIM2]] : vector<[4]x[4]xi1>
300+
// CHECK-NEXT: return %[[EXTRACT]]
301+
%mask = vector.create_mask %dim0, %dim1, %dim2 : vector<4x[4]x[4]xi1>
302+
%extract = vector.extract %mask[%index] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
303+
return %extract : vector<[4]x[4]xi1>
304+
}

0 commit comments

Comments
 (0)