Skip to content

[mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA #98620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@ def VectorLegalization
"func::FuncDialect",
"arm_sme::ArmSMEDialect",
"vector::VectorDialect",
"arith::ArithDialect"
"arith::ArithDialect",
"index::IndexDialect"
];
}

Expand Down
19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_
#define MLIR_DIALECT_VECTOR_UTILS_VECTORUTILS_H_

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
Expand Down Expand Up @@ -101,6 +102,24 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
std::optional<StaticTileOffsetRange>
createUnrollIterator(VectorType vType, int64_t targetRank = 1);

/// Returns a functor (int64_t -> Value) which returns a constant vscale
/// multiple.
///
/// Example:
/// ```c++
/// auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
/// auto c4Vscale = createVscaleMultiple(4); // 4 * vector.vscale
/// ```
inline auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
Value vscale = nullptr;
return [loc, vscale, &rewriter](int64_t multiplier) mutable {
if (!vscale)
vscale = rewriter.create<vector::VectorScaleOp>(loc);
return rewriter.create<arith::MulIOp>(
loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
};
}

/// A wrapper for getMixedSizes for vector.transfer_read and
/// vector.transfer_write Ops (for source and destination, respectively).
///
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
MLIRIndexDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRFuncTransforms
Expand Down
165 changes: 154 additions & 11 deletions mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Transforms/OneToNTypeConversion.h"

#define DEBUG_TYPE "arm-sme-vector-legalization"
Expand Down Expand Up @@ -140,11 +143,11 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
VectorType smeTileType,
bool transposeIndices = false) {
assert(isMultipleOfSMETileVectorType(type) &&
"`type` not multiple of SME tiles");
return llvm::map_range(
StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
smeTileType.getDimSize(1)}),
StaticTileOffsetRange(
type.getShape(),
{std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
[=](auto indices) {
int row = int(indices[0]);
int col = int(indices[1]);
Expand Down Expand Up @@ -440,12 +443,8 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
kMatchFailureUnsupportedMaskOp);

auto loc = writeOp.getLoc();
auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
auto createVscaleMultiple = [&](int64_t multiplier) {
return rewriter.create<arith::MulIOp>(
loc, vscale,
rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
};
auto createVscaleMultiple =
vector::makeVscaleConstantBuilder(rewriter, loc);

// Get SME tile and slice types.
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
Expand Down Expand Up @@ -775,6 +774,149 @@ struct ConvertIllegalShapeCastOpsToTransposes
}
};

/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
///
/// Example:
///
/// BEFORE:
/// ```mlir
/// %transpose = vector.transpose %vec, [1, 0]
/// : vector<2x[4]xf32> to vector<[4]x2xf32>
/// vector.transfer_write %transpose, %dest[%y, %x]
/// : vector<[4]x2xf32>, memref<?x?xf32>
/// ```
///
/// AFTER:
/// ```mlir
/// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
/// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
/// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
/// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
/// %c4_vscale = arith.muli %vscale, %c4 : index
/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
/// vector.transfer_write %4, %dest[%y, %x], %mask
/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
/// : vector<[4]x[4]xf32>, memref<?x?xf32>
/// ```
///
/// Values larger than a single tile are supported via decomposition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🥳

struct LowerIllegalTransposeStoreViaZA
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
if (!isSupportedMaskOp(writeOp.getMask()))
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureUnsupportedMaskOp);

auto permutationMap = writeOp.getPermutationMap();
if (!permutationMap.isIdentity())
return rewriter.notifyMatchFailure(writeOp,
kMatchFailureNonPermutationMap);

auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
if (!transposeOp)
return failure();

auto sourceType = transposeOp.getSourceVectorType();
auto resultType = transposeOp.getResultVectorType();

if (resultType.getRank() != 2)
return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2");

if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
return rewriter.notifyMatchFailure(
transposeOp, "not illegal/unsupported SVE transpose");

auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);

if (sourceType.getDimSize(0) <= 1 ||
sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");

auto loc = writeOp.getLoc();
auto createVscaleMultiple =
vector::makeVscaleConstantBuilder(rewriter, loc);

auto transposeMap = AffineMapAttr::get(
AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));

// Note: We need to use `get_tile` as there's no vector-level `undef`.
Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
Value destTensorOrMemref = writeOp.getSource();
auto numSlicesPerTile =
std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
auto numSlices =
rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
// 1. _Deliberately_ drop a scalable dimension and insert a fixed number
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these dims dropped?

Copy link
Member Author

@MacDue MacDue Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't dynamically index the array-of-vectors input (or dynamically select an SME tile). These restrictions mean this lowering just targets the lowest common denominator (that is vscale = 1).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, IIUC the scalability is dropped on L853. And, basically, this pattern will store at most 4 rows per tile?

// of slices from the source type into the SME tile. Without checking
// vscale (and emitting multiple implementations) we can't make use of the
// rows of the tile after 1*vscale rows.
Value tile = undefTile;
for (int d = 0; d < numSlicesPerTile; ++d) {
Value vector = rewriter.create<vector::ExtractOp>(
loc, transposeOp.getVector(),
rewriter.getIndexAttr(d + smeTile.row));
if (vector.getType() != smeSliceType) {
vector = rewriter.create<vector::ScalableExtractOp>(
loc, smeSliceType, vector, smeTile.col);
}
tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
}

// 2. Transpose the tile position.
auto transposedRow = createVscaleMultiple(smeTile.col);
auto transposedCol =
rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);

