Skip to content

Commit f1f18d6

Browse files
MacDueAlexisPerry
authored andcommitted
[mlir][ArmSME] Lower multi-tile stores to a single loop (llvm#96187)
This adds a new pattern that can legalize a multi-tile transfer_write as a single store loop. This is done as part of type decomposition as at this level we know each tile write is disjoint, but that information is lost after decomposition (without analysis to reconstruct it). Example (pseudo-MLIR): ``` vector.transfer_write %vector, %dest[%y, %x], %mask : vector<[16]x[8]xi16>, memref<?x?xi16> ``` Is rewritten to: ``` scf.for %slice_idx = %c0 to %c8_vscale step %c1 { %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐ : vector<[8]xi1> from vector<[16]x[8]xi1> | %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile : vector<[8]xi16> from vector<[8]x[8]xi16> | vector.transfer_write %upper_slice, | %dest[%slice_idx + %y, %x], %upper_slice_mask | : vector<[8]xi16>, memref<?x?xi16> ┘ %lower_slice_idx = %slice_idx + %c8_vscale ─┐ %lower_slice_mask = vector.extract %mask[%lower_slice_idx] | : vector<[8]xi1> from vector<[16]x[8]xi1> | %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower : vector<[8]xi16> from vector<[8]x[8]xi16> | tile vector.transfer_write %lower_slice, | %dest[%lower_slice_idx + %y, %x], %lower_slice_mask | : vector<[8]xi16>, memref<?x?xi16> ┘ } ```
1 parent 432dd6e commit f1f18d6

File tree

3 files changed

+248
-10
lines changed

3 files changed

+248
-10
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
2020
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
2121
#include "mlir/Dialect/MemRef/IR/MemRef.h"
22+
#include "mlir/Dialect/SCF/IR/SCF.h"
2223
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2324
#include "mlir/Dialect/Utils/IndexingUtils.h"
2425
#include "mlir/Transforms/OneToNTypeConversion.h"
@@ -373,6 +374,139 @@ struct LegalizeTransferWriteOpsByDecomposition
373374
}
374375
};
375376

