Skip to content

Commit b7b6d54

Browse files
authored
[mlir][vector] Add vector.transpose with unit-dim to vector.shape_cast pattern (#72105)
This patch extends the vector.transpose lowering to replace: vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to vector<1xnx<eltty>> with: vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<eltty>> Source with leading unit-dim (inverse) is also replaced. Unit dim must be fixed. Non-unit dim can be scalable. A check is also added to bail out for scalable vectors before unrolling.
1 parent 33b5158 commit b7b6d54

File tree

2 files changed

+92
-0
lines changed

2 files changed

+92
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,27 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
336336
return rewriter.notifyMatchFailure(
337337
op, "Options specifies lowering to shuffle");
338338

339+
// Replace:
340+
// vector.transpose %0, [1, 0] : vector<nx1x<eltty>> to
341+
// vector<1xnxelty>
342+
// with:
343+
// vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnxelty>
344+
//
345+
// Source with leading unit dim (inverse) is also replaced. Unit dim must
346+
// be fixed. Non-unit can be scalable.
347+
if (resType.getRank() == 2 &&
348+
((resType.getShape().front() == 1 &&
349+
!resType.getScalableDims().front()) ||
350+
(resType.getShape().back() == 1 &&
351+
!resType.getScalableDims().back())) &&
352+
transp == ArrayRef<int64_t>({1, 0})) {
353+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, resType, input);
354+
return success();
355+
}
356+
357+
if (inputType.isScalable())
358+
return failure();
359+
339360
// Handle a true 2-D matrix transpose differently when requested.
340361
if (vectorTransformOptions.vectorTransposeLowering ==
341362
vector::VectorTransposeLowering::Flat &&

mlir/test/Dialect/Vector/vector-transpose-lowering.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8
7474
return %0 : vector<1x1x8x8xf32>
7575
}
7676

77+
/// Scalable dim should not be unrolled.
78+
79+
// CHECK-LABEL: func @transpose23_scalable
80+
// CHECK-NOT: vector.extract
81+
// CHECK-NOT: vector.insert
82+
// CHECK: vector.transpose
83+
func.func @transpose23_scalable(%arg0: vector<2x[3]xf32>) -> vector<[3]x2xf32> {
84+
%0 = vector.transpose %arg0, [1, 0] : vector<2x[3]xf32> to vector<[3]x2xf32>
85+
return %0 : vector<[3]x2xf32>
86+
}
87+
7788
module attributes {transform.with_named_sequence} {
7889
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
7990
transform.apply_patterns to %func_op {
@@ -778,3 +789,63 @@ module attributes {transform.with_named_sequence} {
778789
transform.yield
779790
}
780791
}
792+
793+
// -----
794+
795+
/// Transpose of rank-2 vector with leading or trailing unit dim to shape_cast.
796+
797+
// CHECK-LABEL: func @transpose10_4x1xf32
798+
func.func @transpose10_4x1xf32(%arg0: vector<4x1xf32>) -> vector<1x4xf32> {
799+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<4x1xf32> to vector<1x4xf32>
800+
%0 = vector.transpose %arg0, [1, 0] : vector<4x1xf32> to vector<1x4xf32>
801+
return %0 : vector<1x4xf32>
802+
}
803+
804+
// CHECK-LABEL: func @transpose10_nx4x1xf32
805+
func.func @transpose10_nx4x1xf32(%arg0: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
806+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<[4]x1xf32> to vector<1x[4]xf32>
807+
%0 = vector.transpose %arg0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
808+
return %0 : vector<1x[4]xf32>
809+
}
810+
811+
// CHECK-LABEL: func @transpose10_1x4xf32
812+
func.func @transpose10_1x4xf32(%arg0: vector<1x4xf32>) -> vector<4x1xf32> {
813+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
814+
%0 = vector.transpose %arg0, [1, 0] : vector<1x4xf32> to vector<4x1xf32>
815+
return %0 : vector<4x1xf32>
816+
}
817+
818+
// CHECK-LABEL: func @transpose10_1xnx4xf32
819+
func.func @transpose10_1xnx4xf32(%arg0: vector<1x[4]xf32>) -> vector<[4]x1xf32> {
820+
// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<1x[4]xf32> to vector<[4]x1xf32>
821+
%0 = vector.transpose %arg0, [1, 0] : vector<1x[4]xf32> to vector<[4]x1xf32>
822+
return %0 : vector<[4]x1xf32>
823+
}
824+
825+
/// Scalable unit dim should not be lowered to shape_cast.
826+
827+
// CHECK-LABEL: func @transpose10_4xnx1xf32
828+
func.func @transpose10_4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
829+
// CHECK-NOT: vector.shape_cast
830+
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
831+
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
832+
return %0 : vector<[1]x4xf32>
833+
}
834+
835+
// CHECK-LABEL: func @transpose10_nx4xnx1xf32
836+
func.func @transpose10_nx4xnx1xf32(%arg0: vector<4x[1]xf32>) -> vector<[1]x4xf32> {
837+
// CHECK-NOT: vector.shape_cast
838+
// CHECK: vector.transpose %{{.*}} : vector<4x[1]xf32> to vector<[1]x4xf32>
839+
%0 = vector.transpose %arg0, [1, 0] : vector<4x[1]xf32> to vector<[1]x4xf32>
840+
841+
return %0 : vector<[1]x4xf32>
842+
}
843+
844+
module attributes {transform.with_named_sequence} {
845+
transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
846+
transform.apply_patterns to %func_op {
847+
transform.apply_patterns.vector.lower_transpose
848+
} : !transform.op<"func.func">
849+
transform.yield
850+
}
851+
}

0 commit comments

Comments
 (0)