Skip to content

Commit e673522

Browse files
committed
reinstate folders
1 parent 24f7531 commit e673522

File tree

5 files changed

+99
-46
lines changed

5 files changed

+99
-46
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5869,6 +5869,30 @@ LogicalResult ShapeCastOp::verify() {
58695869
return success();
58705870
}
58715871

5872+
/// Return true if `transpose` does not permute a pair of non-unit dims.
5873+
/// By `order preserving` we mean that the flattened versions of the input and
5874+
/// output vectors are (numerically) identical. In other words `transpose` is
5875+
/// effectively a shape cast.
5876+
static bool isOrderPreserving(TransposeOp transpose) {
5877+
ArrayRef<int64_t> permutation = transpose.getPermutation();
5878+
VectorType sourceType = transpose.getSourceVectorType();
5879+
ArrayRef<int64_t> inShape = sourceType.getShape();
5880+
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
5881+
auto isNonScalableUnitDim = [&](int64_t dim) {
5882+
return inShape[dim] == 1 && !inDimIsScalable[dim];
5883+
};
5884+
int64_t current = 0;
5885+
for (auto p : permutation) {
5886+
if (!isNonScalableUnitDim(p)) {
5887+
if (p < current) {
5888+
return false;
5889+
}
5890+
current = p;
5891+
}
5892+
}
5893+
return true;
5894+
}
5895+
58725896
OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
58735897

58745898
VectorType resultType = getType();
@@ -5883,6 +5907,22 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
58835907
return getResult();
58845908
}
58855909

5910+
// shape_cast(transpose(x)) -> shape_cast(x)
5911+
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5912+
if (isOrderPreserving(transpose)) {
5913+
setOperand(transpose.getVector());
5914+
return getResult();
5915+
}
5916+
return {};
5917+
}
5918+
5919+
// Y = shape_cast(broadcast(X))
5920+
// -> X, if X and Y have same type
5921+
if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
5922+
if (bcastOp.getSourceType() == resultType)
5923+
return bcastOp.getSource();
5924+
}
5925+
58865926
// shape_cast(constant) -> constant
58875927
if (auto splatAttr =
58885928
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
@@ -6219,6 +6259,21 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) {
62196259
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getVector()))
62206260
return ub::PoisonAttr::get(getContext());
62216261

6262+
// Eliminate identity transposes, and more generally any transposes that
6263+
// preserves the shape without permuting elements.
6264+
//
6265+
// Examples of what to fold:
6266+
// %0 = vector.transpose %arg, [0, 1] : vector<1x1xi8> to vector<1x1xi8>
6267+
// %0 = vector.transpose %arg, [0, 1] : vector<2x2xi8> to vector<2x2xi8>
6268+
// %0 = vector.transpose %arg, [1, 0] : vector<1x1xi8> to vector<1x1xi8>
6269+
//
6270+
// Example of what NOT to fold:
6271+
//
6272+
// %0 = vector.transpose %arg, [1, 0] : vector<2x2xi8> to vector<2x2xi8>
6273+
if (getSourceVectorType() == getResultVectorType() &&
6274+
isOrderPreserving(*this))
6275+
return getVector();
6276+
62226277
return {};
62236278
}
62246279

@@ -6433,30 +6488,6 @@ class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
64336488
}
64346489
};
64356490

