-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA #98620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,10 +18,13 @@ | |
#include "mlir/Dialect/ArmSME/Utils/Utils.h" | ||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" | ||
#include "mlir/Dialect/Index/IR/IndexDialect.h" | ||
#include "mlir/Dialect/Index/IR/IndexOps.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/Dialect/SCF/Transforms/Patterns.h" | ||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||
#include "mlir/Transforms/OneToNTypeConversion.h" | ||
|
||
#define DEBUG_TYPE "arm-sme-vector-legalization" | ||
|
@@ -140,11 +143,11 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask, | |
auto decomposeToSMETiles(OpBuilder &builder, VectorType type, | ||
VectorType smeTileType, | ||
bool transposeIndices = false) { | ||
assert(isMultipleOfSMETileVectorType(type) && | ||
"`type` not multiple of SME tiles"); | ||
return llvm::map_range( | ||
StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0), | ||
smeTileType.getDimSize(1)}), | ||
StaticTileOffsetRange( | ||
type.getShape(), | ||
{std::min(type.getDimSize(0), smeTileType.getDimSize(0)), | ||
std::min(type.getDimSize(1), smeTileType.getDimSize(1))}), | ||
[=](auto indices) { | ||
int row = int(indices[0]); | ||
int col = int(indices[1]); | ||
|
@@ -440,12 +443,8 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop | |
kMatchFailureUnsupportedMaskOp); | ||
|
||
auto loc = writeOp.getLoc(); | ||
auto vscale = rewriter.create<vector::VectorScaleOp>(loc); | ||
auto createVscaleMultiple = [&](int64_t multiplier) { | ||
return rewriter.create<arith::MulIOp>( | ||
loc, vscale, | ||
rewriter.create<arith::ConstantIndexOp>(loc, multiplier)); | ||
}; | ||
auto createVscaleMultiple = | ||
vector::makeVscaleConstantBuilder(rewriter, loc); | ||
|
||
// Get SME tile and slice types. | ||
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); | ||
|
@@ -775,6 +774,149 @@ struct ConvertIllegalShapeCastOpsToTransposes | |
} | ||
}; | ||
|
||
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use | ||
/// the ZA state. This workaround rewrite to support these transposes when ZA is | ||
/// available. | ||
/// | ||
/// Example: | ||
/// | ||
/// BEFORE: | ||
/// ```mlir | ||
/// %transpose = vector.transpose %vec, [1, 0] | ||
/// : vector<2x[4]xf32> to vector<[4]x2xf32> | ||
/// vector.transfer_write %transpose, %dest[%y, %x] | ||
/// : vector<[4]x2xf32>, memref<?x?xf32> | ||
/// ``` | ||
/// | ||
/// AFTER: | ||
/// ```mlir | ||
/// %0 = arm_sme.get_tile : vector<[4]x[4]xf32> | ||
/// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> | ||
/// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> | ||
/// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> | ||
/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> | ||
/// %c4_vscale = arith.muli %vscale, %c4 : index | ||
/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> | ||
/// vector.transfer_write %4, %dest[%y, %x], %mask | ||
/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} | ||
/// : vector<[4]x[4]xf32>, memref<?x?xf32> | ||
/// ``` | ||
/// | ||
/// Values larger than a single tile are supported via decomposition. | ||
struct LowerIllegalTransposeStoreViaZA | ||
: public OpRewritePattern<vector::TransferWriteOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, | ||
PatternRewriter &rewriter) const override { | ||
if (!isSupportedMaskOp(writeOp.getMask())) | ||
return rewriter.notifyMatchFailure(writeOp, | ||
kMatchFailureUnsupportedMaskOp); | ||
|
||
auto permutationMap = writeOp.getPermutationMap(); | ||
if (!permutationMap.isIdentity()) | ||
return rewriter.notifyMatchFailure(writeOp, | ||
kMatchFailureNonPermutationMap); | ||
|
||
auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>(); | ||
if (!transposeOp) | ||
return failure(); | ||
|
||
auto sourceType = transposeOp.getSourceVectorType(); | ||
auto resultType = transposeOp.getResultVectorType(); | ||
|
||
if (resultType.getRank() != 2) | ||
return rewriter.notifyMatchFailure(transposeOp, "TransposeOp not rank 2"); | ||
|
||
if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType)) | ||
return rewriter.notifyMatchFailure( | ||
transposeOp, "not illegal/unsupported SVE transpose"); | ||
|
||
auto smeTileType = getSMETileTypeForElement(resultType.getElementType()); | ||
VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0); | ||
|
||
if (sourceType.getDimSize(0) <= 1 || | ||
sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0) | ||
return rewriter.notifyMatchFailure(writeOp, "unsupported source shape"); | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
auto loc = writeOp.getLoc(); | ||
auto createVscaleMultiple = | ||
vector::makeVscaleConstantBuilder(rewriter, loc); | ||
|
||
auto transposeMap = AffineMapAttr::get( | ||
AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext())); | ||
|
||
// Note: We need to use `get_tile` as there's no vector-level `undef`. | ||
Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType); | ||
Value destTensorOrMemref = writeOp.getSource(); | ||
auto numSlicesPerTile = | ||
std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0)); | ||
auto numSlices = | ||
rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile); | ||
for (auto [index, smeTile] : llvm::enumerate( | ||
decomposeToSMETiles(rewriter, sourceType, smeTileType))) { | ||
// 1. _Deliberately_ drop a scalable dimension and insert a fixed number | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where are these dims dropped? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't dynamically index the array-of-vectors input (or dynamically select an SME tile). These restrictions mean this lowering just targets the lowest common denominator (that is vscale = 1). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, IIUC the scalability is dropped on L853. And, basically, this pattern will store at most 4 rows per tile? |
||
// of slices from the source type into the SME tile. Without checking | ||
// vscale (and emitting multiple implementations) we can't make use of the | ||
// rows of the tile after 1*vscale rows. | ||
Value tile = undefTile; | ||
for (int d = 0; d < numSlicesPerTile; ++d) { | ||
Value vector = rewriter.create<vector::ExtractOp>( | ||
loc, transposeOp.getVector(), | ||
rewriter.getIndexAttr(d + smeTile.row)); | ||
if (vector.getType() != smeSliceType) { | ||
vector = rewriter.create<vector::ScalableExtractOp>( | ||
loc, smeSliceType, vector, smeTile.col); | ||
} | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d); | ||
} | ||
|
||
// 2. Transpose the tile position. | ||
auto transposedRow = createVscaleMultiple(smeTile.col); | ||
auto transposedCol = | ||
rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row); | ||
|
||
// 3. Compute mask for tile store. | ||
Value maskRows; | ||
Value maskCols; | ||
if (auto mask = writeOp.getMask()) { | ||
auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); | ||
maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0), | ||
transposedRow); | ||
maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1), | ||
transposedCol); | ||
maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices); | ||
} else { | ||
maskRows = createVscaleMultiple(smeTileType.getDimSize(0)); | ||
maskCols = numSlices; | ||
} | ||
auto subMask = rewriter.create<vector::CreateMaskOp>( | ||
loc, smeTileType.clone(rewriter.getI1Type()), | ||
ValueRange{maskRows, maskCols}); | ||
|
||
// 4. Emit a transposed tile write. | ||
auto writeIndices = writeOp.getIndices(); | ||
Value destRow = | ||
rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]); | ||
Value destCol = | ||
rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]); | ||
auto smeWrite = rewriter.create<vector::TransferWriteOp>( | ||
loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, | ||
transposeMap, subMask, writeOp.getInBounds()); | ||
|
||
if (writeOp.hasPureTensorSemantics()) | ||
destTensorOrMemref = smeWrite.getResult(); | ||
} | ||
|
||
if (writeOp.hasPureTensorSemantics()) | ||
rewriter.replaceOp(writeOp, destTensorOrMemref); | ||
else | ||
rewriter.eraseOp(writeOp); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
struct VectorLegalizationPass | ||
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> { | ||
void runOnOperation() override { | ||
|
@@ -796,7 +938,8 @@ struct VectorLegalizationPass | |
|
||
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, | ||
LiftIllegalVectorTransposeToMemory, | ||
ConvertIllegalShapeCastOpsToTransposes>(context); | ||
ConvertIllegalShapeCastOpsToTransposes, | ||
LowerIllegalTransposeStoreViaZA>(context); | ||
// Note: These two patterns are added with a high benefit to ensure: | ||
// - Masked outer products are handled before unmasked ones | ||
// - Multi-tile writes are lowered as a store loop (if possible) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥳