Skip to content

Commit 3bee1a1

Browse files
committed
[mlir][ArmSME] Lower multi-tile stores to a single loop
This adds a new pattern that can legalize a multi-tile transfer_write as a single store loop. This is done as part of type decomposition as at this level we know each tile write is disjoint, but that information is lost after decomposition (without analysis to reconstruct it). Example (in pseudo-MLIR): ``` vector.transfer_write vector, dest[x, y], mask : vector<[16]x[4]xf32>, memref<?x?xf32> ``` Is rewritten to: ``` for i in range (0, 4 * vscale) { let sliceRow = i + tile_n.row * vscale; ─┐ let sliceCol = tile_n.col * vscale; | slice = vector.extract tile_n[i] | : vector<[4]xf32> from vector<[16]x[4]xf32> | slice_mask = vector.extract mask[sliceRow] |- Repeated 4x for : vector<[4]xi1> from vector<[16]x[4]xi1> | all tiles in vector.transfer_write | [16]x[4] slice, dest[x + sliceRow, y + sliceCol], slice_mask | : vector<[4]xf32>, memref<?x?xf32> ┘ } ```
1 parent 80f8814 commit 3bee1a1

File tree

3 files changed

+196
-10
lines changed

3 files changed

+196
-10
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
2121
#include "mlir/Dialect/MemRef/IR/MemRef.h"
22+
#include "mlir/Dialect/SCF/IR/SCF.h"
2223
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2324
#include "mlir/Dialect/Utils/IndexingUtils.h"
2425
#include "mlir/Transforms/OneToNTypeConversion.h"
@@ -373,6 +374,130 @@ struct LegalizeTransferWriteOpsByDecomposition
373374
}
374375
};
375376

