Skip to content

[mlir][ArmSME] Lower multi-tile stores to a single loop #96187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 140 additions & 3 deletions mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
Expand Down Expand Up @@ -373,6 +374,139 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};

/// Legalize a multi-tile transfer_write as a single store loop. This is done as
/// part of type decomposition as at this level we know each tile write is
/// disjoint, but that information is lost after decomposition (without analysis
/// to reconstruct it).
///
/// Example (pseudo-MLIR):
///
/// ```
/// vector.transfer_write %vector, %dest[%y, %x], %mask
/// : vector<[16]x[8]xi16>, memref<?x?xi16>
/// ```
/// Is rewritten to:
/// ```
/// scf.for %slice_idx = %c0 to %c8_vscale step %c1 {
/// %upper_slice_mask = vector.extract %mask[%slice_idx] ─┐
/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
/// %upper_slice = vector.extract %upper_tile[%slice_idx] |- Store upper tile
/// : vector<[8]xi16> from vector<[8]x[8]xi16> |
/// vector.transfer_write %upper_slice, |
/// %dest[%slice_idx + %y, %x], %upper_slice_mask |
/// : vector<[8]xi16>, memref<?x?xi16> ┘
/// %lower_slice_idx = %slice_idx + %c8_vscale ─┐
/// %lower_slice_mask = vector.extract %mask[%lower_slice_idx] |
/// : vector<[8]xi1> from vector<[16]x[8]xi1> |
/// %lower_slice = vector.extract %lower_tile[%slice_idx] |- Store lower
/// : vector<[8]xi16> from vector<[8]x[8]xi16> | tile
/// vector.transfer_write %lower_slice, |
/// %dest[%lower_slice_idx + %y, %x], %lower_slice_mask |
/// : vector<[8]xi16>, memref<?x?xi16> ┘
/// }
/// ```
struct LegalizeMultiTileTransferWriteAsStoreLoop
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
if (writeOp.hasPureTensorSemantics())
return rewriter.notifyMatchFailure(
writeOp, "TODO: tensor semantics are unsupported");

auto permutationMap = writeOp.getPermutationMap();
if (!permutationMap.isPermutation())
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureNonPermutationMap);

bool transposed = !permutationMap.isIdentity();
if (transposed)
return rewriter.notifyMatchFailure(writeOp,
"TODO: transpose unsupported");

auto vectorType = writeOp.getVectorType();
if (!isMultipleOfSMETileVectorType(vectorType))
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureNotSMETileTypeMultiple);

// Note: We also disallow masks where any dimension is > 16 because that
// prevents the masking from being lowered to use arm_sve.psel.
auto mask = writeOp.getMask();
if (!isSupportedMaskOp(mask) || (mask && (vectorType.getDimSize(0) > 16 ||
vectorType.getDimSize(1) > 16)))
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureUnsupportedMaskOp);

auto loc = writeOp.getLoc();
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
auto createVscaleMultiple = [&](int64_t multiplier) {
return rewriter.create<arith::MulIOp>(
loc, vscale,
rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
};

// Get SME tile and slice types.
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
auto minTileSlices = smeTileType.getDimSize(0);
VectorType sliceMaskType =
VectorType::get(minTileSlices, rewriter.getI1Type(), true);

// Create loop over all tile slices.
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto upperBound = createVscaleMultiple(minTileSlices);
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto storeLoop =
rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
rewriter.setInsertionPointToStart(storeLoop.getBody());

// For each sub-tile of the multi-tile `vectorType`.
auto inputSMETiles = adaptor.getVector();
auto tileSliceIndex = storeLoop.getInductionVar();
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
// The coordinates of the tile within `vectorType`.
auto tileRow = createVscaleMultiple(smeTile.row);
auto tileCol = createVscaleMultiple(smeTile.col);

// The current slice of `vectorType` we are processing.
auto sliceIndex =
rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex);

// Where in the destination memref the current slice will be stored.
auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex,
writeOp.getIndices()[0]);
auto storeCol =
rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]);

// Extract the mask for the current slice.
Value sliceMask = nullptr;
if (mask) {
sliceMask = rewriter.create<vector::ExtractOp>(
loc, mask, OpFoldResult(sliceIndex));
if (sliceMaskType != sliceMask.getType())
sliceMask = rewriter.create<vector::ScalableExtractOp>(
loc, sliceMaskType, sliceMask, smeTile.col);
}

// Extract and store the current slice.
Value tile = inputSMETiles[index];
auto slice =
rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex);
rewriter.create<vector::TransferWriteOp>(
loc, slice, writeOp.getSource(), ValueRange{storeRow, storeCol},
AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)),
sliceMask,
rewriter.getBoolArrayAttr(
ArrayRef<bool>(writeOp.getInBoundsValues()).drop_front()));
}

