-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Add rewrite to handle unsupported SVE transposes via SME/ZA #98620
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-sme Author: Benjamin Maxwell (MacDue) ChangesThis adds a workaround rewrite that allows stores of unsupported SVE transposes such as: %tr = vector.transpose %vec, [1, 0]
: vector<2x[4]xf32> to vector<[4]x2xf32>
vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]}
: vector<[4]x2xf32>, memref<?x?xf32> To use SME tiles, which are possible to lower (when SME is available): // Insert vector<2x[4]xf32> into an SME tile:
%0 = arm_sme.get_tile : vector<[4]x[4]xf32>
%1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
%2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
%3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
%4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
// Store the tile with a transpose + mask:
%c4_vscale = arith.muli %vscale, %c4 : index
%mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
{permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
: vector<[4]x[4]xf32>, memref<?x?xf32> Full diff: https://github.com/llvm/llvm-project/pull/98620.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
index dfd64f995546a..921234daad1f1 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
@@ -202,7 +202,8 @@ def VectorLegalization
"func::FuncDialect",
"arm_sme::ArmSMEDialect",
"vector::VectorDialect",
- "arith::ArithDialect"
+ "arith::ArithDialect",
+ "index::IndexDialect"
];
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index 600f2ecdb51bc..8f9b5080e82db 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms
MLIRFuncDialect
MLIRLLVMCommonConversion
MLIRVectorDialect
+ MLIRIndexDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRFuncTransforms
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 96dad6518fec8..028e2327e2a4f 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -18,6 +18,8 @@
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
@@ -140,11 +142,11 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
VectorType smeTileType,
bool transposeIndices = false) {
- assert(isMultipleOfSMETileVectorType(type) &&
- "`type` not multiple of SME tiles");
return llvm::map_range(
- StaticTileOffsetRange(type.getShape(), {smeTileType.getDimSize(0),
- smeTileType.getDimSize(1)}),
+ StaticTileOffsetRange(
+ type.getShape(),
+ {std::min(type.getDimSize(0), smeTileType.getDimSize(0)),
+ std::min(type.getDimSize(1), smeTileType.getDimSize(1))}),
[=](auto indices) {
int row = int(indices[0]);
int col = int(indices[1]);
@@ -374,6 +376,14 @@ struct LegalizeTransferWriteOpsByDecomposition
}
};
+auto makeVscaleConstantBuilder(PatternRewriter &rewriter, Location loc) {
+ Value vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ return [loc, vscale, &rewriter](int64_t multiplier) {
+ return rewriter.create<arith::MulIOp>(
+ loc, vscale, rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
+ };
+}
+
/// Legalize a multi-tile transfer_write as a single store loop. This is done as
/// part of type decomposition as at this level we know each tile write is
/// disjoint, but that information is lost after decomposition (without analysis
@@ -440,12 +450,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
kMatchFailureUnsupportedMaskOp);
auto loc = writeOp.getLoc();
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
- auto createVscaleMultiple = [&](int64_t multiplier) {
- return rewriter.create<arith::MulIOp>(
- loc, vscale,
- rewriter.create<arith::ConstantIndexOp>(loc, multiplier));
- };
+ auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
// Get SME tile and slice types.
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -775,6 +780,148 @@ struct ConvertIllegalShapeCastOpsToTransposes
}
};
+/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
+/// the ZA state. This workaround rewrite to support these transposes when ZA is
+/// available.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %transpose = vector.transpose %vec, [1, 0]
+/// : vector<2x[4]xf32> to vector<[4]x2xf32>
+/// vector.transfer_write %transpose, %dest[%y, %x]
+/// : vector<[4]x2xf32>, memref<?x?xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = arm_sme.get_tile : vector<[4]x[4]xf32>
+/// %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32>
+/// %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+/// %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32>
+/// %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+/// %c4_vscale = arith.muli %vscale, %c4 : index
+/// %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1>
+/// vector.transfer_write %4, %arg1[%arg2, %arg3], %mask
+/// {permutation_map = affine_map<(d0, d1) -> (d1, d0)>}
+/// : vector<[4]x[4]xf32>, memref<?x?xf32>
+/// ```
+///
+/// Values larger than a single tile are supported via decomposition.
+struct LowerIllegalTransposeStoreViaZA
+ : public OpRewritePattern<vector::TransferWriteOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) const override {
+ if (!isSupportedMaskOp(writeOp.getMask()))
+ return rewriter.notifyMatchFailure(writeOp,
+ kMatchFailureUnsupportedMaskOp);
+
+ auto permutationMap = writeOp.getPermutationMap();
+ if (!permutationMap.isIdentity())
+ return rewriter.notifyMatchFailure(writeOp,
+ kMatchFailureNonPermutationMap);
+
+ auto transposeOp = writeOp.getVector().getDefiningOp<vector::TransposeOp>();
+ if (!transposeOp)
+ return failure();
+
+ auto sourceType = transposeOp.getSourceVectorType();
+ auto resultType = transposeOp.getResultVectorType();
+
+ if (resultType.getRank() != 2)
+ return rewriter.notifyMatchFailure(transposeOp, "not rank 2");
+
+ if (!isLegalVectorType(sourceType) || isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(
+ transposeOp, "not illegal/unsupported SVE transpose");
+
+ auto smeTileType = getSMETileTypeForElement(resultType.getElementType());
+ VectorType smeSliceType = VectorType::Builder(smeTileType).dropDim(0);
+
+ if (sourceType.getDimSize(0) <= 1 ||
+ sourceType.getDimSize(1) % smeSliceType.getDimSize(0) != 0)
+ return rewriter.notifyMatchFailure(writeOp, "unsupported source shape");
+
+ auto loc = writeOp.getLoc();
+ auto createVscaleMultiple = makeVscaleConstantBuilder(rewriter, loc);
+
+ auto transposeMap = AffineMapAttr::get(
+ AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext()));
+
+ // Note: We need to use `get_tile` as there's no vector-level `undef`.
+ Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType);
+ Value destTensorOrMemref = writeOp.getSource();
+ auto numSlicesPerTile =
+ std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0));
+ auto numSlices =
+ rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile);
+ for (auto [index, smeTile] : llvm::enumerate(
+ decomposeToSMETiles(rewriter, sourceType, smeTileType))) {
+ // 1. _Deliberately_ drop a scalable dimension and insert a fixed number
+ // of slices from the source type into the SME tile. Without checking
+ // vscale (and emitting multiple implementations) we can't make use of the
+ // rows of the tile after 1*vscale rows.
+ Value tile = undefTile;
+ for (int d = 0, e = numSlicesPerTile; d < e; ++d) {
+ Value vector = rewriter.create<vector::ExtractOp>(
+ loc, transposeOp.getVector(),
+ rewriter.getIndexAttr(d + smeTile.row));
+ if (vector.getType() != smeSliceType) {
+ vector = rewriter.create<vector::ScalableExtractOp>(
+ loc, smeSliceType, vector, smeTile.col);
+ }
+ tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d);
+ }
+
+ // 2. Transpose the tile position.
+ auto transposedRow = createVscaleMultiple(smeTile.col);
+ auto transposedCol =
+ rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row);
+
+ // 3. Compute mask for tile store.
+ Value maskRows;
+ Value maskCols;
+ if (auto mask = writeOp.getMask()) {
+ auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
+ maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0),
+ transposedRow);
+ maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1),
+ transposedCol);
+ maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices);
+ } else {
+ maskRows = createVscaleMultiple(smeTileType.getDimSize(0));
+ maskCols = numSlices;
+ }
+ auto subMask = rewriter.create<vector::CreateMaskOp>(
+ loc, smeTileType.clone(rewriter.getI1Type()),
+ ValueRange{maskRows, maskCols});
+
+ // 4. Emit a transposed tile write.
+ auto writeIndices = writeOp.getIndices();
+ Value destRow =
+ rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]);
+ Value destCol =
+ rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]);
+ auto smeWrite = rewriter.create<vector::TransferWriteOp>(
+ loc, tile, destTensorOrMemref, ValueRange{destRow, destCol},
+ transposeMap, subMask, writeOp.getInBounds().value_or(ArrayAttr{}));
+
+ if (writeOp.hasPureTensorSemantics())
+ destTensorOrMemref = smeWrite.getResult();
+ }
+
+ if (writeOp.hasPureTensorSemantics())
+ rewriter.replaceOp(writeOp, destTensorOrMemref);
+ else
+ rewriter.eraseOp(writeOp);
+
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
@@ -796,7 +943,8 @@ struct VectorLegalizationPass
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes>(context);
+ ConvertIllegalShapeCastOpsToTransposes,
+ LowerIllegalTransposeStoreViaZA>(context);
// Note: These two patterns are added with a high benefit to ensure:
// - Masked outer products are handled before unmasked ones
// - Multi-tile writes are lowered as a store loop (if possible)
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index 71d80bc16ea12..951b29b6e3805 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -544,3 +544,74 @@ func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
%0 = arith.constant dense<42> : vector<[8]x[8]xi32>
return %0 : vector<[8]x[8]xi32>
}
+
+// -----
+
+// CHECK: #[[$TRANSPOSE_MAP_0:.*]] = affine_map<(d0, d1) -> (d1, d0)>
+
+// CHECK-LABEL: @transpose_store_scalable_via_za(
+// CHECK-SAME: %[[VEC:.*]]: vector<2x[4]xf32>
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[I:.*]]: index,
+// CHECK-SAME: %[[J:.*]]: index)
+func.func @transpose_store_scalable_via_za(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+ // CHECK-NEXT: %[[INIT:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[V0:.*]] = vector.extract %[[VEC]][0] : vector<[4]xf32> from vector<2x[4]xf32>
+ // CHECK-NEXT: %[[R0:.*]] = vector.insert %[[V0]], %[[INIT]] [0] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[V1:.*]] = vector.extract %[[VEC]][1] : vector<[4]xf32> from vector<2x[4]xf32>
+ // CHECK-NEXT: %[[RES:.*]] = vector.insert %[[V1]], %[[R0]] [1] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK-NEXT: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+ // CHECK-NEXT: %[[MASK:.*]] = vector.create_mask %[[C4_VSCALE]], %[[C2]] : vector<[4]x[4]xi1>
+ // CHECK-NEXT: vector.transfer_write %[[RES]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {in_bounds = [true, true], permutation_map = #[[$TRANSPOSE_MAP_0]]} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+ vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK: @transpose_store_scalable_via_za_masked(
+// CHECK-SAME: %[[A:[a-z0-9]+]]: index,
+// CHECK-SAME: %[[B:[a-z0-9]+]]: index)
+func.func @transpose_store_scalable_via_za_masked(%vec: vector<2x[4]xf32>, %dest: memref<?x?xf32>, %a: index, %b: index) {
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
+ // CHECK: %[[MIN:.*]] = index.mins %[[B]], %[[C2]]
+ // CHECK: %[[MASK:.*]] = vector.create_mask %[[A]], %[[MIN]] : vector<[4]x[4]xi1>
+ // CHECK: vector.transfer_write {{.*}} %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ %c0 = arith.constant 0 : index
+ %mask = vector.create_mask %a, %b : vector<[4]x2xi1>
+ %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32>
+ vector.transfer_write %tr, %dest[%c0, %c0], %mask {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32>
+ return
+}
+
+// -----
+
+// CHECK: @transpose_store_scalable_via_za_multi_tile(
+// CHECK-SAME: %[[VEC:.*]]: vector<8x[4]xf32>
+// CHECK-SAME: %[[DEST:.*]]: memref<?x?xf32>,
+// CHECK-SAME: %[[I:.*]]: index,
+// CHECK-SAME: %[[J:.*]]: index)
+func.func @transpose_store_scalable_via_za_multi_tile(%vec: vector<8x[4]xf32>, %dest: memref<?x?xf32>, %i: index, %j: index) {
+ // CHECK: %[[C4:.*]] = arith.constant 4 : index
+ // CHECK: %[[VSCALE:.*]] = vector.vscale
+
+ // <skip 3x other extract+insert chain>
+ // CHECK: %[[V3:.*]] = vector.extract %[[VEC]][3] : vector<[4]xf32> from vector<8x[4]xf32>
+ // CHECK: %[[TILE_0:.*]] = vector.insert %[[V3]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+ // CHECK: %[[MASK:.*]] = vector.create_mask %c4_vscale, %c4 : vector<[4]x[4]xi1>
+ // CHECK: vector.transfer_write %[[TILE_0]], %[[DEST]][%[[I]], %[[J]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+
+ // <skip 3x other extract+insert chain>
+ // CHECK: %[[V7:.*]] = vector.extract %arg0[7] : vector<[4]xf32> from vector<8x[4]xf32>
+ // CHECK: %[[TILE_1:.*]] = vector.insert %[[V7]], %{{.*}} [3] : vector<[4]xf32> into vector<[4]x[4]xf32>
+ // CHECK: %[[J_OFFSET:.*]] = arith.addi %[[J]], %[[C4]] : index
+ // CHECK: vector.transfer_write %[[TILE_1]], %[[DEST]][%[[I]], %[[J_OFFSET]]], %[[MASK]] {{.*}} : vector<[4]x[4]xf32>, memref<?x?xf32>
+ %tr = vector.transpose %vec, [1, 0] : vector<8x[4]xf32> to vector<[4]x8xf32>
+ vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x8xf32>, memref<?x?xf32>
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have one question, but you can answer after landing. Thanks, LGTM!
/// : vector<[4]x[4]xf32>, memref<?x?xf32> | ||
/// ``` | ||
/// | ||
/// Values larger than a single tile are supported via decomposition. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥳
rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile); | ||
for (auto [index, smeTile] : llvm::enumerate( | ||
decomposeToSMETiles(rewriter, sourceType, smeTileType))) { | ||
// 1. _Deliberately_ drop a scalable dimension and insert a fixed number |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where are these dims dropped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't dynamically index the array-of-vectors input (or dynamically select an SME tile). These restrictions mean this lowering just targets the lowest common denominator (that is vscale = 1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, IIUC the scalability is dropped on L853. And, basically, this pattern will store at most 4 rows per tile?
…ME/ZA This adds a workaround rewrite that allows stores of unsupported SVE transposes such as: ```mlir %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32> vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32> ``` To use SME tiles, which are possible to lower (when SME is available): ```mlir // Insert vector<2x[4]xf32> into an SME tile: %0 = arm_sme.get_tile : vector<[4]x[4]xf32> %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> // Store the tile with a transpose + mask: %c4_vscale = arith.muli %vscale, %c4 : index %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> vector.transfer_write %4, %arg1[%arg2, %arg3], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<[4]x[4]xf32>, memref<?x?xf32> ```
cc09972
to
a37d6da
Compare
…ME/ZA (#98620) Summary: This adds a workaround rewrite that allows stores of unsupported SVE transposes such as: ```mlir %tr = vector.transpose %vec, [1, 0] : vector<2x[4]xf32> to vector<[4]x2xf32> vector.transfer_write %tr, %dest[%i, %j] {in_bounds = [true, true]} : vector<[4]x2xf32>, memref<?x?xf32> ``` To use SME tiles, which are possible to lower (when SME is available): ```mlir // Insert vector<2x[4]xf32> into an SME tile: %0 = arm_sme.get_tile : vector<[4]x[4]xf32> %1 = vector.extract %vec[0] : vector<[4]xf32> from vector<2x[4]xf32> %2 = vector.insert %1, %0 [0] : vector<[4]xf32> into vector<[4]x[4]xf32> %3 = vector.extract %vec[1] : vector<[4]xf32> from vector<2x[4]xf32> %4 = vector.insert %3, %2 [1] : vector<[4]xf32> into vector<[4]x[4]xf32> // Store the tile with a transpose + mask: %c4_vscale = arith.muli %vscale, %c4 : index %mask = vector.create_mask %c4_vscale, %c2 : vector<[4]x[4]xi1> vector.transfer_write %4, %arg1[%arg2, %arg3], %mask {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<[4]x[4]xf32>, memref<?x?xf32> ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250566
This adds a workaround rewrite that allows stores of unsupported SVE transposes such as:
To use SME tiles, which are possible to lower (when SME is available):