377+
/// Legalize a multi-tile transfer_write as a single store loop. This is done as
378+
/// part of type decomposition as at this level we know each tile write is
379+
/// disjoint, but that information is lost after decomposition (without
380+
/// static analysis).
381+
///
382+
/// Example (in pseudo-MLIR):
383+
///
384+
/// ```
385+
/// vector.transfer_write vector, dest[x, y], mask
386+
/// : vector<[16]x[4]xf32>, memref<?x?xf32>
387+
/// ```
388+
/// Is rewritten to:
389+
/// ```
390+
/// for i in range (0, 4 * vscale) {
391+
/// let sliceRow = i + tile_n.row * vscale; ─┐
392+
/// let sliceCol = tile_n.col * vscale; |
393+
/// slice = vector.extract tile_n[i] |
394+
/// : vector<[4]xf32> from vector<[16]x[4]xf32> |
395+
/// slice_mask = vector.extract mask[sliceRow] |- Repeated 4x for
396+
/// : vector<[4]xi1> from vector<[16]x[4]xi1> | all tiles in
397+
/// vector.transfer_write | [16]x[4]
398+
/// slice, dest[x + sliceRow, y + sliceCol], slice_mask |
399+
/// : vector<[4]xf32>, memref<?x?xf32> ┘
400+
/// }
401+
/// ```
402+
struct LegalizeMultiTileTransferWriteAsStoreLoop
403+
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
404+
using OneToNOpConversionPattern::OneToNOpConversionPattern;
405+
406+
LogicalResult
407+
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
408+
OneToNPatternRewriter &rewriter) const override {
409+
if (writeOp.hasPureTensorSemantics())
410+
return rewriter.notifyMatchFailure(
411+
writeOp, "TODO: tensor semantics are unsupported");
412+
413+
auto permutationMap = writeOp.getPermutationMap();
414+
if (!permutationMap.isPermutation())
415+
return rewriter.notifyMatchFailure(writeOp,
416+
kMatchFailureNonPermutationMap);
417+
418+
bool transposed = !permutationMap.isIdentity();
419+
if (transposed)
420+
return rewriter.notifyMatchFailure(writeOp,
421+
"TODO: transpose unsupported");
422+
423+
auto vectorType = writeOp.getVectorType();
424+
if (!isMultipleOfSMETileVectorType(vectorType))
425+
return rewriter.notifyMatchFailure(writeOp,
426+
kMatchFailureNotSMETileTypeMultiple);
427+
428+
auto mask = writeOp.getMask();
429+
if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
430+
vectorType.getDimSize(1) > 16)))
431+
return rewriter.notifyMatchFailure(writeOp,
432+
kMatchFailureUnsupportedMaskOp);
433+
434+
auto loc = writeOp.getLoc();
435+
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
436+
auto createVscaleMultiple = [&](int64_t multiplier) {
437+
return rewriter.create<arith::MulIOp>(
438+
loc, vscale,
439+
rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
440+
};
441+
442+
// Get SME tile and slice types.
443+
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
444+
auto minTileSlices = smeTileType.getDimSize(0);
445+
VectorType sliceMaskType =
446+
VectorType::get(minTileSlices, rewriter.getI1Type(), true);
447+
448+
// Create loop over all tile slices.
449+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
450+
auto upperBound = createVscaleMultiple(minTileSlices);
451+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
452+
auto storeLoop =
453+
rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
454+
rewriter.setInsertionPointToStart(storeLoop.getBody());
455+
456+
// For each tile sub-tile of the multi-tile `vectorType`.
457+
auto inputSMETiles = adaptor.getVector();
458+
auto inductionVar = storeLoop.getInductionVar();
459+
for (auto [index, smeTile] : llvm::enumerate(
460+
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
461+
// The coordinates of the tile within `vectorType`.
462+
auto tileRow = createVscaleMultiple(smeTile.row);
463+
auto tileCol = createVscaleMultiple(smeTile.col);
464+
465+
// The current slice of `vectorType` we are processing.
466+
auto sliceIndex =
467+
rewriter.create<arith::AddIOp>(loc, tileRow, inductionVar);
468+
469+
// Where in the destination memref the current slice will be stored.
470+
auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
471+
writeOp.getIndices()[0]);
472+
auto storeCol =
473+
rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
474+
475+
// Extract the mask for the current slice.
476+
Value sliceMask = nullptr;
477+
if (mask) {
478+
sliceMask = rewriter.create<vector::ExtractOp>(
479+
loc, mask, OpFoldResult(sliceIndex));
480+
if (sliceMaskType != sliceMask.getType())
481+
sliceMask = rewriter.create<vector::ScalableExtractOp>(
482+
loc, sliceMaskType, sliceMask, smeTile.col);
483+
}
484+
485+
// Extract and store the current slice slice.
486+
Value tile = inputSMETiles[index];
487+
auto slice = rewriter.create<vector::ExtractOp>(loc, tile, inductionVar);
488+
rewriter.create<vector::TransferWriteOp>(
489+
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
490+
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
491+
sliceMask,
492+
rewriter.getBoolArrayAttr(
493+
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
494+
}
495+
496+
rewriter.eraseOp(writeOp);
497+
return success();
498+
}
499+
};
500+
376501
//===----------------------------------------------------------------------===//
377502
// ArmSME-specific fixup canonicalizations/folds
378503
//===----------------------------------------------------------------------===//
@@ -663,9 +788,12 @@ struct VectorLegalizationPass
663788
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
664789
LiftIllegalVectorTransposeToMemory,
665790
ConvertIllegalShapeCastOpsToTransposes>(context);
666-
// Note: High benefit to ensure masked outer products are lowered first.
667-
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
668-
converter, context, 1024);
791+
// Note: These two patterns are added with a high benefit to ensure:
792+
// - Masked outer products are handled before unmasked ones
793+
// - Multi-tile writes are lowered as a store loop (if possible)
794+
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
795+
LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
796+
/*benefit=*/1024);
669797
patterns.add<LegalizeArithConstantOpsByDecomposition,
670798
LegalizeVectorOuterProductOpsByDecomposition,
671799
LegalizeTransferReadOpsByDecomposition,

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,17 @@ func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0:
174174
func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
175175
{
176176
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
177+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
177178
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
178179
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
179180
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
180-
// CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181-
// CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181+
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
182+
// CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
183+
// CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
184+
// CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
185+
// CHECK-NEXT: %[[BOTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
186+
// CHECK-NEXT: vector.transfer_write %[[BOTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
187+
// CHECK-NEXT: }
182188
// CHECK-NEXT: return
183189
%c0 = arith.constant 0 : index
184190
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
@@ -201,6 +207,47 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
201207

202208
// -----
203209

210+
// CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
211+
// CHECK-SAME: %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
212+
// CHECK-SAME: %[[DIM_0:[a-z0-9]+]]: index,
213+
// CHECK-SAME: %[[DIM_1:[a-z0-9]+]]: index,
214+
// CHECK-SAME: %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
215+
// CHECK-SAME: %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
216+
// CHECK-SAME: %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
217+
// CHECK-SAME: %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
218+
func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
219+
{
220+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
221+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
222+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
223+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
224+
// CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
225+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
226+
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
227+
// CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
228+
// CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
229+
// CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
230+
// CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
231+
// CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
232+
// CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
233+
// CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
234+
// CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
235+
// CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
236+
// CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
237+
// CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
238+
// CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
239+
// CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
240+
// CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
241+
// CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
242+
// CHECK-NEXT: }
243+
%c0 = arith.constant 0 : index
244+
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
245+
vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
246+
return
247+
}
248+
249+
// -----
250+
204251
#transpose = affine_map<(d0, d1) -> (d1, d0)>
205252

206253
// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
@@ -209,6 +256,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
209256
func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
210257
{
211258
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
259+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
212260
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
213261
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
214262
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
@@ -221,10 +269,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
221269
// CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
222270
// CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
223271
// CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
224-
// CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
225-
// CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
226-
// CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
227-
// CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
272+
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
273+
// CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
274+
// CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
275+
// CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
276+
// CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
277+
// CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
278+
// CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
279+
// CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
280+
// CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
281+
// CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
282+
// CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
283+
// CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
284+
// CHECK-NEXT: }
228285
// CHECK-NEXT: return
229286
%c0 = arith.constant 0 : index
230287
%pad = arith.constant 0.0 : f32

mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// RUN: mlir-opt %s \
22
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
33
// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
4-
// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \
4+
// RUN: -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
5+
// RUN: -test-lower-to-llvm | \
56
// RUN: %mcr_aarch64_cmd \
67
// RUN: -e=main -entry-point-result=void \
78
// RUN: -march=aarch64 -mattr="+sve,+sme" \

0 commit comments

Comments
 (0)