377+
/// Legalize a multi-tile transfer_write as a single store loop. This is done as
378+
/// part of type decomposition as at this level we know each tile write is
379+
/// disjoint, but that information is lost after decomposition (without analysis
380+
/// to reconstruct it).
381+
///
382+
/// Example (pseudo-MLIR):
383+
///
384+
/// ```
385+
/// vector.transfer_write %vector, %dest[%y, %x], %mask
386+
/// : vector<[16]x[8]xi16>, memref<?x?xi16>
387+
/// ```
388+
/// Is rewritten to:
389+
/// ```
390+
/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
391+
/// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
392+
/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
393+
/// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
394+
/// : vector<[8]xi16> from vector<[8]x[8]xi16> |
395+
/// vector.transfer_write %upper_slice, |
396+
/// %dest[%slice_idx + %y, %x], %upper_slice_mask |
397+
/// : vector<[8]xi16>, memref<?x?xi16> ┘
398+
/// %lower_slice_idx = %slice_idx + %c8_vscale ─┐
399+
/// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
400+
/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
401+
/// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
402+
/// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile
403+
/// vector.transfer_write %lower_slice, |
404+
/// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
405+
/// : vector<[8]xi16>, memref<?x?xi16> ┘
406+
/// }
407+
/// ```
408+
struct LegalizeMultiTileTransferWriteAsStoreLoop
409+
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
410+
using OneToNOpConversionPattern::OneToNOpConversionPattern;
411+
412+
LogicalResult
413+
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
414+
OneToNPatternRewriter &rewriter) const override {
415+
if (writeOp.hasPureTensorSemantics())
416+
return rewriter.notifyMatchFailure(
417+
writeOp, "TODO: tensor semantics are unsupported");
418+
419+
auto permutationMap = writeOp.getPermutationMap();
420+
if (!permutationMap.isPermutation())
421+
return rewriter.notifyMatchFailure(writeOp,
422+
kMatchFailureNonPermutationMap);
423+
424+
bool transposed = !permutationMap.isIdentity();
425+
if (transposed)
426+
return rewriter.notifyMatchFailure(writeOp,
427+
"TODO: transpose unsupported");
428+
429+
auto vectorType = writeOp.getVectorType();
430+
if (!isMultipleOfSMETileVectorType(vectorType))
431+
return rewriter.notifyMatchFailure(writeOp,
432+
kMatchFailureNotSMETileTypeMultiple);
433+
434+
// Note: We also disallow masks where any dimension is > 16 because that
435+
// prevents the masking from being lowered to use arm_sve.psel.
436+
auto mask = writeOp.getMask();
437+
if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
438+
vectorType.getDimSize(1) > 16)))
439+
return rewriter.notifyMatchFailure(writeOp,
440+
kMatchFailureUnsupportedMaskOp);
441+
442+
auto loc = writeOp.getLoc();
443+
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
444+
auto createVscaleMultiple = [&](int64_t multiplier) {
445+
return rewriter.create<arith::MulIOp>(
446+
loc, vscale,
447+
rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
448+
};
449+
450+
// Get SME tile and slice types.
451+
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
452+
auto minTileSlices = smeTileType.getDimSize(0);
453+
VectorType sliceMaskType =
454+
VectorType::get(minTileSlices, rewriter.getI1Type(), true);
455+
456+
// Create loop over all tile slices.
457+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
458+
auto upperBound = createVscaleMultiple(minTileSlices);
459+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
460+
auto storeLoop =
461+
rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
462+
rewriter.setInsertionPointToStart(storeLoop.getBody());
463+
464+
// For each sub-tile of the multi-tile `vectorType`.
465+
auto inputSMETiles = adaptor.getVector();
466+
auto tileSliceIndex = storeLoop.getInductionVar();
467+
for (auto [index, smeTile] : llvm::enumerate(
468+
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
469+
// The coordinates of the tile within `vectorType`.
470+
auto tileRow = createVscaleMultiple(smeTile.row);
471+
auto tileCol = createVscaleMultiple(smeTile.col);
472+
473+
// The current slice of `vectorType` we are processing.
474+
auto sliceIndex =
475+
rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);
476+
477+
// Where in the destination memref the current slice will be stored.
478+
auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
479+
writeOp.getIndices()[0]);
480+
auto storeCol =
481+
rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);
482+
483+
// Extract the mask for the current slice.
484+
Value sliceMask = nullptr;
485+
if (mask) {
486+
sliceMask = rewriter.create<vector::ExtractOp>(
487+
loc, mask, OpFoldResult(sliceIndex));
488+
if (sliceMaskType != sliceMask.getType())
489+
sliceMask = rewriter.create<vector::ScalableExtractOp>(
490+
loc, sliceMaskType, sliceMask, smeTile.col);
491+
}
492+
493+
// Extract and store the current slice.
494+
Value tile = inputSMETiles[index];
495+
auto slice =
496+
rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
497+
rewriter.create<vector::TransferWriteOp>(
498+
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
499+
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
500+
sliceMask,
501+
rewriter.getBoolArrayAttr(
502+
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
503+
}
504+
505+
rewriter.eraseOp(writeOp);
506+
return success();
507+
}
508+
};
509+
376510
//===----------------------------------------------------------------------===//
377511
// ArmSME-specific fixup canonicalizations/folds
378512
//===----------------------------------------------------------------------===//
@@ -663,9 +797,12 @@ struct VectorLegalizationPass
663797
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
664798
LiftIllegalVectorTransposeToMemory,
665799
ConvertIllegalShapeCastOpsToTransposes>(context);
666-
// Note: High benefit to ensure masked outer products are lowered first.
667-
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
668-
converter, context, 1024);
800+
// Note: These two patterns are added with a high benefit to ensure:
801+
// - Masked outer products are handled before unmasked ones
802+
// - Multi-tile writes are lowered as a store loop (if possible)
803+
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
804+
LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
805+
/*benefit=*/1024);
669806
patterns.add<LegalizeArithConstantOpsByDecomposition,
670807
LegalizeVectorOuterProductOpsByDecomposition,
671808
LegalizeTransferReadOpsByDecomposition,

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 106 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,17 @@ func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0:
174174
func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
175175
{
176176
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
177+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
177178
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
178179
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
179180
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
180-
// CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181-
// CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
181+
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
182+
// CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
183+
// CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
184+
// CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
185+
// CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
186+
// CHECK-NEXT: vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
187+
// CHECK-NEXT: }
182188
// CHECK-NEXT: return
183189
%c0 = arith.constant 0 : index
184190
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
@@ -201,6 +207,90 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
201207

202208
// -----
203209

210+
// CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
211+
// CHECK-SAME: %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
212+
// CHECK-SAME: %[[DIM_0:[a-z0-9]+]]: index,
213+
// CHECK-SAME: %[[DIM_1:[a-z0-9]+]]: index,
214+
// CHECK-SAME: %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
215+
// CHECK-SAME: %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
216+
// CHECK-SAME: %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
217+
// CHECK-SAME: %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
218+
func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
219+
{
220+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
221+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
222+
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
223+
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
224+
// CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
225+
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
226+
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
227+
// CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
228+
// CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
229+
// CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
230+
// CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
231+
// CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
232+
// CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
233+
// CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
234+
// CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
235+
// CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
236+
// CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
237+
// CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
238+
// CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
239+
// CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
240+
// CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
241+
// CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
242+
// CHECK-NEXT: }
243+
%c0 = arith.constant 0 : index
244+
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
245+
vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
246+
return
247+
}
248+
249+
// -----
250+
251+
// Tensor semantics are not supported for the store loop lowering.
252+
253+
// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
254+
// CHECK-NOT: scf.for
255+
func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %vec: vector<[8]x[8]xf32>)
256+
{
257+
%c0 = arith.constant 0 : index
258+
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
259+
return
260+
}
261+
262+
// -----
263+
264+
#transpose = affine_map<(d0, d1) -> (d1, d0)>
265+
266+
// Transposes are not supported for the store loop lowering.
267+
268+
// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
269+
// CHECK-NOT: scf.for
270+
func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
271+
{
272+
%c0 = arith.constant 0 : index
273+
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
274+
vector.transfer_write %vec, %dest[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
275+
return
276+
}
277+
278+
// -----
279+
280+
// Masked writes where any dimension of the mask is > 16 are not supported for the store loop lowering.
281+
282+
// CHECK-LABEL: @negative_transfer_write_f32_scalable_32x32
283+
// CHECK-NOT: scf.for
284+
func.func @negative_transfer_write_f32_scalable_32x32(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[32]x[32]xf32>)
285+
{
286+
%c0 = arith.constant 0 : index
287+
%mask = vector.create_mask %dim0, %dim1 : vector<[32]x[32]xi1>
288+
vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[32]x[32]xf32>, memref<?x?xf32>
289+
return
290+
}
291+
292+
// -----
293+
204294
#transpose = affine_map<(d0, d1) -> (d1, d0)>
205295

206296
// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
@@ -209,6 +299,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
209299
func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
210300
{
211301
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
302+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
212303
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
213304
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
214305
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
@@ -221,10 +312,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
221312
// CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
222313
// CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
223314
// CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
224-
// CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
225-
// CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
226-
// CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
227-
// CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
315+
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
316+
// CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
317+
// CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
318+
// CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
319+
// CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
320+
// CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
321+
// CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
322+
// CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
323+
// CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
324+
// CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
325+
// CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
326+
// CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
327+
// CHECK-NEXT: }
228328
// CHECK-NEXT: return
229329
%c0 = arith.constant 0 : index
230330
%pad = arith.constant 0.0 : f32

mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/multi-tile-matmul.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// RUN: mlir-opt %s \
22
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
33
// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
4-
// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \
4+
// RUN: -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
5+
// RUN: -test-lower-to-llvm | \
56
// RUN: %mcr_aarch64_cmd \
67
// RUN: -e=main -entry-point-result=void \
78
// RUN: -march=aarch64 -mattr="+sve,+sme" \

0 commit comments

Comments
 (0)