6436-
/// Return true if `transpose` does not permute a pair of non-unit dims.
6437-
/// By `order preserving` we mean that the flattened versions of the input and
6438-
/// output vectors are (numerically) identical. In other words `transpose` is
6439-
/// effectively a shape cast.
6440-
static bool isOrderPreserving(TransposeOp transpose) {
6441-
ArrayRef<int64_t> permutation = transpose.getPermutation();
6442-
VectorType sourceType = transpose.getSourceVectorType();
6443-
ArrayRef<int64_t> inShape = sourceType.getShape();
6444-
ArrayRef<bool> inDimIsScalable = sourceType.getScalableDims();
6445-
auto isNonScalableUnitDim = [&](int64_t dim) {
6446-
return inShape[dim] == 1 && !inDimIsScalable[dim];
6447-
};
6448-
int64_t current = 0;
6449-
for (auto p : permutation) {
6450-
if (!isNonScalableUnitDim(p)) {
6451-
if (p < current) {
6452-
return false;
6453-
}
6454-
current = p;
6455-
}
6456-
}
6457-
return true;
6458-
}
6459-
64606491
/// BEFORE:
64616492
/// %0 = vector.transpose %arg0, [0, 2, 1] :
64626493
/// vector<2x1x2xf32> to vector<2x2x1xf32>

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,16 +451,25 @@ func.func @extract_strided_fold_insert(%a: vector<2x8xf32>, %b: vector<1x4xf32>,
451451
// -----
452452

453453
// CHECK-LABEL: transpose_3D_identity
454-
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
454+
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3x2xf32>)
455+
// CHECK-NEXT: return [[ARG]]
455456
func.func @transpose_3D_identity(%arg : vector<4x3x2xf32>) -> vector<4x3x2xf32> {
456-
// CHECK-NOT: transpose
457457
%0 = vector.transpose %arg, [0, 1, 2] : vector<4x3x2xf32> to vector<4x3x2xf32>
458-
// CHECK-NEXT: return [[ARG]]
459458
return %0 : vector<4x3x2xf32>
460459
}
461460

462461
// -----
463462

