Skip to content

Commit bbd2b08

Browse files
authored
[mlir][vector] Make TransposeOpLowering configurable (#73915)
Following the discussion here: * #72105 this patch makes the `TransposeOpLowering` configurable so that one can select whether to favour `vector.shape_cast` over `vector.transpose`. As per the discussion in #72105, using `vector.shape_cast` is very beneficial and desirable when targeting `LLVM IR` (CPU lowering), but won't work when targeting `SPIR-V` today (GPU lowering). Hence the need for a mechanism to be able to disable/enable the pattern introduced in #72105. This patch proposes one such mechanism. While this should solve the problem that we are facing today, it's understood to be a temporary workaround. It should be removed once support for lowering `vector.shape_cast` to SPIR-V is added. Also, (once implemented) the following proposal might make this workaround redundant: * https://discourse.llvm.org/t/improving-handling-of-unit-dimensions-in-the-vector-dialect/
1 parent 74e59e7 commit bbd2b08

File tree

2 files changed

+28
-16
lines changed

2 files changed

+28
-16
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ struct VectorTransformsOptions {
5959
vectorTransferSplit = opt;
6060
return *this;
6161
}
62+
63+
/// Option to control if vector.transpose can lower to a vector.shape_cast.
64+
/// TODO: ATM it's not possible to lower `vector.shape_cast` to SPIR-V
65+
/// and hence the need for this opt-out. Once the missing support has been
66+
/// added, this option can be removed.
67+
bool useShapeCast = true;
68+
VectorTransformsOptions &setUseShapeCast(bool opt = true) {
69+
useShapeCast = opt;
70+
return *this;
71+
}
6272
};
6373

6474
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -334,22 +334,24 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
334334
return rewriter.notifyMatchFailure(
335335
op, "Options specifies lowering to shuffle");
336336

337-
// Replace:
338-
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
339-
// vector<1xnxelty>
340-
// with:
341-
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
342-
//
343-
// Source with leading unit dim (inverse) is also replaced. Unit dim must
344-
// be fixed. Non-unit can be scalable.
345-
if (resType.getRank() == 2 &&
346-
((resType.getShape().front() == 1 &&
347-
!resType.getScalableDims().front()) ||
348-
(resType.getShape().back() == 1 &&
349-
!resType.getScalableDims().back())) &&
350-
transp == ArrayRef<int64_t>({1, 0})) {
351-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
352-
return success();
337+
if (vectorTransformOptions.useShapeCast) {
338+
// Replace:
339+
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
340+
// vector<1xnxelty>
341+
// with:
342+
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
343+
//
344+
// Source with leading unit dim (inverse) is also replaced. Unit dim must
345+
// be fixed. Non-unit can be scalable.
346+
if (resType.getRank() == 2 &&
347+
((resType.getShape().front() == 1 &&
348+
!resType.getScalableDims().front()) ||
349+
(resType.getShape().back() == 1 &&
350+
!resType.getScalableDims().back())) &&
351+
transp == ArrayRef<int64_t>({1, 0})) {
352+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
353+
return success();
354+
}
353355
}
354356

355357
if (inputType.isScalable())

0 commit comments

Comments
 (0)