Skip to content

Commit 952f78a

Browse files
committed
[mlir][vector] Update tests/patterns for vector.transpose
Pretty much all logic that we have today for lowering vector.transpose assumes fixed length vectors (it's done via vector.shuffle that don't support scalable vectors). This patch updates related tests and patterns to capture and document that limitation more explicitly. Note that `vector.transpose` is a valid operation in the context of scalable vectors, but we are yet to implement the missing lowerings. Most changes are implemented in the test file. Here's a summary: * @transpose_nx8x2xf32 is renamed as @transpose_scalable and move near other test using `lowering_strategy = "shuffle_1d"` (to avoid duplicating TD sequences) * tests specific to X86 (`avx2_lowering_strategy = true`) are moved to a dedicated file (to seperate generic tests from target-specific tests) * `@transpose10_nx4xnx1xf32` duplicated `@transpose10_4xnx1xf32` and was deleted (the latter is renamed as `@transpose10_4x1xf32_scalable` to match its fixed-width counterpart: `@transpose10_4x1xf32`) * The changes in LowerVectorTranspose.cpp are NFCs - they just make sure that "scalable" vectors are filtered out at the very beginning
1 parent 33e16ca commit 952f78a

File tree

2 files changed

+43
-525
lines changed

2 files changed

+43
-525
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
320320

321321
LogicalResult matchAndRewrite(vector::TransposeOp op,
322322
PatternRewriter &rewriter) const override {
323+
if (op.getSourceVectorType().isScalable())
324+
return rewriter.notifyMatchFailure(
325+
op, "scalable vectors are not supported by this pattern");
326+
323327
auto loc = op.getLoc();
324328

325329
Value input = op.getVector();
@@ -352,9 +356,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
352356
return success();
353357
}
354358

355-
if (inputType.isScalable())
356-
return failure();
357-
358359
// Handle a true 2-D matrix transpose differently when requested.
359360
if (vectorTransformOptions.vectorTransposeLowering ==
360361
vector::VectorTransposeLowering::Flat &&

0 commit comments

Comments
 (0)