// 3. Compute mask for tile store.
Value maskRows;
Value maskCols;
if (auto mask = writeOp.getMask()) {
auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
transposedRow);
maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
transposedCol);
maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
} else {
maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
maskCols = numSlices;
}
auto subMask = rewriter.create<vector::CreateMaskOp>(
loc, smeTileType.clone(rewriter.getI1Type()),
ValueRange{maskRows, maskCols});

// 4. Emit a transposed tile write.
auto writeIndices = writeOp.getIndices();
Value destRow =
rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
Value destCol =
rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
auto smeWrite = rewriter.create<vector::TransferWriteOp>(
loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
transposeMap, subMask, writeOp.getInBounds());

if (writeOp.hasPureTensorSemantics())
destTensorOrMemref = smeWrite.getResult();
}

if (writeOp.hasPureTensorSemantics())
rewriter.replaceOp(writeOp, destTensorOrMemref);
else
rewriter.eraseOp(writeOp);

return success();
}
};

struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
Expand All @@ -796,7 +938,8 @@ struct VectorLegalizationPass

patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
ConvertIllegalShapeCastOpsToTransposes>(context);
ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
Expand Down
102 changes: 102 additions & 0 deletions mlir/test/Dialect/ArmSME/vector-legalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,105 @@ func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
%0 = arith.constant dense<42> : vector<[8]x[8]xi32>
return %0 : vector<[8]x[8]xi32>
}

// -----

// CHECK: #[[$TRANSPOSE_MAP_0:.*]] = affine_map<(d0, d1) -> (d1, d0)>

// CHECK-LABEL: @transpose_store_scalable_via_za(
// CHECK-SAME: %[[VEC:.*]]: vector<2x[4]xf32>
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[I:.*]]: index,
// CHECK-SAME: %[[J:.*]]: index)
func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
// CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
// CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32>
// CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
// CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// CHECK-NEXT: %[[VSCALE:.*]] = vector.vscale
// CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
// CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1>
// CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32>
%tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32>
return
}

// -----

// CHECK-LABEL: @transpose_store_scalable_via_za_masked(
// CHECK-SAME: %[[A:[a-z0-9]+]]: index,
// CHECK-SAME: %[[B:[a-z0-9]+]]: index)
func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) {
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]]
// CHECK: %[[MASK:.*]] = vector.create_mask %[[A]], %[[MIN]] : vector<[4]x[4]xi1>
// CHECK: vector.transfer_write {{.*}} %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
%c0 = arith.constant 0 : index
%mask = vector.create_mask %a, %b : vector<[4]x2xi1>
%tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32>
return
}

// -----

// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile(
// CHECK-SAME: %[[VEC:.*]]: vector<8x[4]xf32>
// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[I:.*]]: index,
// CHECK-SAME: %[[J:.*]]: index)
func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
// CHECK: %[[C4:.*]] = arith.constant 4 : index

// <skip 3x other extract+insert chain>
// CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32>
// CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
// CHECK: %[[VSCALE:.*]] = vector.vscale
// CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
// CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1>
// CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>

// <skip 3x other extract+insert chain>
// CHECK: %[[V7:.*]] = vector.extract %arg0[7] : vector<[4]xf32> from vector<8x[4]xf32>
// CHECK: %[[TILE_1:.*]] = vector.insert %[[V7]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
// CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[C4]] : index
// CHECK: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[I]], %[[J_OFFSET]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
%tr = vector.transpose %vec, [1, 0] : vector<8x[4]xf32> to vector<[4]x8xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>, memref<?x?xf32>
return
}

// -----

// CHECK-LABEL: @transpose_store_scalable_via_za_multi_tile_wide
func.func @transpose_store_scalable_via_za_multi_tile_wide(%vec: vector<2x[8]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
// <check extracts from lower 4 x vscale of %vec>
// CHECK: vector.scalable.extract
// CHECK: %[[ROW_2_LOWER:.*]] = vector.scalable.extract %{{.*}}[0] : vector<[4]xf32> from vector<[8]xf32>
// CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_LOWER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I:.[a-z0-9]+]], %[[J:[a-z0-9]+]]]

// <check extracts from upper 4 x vscale of %vec>
// CHECK: vector.scalable.extract
// CHECK: %[[ROW_2_UPPER:.*]] = vector.scalable.extract %{{.*}}[4] : vector<[4]xf32> from vector<[8]xf32>
// CHECK: %[[TILE_0:.*]] = vector.insert %[[ROW_2_UPPER]], %{{.*}}[1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// CHECK: %[[I_OFFSET:.*]] = arith.addi %c4_vscale, %[[I]] : index
// CHECK: vector.transfer_write %[[TILE_0]], %{{.*}}[%[[I_OFFSET]], %[[J]]]
%tr = vector.transpose %vec, [1, 0] : vector<2x[8]xf32> to vector<[8]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[8]x2xf32>, memref<?x?xf32>
return
}

// -----

// CHECK-LABEL: @negative_transpose_store_scalable_via_za__bad_source_shape
// CHECK-NOT: arm_sme.get_tile
func.func @negative_transpose_store_scalable_via_za__bad_source_shape(%vec: vector<2x[7]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
%tr = vector.transpose %vec, [1, 0] : vector<2x[7]xf32> to vector<[7]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[7]x2xf32>, memref<?x?xf32>
return
}
Loading