Skip to content

Commit cbd72cb

Browse files
authored
[mlir][vector] Split TransposeOpLowering into 2 patterns (#91935)
Splits `TransposeOpLowering` into two patterns: 1. `Transpose2DWithUnitDimToShapeCast` - rewrites 2D `vector.transpose` as `vector.shape_cast` (there has to be at least one unit dim), 2. `TransposeOpLowering` - the original pattern without the part extracted into `Transpose2DWithUnitDimToShapeCast`. The rationale behind the split: * the output generated by `Transpose2DWithUnitDimToShapeCast` doesn't really match the intended output from `TransposeOpLowering` as documented in the source file - it doesn't make much sense to keep it embedded inside `TransposeOpLowering`, * `Transpose2DWithUnitDimToShapeCast` _does_ work for scalable vectors, `TransposeOpLowering` _does_ not.
1 parent 363258a commit cbd72cb

File tree

1 file changed

+64
-22
lines changed

1 file changed

+64
-22
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,10 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
326326
VectorType inputType = op.getSourceVectorType();
327327
VectorType resType = op.getResultVectorType();
328328

329+
if (inputType.isScalable())
330+
return rewriter.notifyMatchFailure(
331+
op, "This lowering does not support scalable vectors");
332+
329333
// Set up convenience transposition table.
330334
ArrayRef<int64_t> transp = op.getPermutation();
331335

@@ -334,28 +338,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
334338
return rewriter.notifyMatchFailure(
335339
op, "Options specifies lowering to shuffle");
336340

337-
// Replace:
338-
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
339-
// vector<1xnxelty>
340-
// with:
341-
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
342-
//
343-
// Source with leading unit dim (inverse) is also replaced. Unit dim must
344-
// be fixed. Non-unit can be scalable.
345-
if (resType.getRank() == 2 &&
346-
((resType.getShape().front() == 1 &&
347-
!resType.getScalableDims().front()) ||
348-
(resType.getShape().back() == 1 &&
349-
!resType.getScalableDims().back())) &&
350-
transp == ArrayRef<int64_t>({1, 0})) {
351-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
352-
return success();
353-
}
354-
355-
// TODO: Add support for scalable vectors
356-
if (inputType.isScalable())
357-
return failure();
358-
359341
// Handle a true 2-D matrix transpose differently when requested.
360342
if (vectorTransformOptions.vectorTransposeLowering ==
361343
vector::VectorTransposeLowering::Flat &&
@@ -411,6 +393,64 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
411393
vector::VectorTransformsOptions vectorTransformOptions;
412394
};
413395

396+
/// Rewrites vector.transpose as vector.shape_cast. This pattern is only applied
397+
/// to 2D vectors with at least one unit dim. For example:
398+
///
399+
/// Replace:
400+
/// vector.transpose %0, [1, 0] : vector<4x1xi32>> to
401+
/// vector<1x4xi32>
402+
/// with:
403+
/// vector.shape_cast %0 : vector<4x1xi32> to vector<1x4xi32>
404+
///
405+
/// Source with leading unit dim (inverse) is also replaced. Unit dim must
406+
/// be fixed. Non-unit dim can be scalable.
407+
///
408+
/// TODO: This pattern was introduced specifically to help lower scalable
409+
/// vectors. In hindsight, a more specialised canonicalization (for shape_cast's
410+
/// to cancel out) would be preferable:
411+
///
412+
/// BEFORE:
413+
/// %0 = some_op
414+
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<[4]x1xf32>
415+
/// %2 = vector.transpose %1 [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
416+
/// AFTER:
417+
/// %0 = some_op
418+
/// %1 = vector.shape_cast %0 : vector<[4]xf32> to vector<1x[4]xf32>
419+
///
420+
/// Given the context above, we may want to consider (re-)moving this pattern
421+
/// at some later time. I am leaving it for now in case there are other users
422+
/// that I am not aware of.
423+
class Transpose2DWithUnitDimToShapeCast
424+
: public OpRewritePattern<vector::TransposeOp> {
425+
public:
426+
using OpRewritePattern::OpRewritePattern;
427+
428+
Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
429+
PatternBenefit benefit = 1)
430+
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
431+
432+
LogicalResult matchAndRewrite(vector::TransposeOp op,
433+
PatternRewriter &rewriter) const override {
434+
Value input = op.getVector();
435+
VectorType resType = op.getResultVectorType();
436+
437+
// Set up convenience transposition table.
438+
ArrayRef<int64_t> transp = op.getPermutation();
439+
440+
if (resType.getRank() == 2 &&
441+
((resType.getShape().front() == 1 &&
442+
!resType.getScalableDims().front()) ||
443+
(resType.getShape().back() == 1 &&
444+
!resType.getScalableDims().back())) &&
445+
transp == ArrayRef<int64_t>({1, 0})) {
446+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
447+
return success();
448+
}
449+
450+
return failure();
451+
}
452+
};
453+
414454
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
415455
/// If the strategy is Shuffle1D, it will be lowered to:
416456
/// vector.shape_cast 2D -> 1D
@@ -483,6 +523,8 @@ class TransposeOp2DToShuffleLowering
483523
void mlir::vector::populateVectorTransposeLoweringPatterns(
484524
RewritePatternSet &patterns, VectorTransformsOptions options,
485525
PatternBenefit benefit) {
526+
patterns.add<Transpose2DWithUnitDimToShapeCast>(patterns.getContext(),
527+
benefit);
486528
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
487529
options, patterns.getContext(), benefit);
488530
}

0 commit comments

Comments
 (0)