463+
// CHECK-LABEL: transpose_0D_identity
464+
// CHECK-SAME: ([[ARG:%.*]]: vector<i8>)
465+
// CHECK-NEXT: return [[ARG]]
466+
func.func @transpose_0D_identity(%arg : vector<i8>) -> vector<i8> {
467+
%0 = vector.transpose %arg, [] : vector<i8> to vector<i8>
468+
return %0 : vector<i8>
469+
}
470+
471+
// -----
472+
464473
// CHECK-LABEL: transpose_2D_sequence
465474
// CHECK-SAME: ([[ARG:%.*]]: vector<4x3xf32>)
466475
func.func @transpose_2D_sequence(%arg : vector<4x3xf32>) -> vector<4x3xf32> {

mlir/test/Dialect/Vector/canonicalize/vector-shape-cast.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
// +----------------------------------------
77

88
// CHECK-LABEL: @broadcast_to_shape_cast
9-
// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
10-
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
11-
// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
9+
// CHECK-SAME: %[[ARG0:.*]]: vector<4xi8>
10+
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
11+
// CHECK-NEXT: return %[[SCAST]] : vector<1x1x4xi8>
1212
func.func @broadcast_to_shape_cast(%arg0 : vector<4xi8>) -> vector<1x1x4xi8> {
1313
%0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
1414
return %0 : vector<1x1x4xi8>
@@ -49,9 +49,9 @@ func.func @negative_broadcast_scalar_to_shape_cast(%arg0 : i8) -> vector<1xi8> {
4949
// 2 -> 1
5050
// Because 0 < 1, this permutation is order preserving and effectively a shape_cast.
5151
// CHECK-LABEL: @transpose_to_shape_cast
52-
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
53-
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
54-
// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
52+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
53+
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
54+
// CHECK-NEXT: return %[[SCAST]] : vector<2x2x1xf32>
5555
func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
5656
%0 = vector.transpose %arg0, [0, 2, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
5757
return %0 : vector<2x2x1xf32>
@@ -64,10 +64,10 @@ func.func @transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf3
6464
// 2 -> 4
6565
// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
6666
// CHECK-LABEL: @shape_cast_of_transpose
67-
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
68-
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
69-
// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
70-
// CHECK: return %[[SHAPE_CAST]]
67+
// CHECK-SAME: %[[ARG:.*]]: vector<1x4x4x1x1xi8>)
68+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
69+
// CHECK-SAME: vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
70+
// CHECK: return %[[SHAPE_CAST]]
7171
func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x1x1x1x4xi8> {
7272
%0 = vector.transpose %arg, [1, 0, 3, 4, 2] : vector<1x4x4x1x1xi8> to vector<4x1x1x1x4xi8>
7373
return %0 : vector<4x1x1x1x4xi8>
@@ -101,8 +101,8 @@ func.func @negative_transpose_to_shape_cast(%arg : vector<1x4x4x1xi8>) -> vector
101101
// -----
102102

103103
// CHECK-LABEL: @shape_cast_of_transpose_scalable
104-
// CHECK-NEXT: vector.shape_cast
105-
// CHECK-NEXT: return
104+
// CHECK-NEXT: vector.shape_cast
105+
// CHECK-NEXT: return
106106
func.func @shape_cast_of_transpose_scalable(%arg : vector<[4]x1xi8>) -> vector<[4]xi8> {
107107
%0 = vector.transpose %arg, [1, 0] : vector<[4]x1xi8> to vector<1x[4]xi8>
108108
%1 = vector.shape_cast %0 : vector<1x[4]xi8> to vector<[4]xi8>
@@ -125,9 +125,9 @@ func.func @transpose_of_shape_cast_scalable(%arg : vector<[4]xi8>) -> vector<[4]
125125
// A test where a transpose cannot be transformed to a shape_cast because it is not order
126126
// preserving
127127
// CHECK-LABEL: @negative_transpose_to_shape_cast
128-
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
129-
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
130-
// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
128+
// CHECK-SAME: %[[ARG0:.*]]: vector<2x1x2xf32>
129+
// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[ARG0]], [2, 0, 1]
130+
// CHECK-NEXT: return %[[TRANSPOSE]] : vector<2x2x1xf32>
131131
func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector<2x2x1xf32> {
132132
%0 = vector.transpose %arg0, [2, 0, 1] : vector<2x1x2xf32> to vector<2x2x1xf32>
133133
return %0 : vector<2x2x1xf32>
@@ -140,9 +140,9 @@ func.func @negative_transpose_to_shape_cast(%arg0 : vector<2x1x2xf32>) -> vector
140140
// +----------------------------------------
141141

142142
// CHECK-LABEL: @extract_to_shape_cast
143-
// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
144-
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
145-
// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
143+
// CHECK-SAME: %[[ARG0:.*]]: vector<1x4xf32>
144+
// CHECK-NEXT: %[[SCAST:.*]] = vector.shape_cast %[[ARG0]]
145+
// CHECK-NEXT: return %[[SCAST]] : vector<4xf32>
146146
func.func @extract_to_shape_cast(%arg0 : vector<1x4xf32>) -> vector<4xf32> {
147147
%0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
148148
return %0 : vector<4xf32>

mlir/test/Dialect/Vector/drop-unit-dims-with-shape-cast.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,18 @@ func.func @transpose_with_scalable_unit_dims(%vec: vector<[1]x1x2x4x1xf32>) -> v
188188

189189
// -----
190190

191+
func.func @transpose_with_all_unit_dims(%vec: vector<1x1x1xf32>) -> vector<1x1x1xf32> {
192+
%res = vector.transpose %vec, [0, 2, 1] : vector<1x1x1xf32> to vector<1x1x1xf32>
193+
return %res : vector<1x1x1xf32>
194+
}
195+
// The `vec` is returned because there are other flattening patterns that fold
196+
// vector.shape_cast ops away.
197+
// CHECK-LABEL: func.func @transpose_with_all_unit_dims
198+
// CHECK-SAME: %[[VEC:.[a-zA-Z0-9]+]]
199+
// CHECK-NEXT: return %[[VEC]]
200+
201+
// -----
202+
191203
func.func @negative_transpose_with_no_unit_dims(%vec: vector<4x2x3xf32>) -> vector<4x3x2xf32> {
192204
%res = vector.transpose %vec, [0, 2, 1] : vector<4x2x3xf32> to vector<4x3x2xf32>
193205
return %res : vector<4x3x2xf32>

mlir/test/Dialect/Vector/single-fold.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ func.func @fold_insert_in_single_pass() -> vector<2xf16> {
3535
// CHECK: arith.constant dense<[0.000000e+00, 2.500000e+00]> : vector<2xf16>
3636
%0 = vector.insert %c2, %cst [%c1] : f16 into vector<2xf16>
3737
return %0 : vector<2xf16>
38-
}
38+
}
39+

0 commit comments

Comments
 (0)