rewriter.eraseOp(writeOp);
return success();
}
};

//===----------------------------------------------------------------------===//
// ArmSME-specific fixup canonicalizations/folds
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -663,9 +797,12 @@ struct VectorLegalizationPass
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes>(context);
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition,
LegalizeMultiTileTransferWriteAsStoreLoop>(converter, context,
/*benefit=*/1024);
patterns.add<LegalizeArithConstantOpsByDecomposition,
LegalizeVectorOuterProductOpsByDecomposition,
LegalizeTransferReadOpsByDecomposition,
Expand Down
112 changes: 106 additions & 6 deletions mlir/test/Dialect/ArmSME/vector-legalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,17 @@ func.func @transfer_read_i16_scalable_8x16_masked(%src: memref<?x?xi16>, %dim0:
func.func @transfer_write_f16_scalable_16x8(%dest: memref<?x?xf16>, %vec: vector<[16]x[8]xf16>)
{
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-DAG: %[[C8_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C8]] : index
// CHECK-DAG: vector.transfer_write %[[TOP]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
// CHECK-DAG: vector.transfer_write %[[BOTTOM]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref<?x?xf16>
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C8_VSCALE]] step %[[C1]] {
// CHECK-NEXT: %[[TOP_SLICE:.*]] = vector.extract %[[TOP]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
// CHECK-NEXT: vector.transfer_write %[[TOP_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
// CHECK-NEXT: %[[BOTTOM_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
// CHECK-NEXT: %[[BOTTOM_SLICE:.*]] = vector.extract %[[BOTTOM]][%[[I]]] : vector<[8]xf16> from vector<[8]x[8]xf16>
// CHECK-NEXT: vector.transfer_write %[[BOTTOM_SLICE]], %[[DEST]][%[[BOTTOM_I]], %[[C0]]] {in_bounds = [true]} : vector<[8]xf16>, memref<?x?xf16>
// CHECK-NEXT: }
// CHECK-NEXT: return
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[8]xf16>, memref<?x?xf16>
Expand All @@ -201,6 +207,90 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:

// -----

// CHECK-LABEL: @transfer_write_f32_scalable_8x8_masked(
// CHECK-SAME: %[[DEST:[a-z0-9]+]]: memref<?x?xf32>,
// CHECK-SAME: %[[DIM_0:[a-z0-9]+]]: index,
// CHECK-SAME: %[[DIM_1:[a-z0-9]+]]: index,
// CHECK-SAME: %[[TILE_0:[a-z0-9]+]]: vector<[4]x[4]xf32>,
// CHECK-SAME: %[[TILE_1:[a-z0-9]+]]: vector<[4]x[4]xf32>,
// CHECK-SAME: %[[TILE_2:[a-z0-9]+]]: vector<[4]x[4]xf32>,
// CHECK-SAME: %[[TILE_3:[a-z0-9]+]]: vector<[4]x[4]xf32>)
func.func @transfer_write_f32_scalable_8x8_masked(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
{
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
// CHECK-DAG: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
// CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<[8]x[8]xi1>
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
// CHECK-NEXT: %[[UPPER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
// CHECK-NEXT: %[[TILE_0_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
// CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
// CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]], %[[TILE_0_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
// CHECK-NEXT: %[[TILE_1_SLICE_MASK:.*]] = vector.scalable.extract %[[UPPER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
// CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
// CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[I]], %[[C4_VSCALE]]], %[[TILE_1_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
// CHECK-NEXT: %[[LOWER_SLICE_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
// CHECK-NEXT: %[[LOWER_SLICE_MASK:.*]] = vector.extract %[[MASK]][%[[LOWER_SLICE_I]]] : vector<[8]xi1> from vector<[8]x[8]xi1>
// CHECK-NEXT: %[[TILE_2_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][0] : vector<[4]xi1> from vector<[8]xi1>
// CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
// CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C0]]], %[[TILE_2_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
// CHECK-NEXT: %[[TILE_3_SLICE_MASK:.*]] = vector.scalable.extract %[[LOWER_SLICE_MASK]][4] : vector<[4]xi1> from vector<[8]xi1>
// CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
// CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE:.*]], %[[DEST]][%[[LOWER_SLICE_I]], %[[C4_VSCALE]]], %[[TILE_3_SLICE_MASK]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
// CHECK-NEXT: }
%c0 = arith.constant 0 : index
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
return
}

// -----

// Tensor semantics are not supported for the store loop lowering.

// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
// CHECK-NOT: scf.for
func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %vec: vector<[8]x[8]xf32>)
{
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
return
}

// -----

#transpose = affine_map<(d0, d1) -> (d1, d0)>

// Transposes are not supported for the store loop lowering.

// CHECK-LABEL: @negative_transfer_write_f32_scalable_8x8_tensor
// CHECK-NOT: scf.for
func.func @negative_transfer_write_f32_scalable_8x8_tensor(%dest: tensor<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[8]x[8]xf32>)
{
%c0 = arith.constant 0 : index
%mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1>
vector.transfer_write %vec, %dest[%c0, %c0], %mask {permutation_map = #transpose, in_bounds = [true, true]} : vector<[8]x[8]xf32>, tensor<?x?xf32>
return
}

// -----

// Masked writes where any dimension of the mask is > 16 are not supported for the store loop lowering.

// CHECK-LABEL: @negative_transfer_write_f32_scalable_32x32
// CHECK-NOT: scf.for
func.func @negative_transfer_write_f32_scalable_32x32(%dest: memref<?x?xf32>, %dim0: index, %dim1: index, %vec: vector<[32]x[32]xf32>)
{
%c0 = arith.constant 0 : index
%mask = vector.create_mask %dim0, %dim1 : vector<[32]x[32]xi1>
vector.transfer_write %vec, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[32]x[32]xf32>, memref<?x?xf32>
return
}

// -----

#transpose = affine_map<(d0, d1) -> (d1, d0)>

// CHECK-LABEL: @transpose_f32_scalable_4x16_via_read(
Expand All @@ -209,6 +299,7 @@ func.func @transfer_write_i8_scalable_16x16_masked(%dest: memref<?x?xi8>, %vec:
func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: memref<?x?xf32>)
{
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
Expand All @@ -221,10 +312,19 @@ func.func @transpose_f32_scalable_4x16_via_read(%src: memref<?x?xf32>, %dest: me
// CHECK-DAG: %[[TILE_1:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C4_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK-DAG: %[[TILE_2:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C8_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK-DAG: %[[TILE_3:.*]] = vector.transfer_read %[[SRC]][%[[C0]], %[[C12_VSCALE]]], %[[PAD]] {in_bounds = [true, true], permutation_map = #{{.*}}} : memref<?x?xf32>, vector<[4]x[4]xf32>
// CHECK-DAG: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
// CHECK-DAG: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[C4_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
// CHECK-DAG: vector.transfer_write %[[TILE_2]], %[[DEST]][%[[C8_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
// CHECK-DAG: vector.transfer_write %[[TILE_3]], %[[DEST]][%[[C12_VSCALE]], %[[C0]]] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref<?x?xf32>
// CHECK-NEXT: scf.for %[[I:.*]] = %[[C0]] to %[[C4_VSCALE]] step %[[C1]] {
// CHECK-NEXT: %[[TILE_0_SLICE:.*]] = vector.extract %[[TILE_0]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
// CHECK-NEXT: vector.transfer_write %[[TILE_0_SLICE]], %[[DEST]][%[[I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
// CHECK-NEXT: %[[TILE_1_I:.*]] = arith.addi %[[C4_VSCALE]], %[[I]] : index
// CHECK-NEXT: %[[TILE_1_SLICE:.*]] = vector.extract %[[TILE_1]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
// CHECK-NEXT: vector.transfer_write %[[TILE_1_SLICE]], %[[DEST]][%[[TILE_1_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
// CHECK-NEXT: %[[TILE_2_I:.*]] = arith.addi %[[C8_VSCALE]], %[[I]] : index
// CHECK-NEXT: %[[TILE_2_SLICE:.*]] = vector.extract %[[TILE_2]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
// CHECK-NEXT: vector.transfer_write %[[TILE_2_SLICE]], %[[DEST]][%[[TILE_2_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
// CHECK-NEXT: %[[TILE_3_I:.*]] = arith.addi %[[C12_VSCALE]], %[[I]] : index
// CHECK-NEXT: %[[TILE_3_SLICE:.*]] = vector.extract %[[TILE_3]][%[[I]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
// CHECK-NEXT: vector.transfer_write %[[TILE_3_SLICE]], %[[DEST]][%[[TILE_3_I]], %[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, memref<?x?xf32>
// CHECK-NEXT: }
// CHECK-NEXT: return
%c0 = arith.constant 0 : index
%pad = arith.constant 0.0 : f32
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
// RUN: mlir-opt %s \
// RUN: -transform-interpreter -test-transform-dialect-erase-schedule \
// RUN: -one-shot-bufferize="bufferize-function-boundaries" -canonicalize \
// RUN: -test-lower-to-arm-sme -test-lower-to-llvm | \
// RUN: -test-lower-to-arm-sme -convert-vector-to-llvm="enable-arm-sve" \
// RUN: -test-lower-to-llvm | \
// RUN: %mcr_aarch64_cmd \
// RUN: -e=main -entry-point-result=void \
// RUN: -march=aarch64 -mattr="+sve,+sme" \
Expand Down
Loading