Skip to content

Commit bd64a9b

Browse files
committed
[mlir][Vector] Add fold transpose(shape_cast) -> shape_cast
This folds transpose(shape_cast) into a new shape_cast, when the transpose just permutes a unit dim from the result of the shape_cast. Example: ``` %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32> %1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32> ``` Folds to: ``` %0 = vector.shape_cast %vec : vector<[4]xf32> to vector<1x[4]xf32> ``` This is an (alternate) fix for lowering matmuls to ArmSME. --- Corrected version of llvm#73951.
1 parent a65363d commit bd64a9b

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5548,12 +5548,69 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
55485548
}
55495549
};
55505550

5551+
/// Folds transpose(shape_cast) into a new shape_cast, when the transpose just
5552+
/// permutes a unit dim from the result of the shape_cast.
5553+
class FoldTransposeShapeCast : public OpRewritePattern<TransposeOp> {
5554+
using OpRewritePattern::OpRewritePattern;
5555+
5556+
LogicalResult matchAndRewrite(TransposeOp transpOp,
5557+
PatternRewriter &rewriter) const override {
5558+
Value transposeSrc = transpOp.getVector();
5559+
auto shapeCastOp = transposeSrc.getDefiningOp<vector::ShapeCastOp>();
5560+
if (!shapeCastOp)
5561+
return rewriter.notifyMatchFailure(
5562+
transpOp, "TransposeOp source is not ShapeCastOp");
5563+
5564+
auto shapeCastSourceType = shapeCastOp.getSourceVectorType();
5565+
auto sourceType = transpOp.getSourceVectorType();
5566+
auto resultType = transpOp.getResultVectorType();
5567+
auto permutation = transpOp.getPermutation();
5568+
5569+
auto getSourceDim = [&](int64_t index) {
5570+
return std::make_pair(sourceType.getDimSize(index),
5571+
sourceType.getScalableDims()[index]);
5572+
};
5573+
5574+
auto unitDim = std::make_pair(int64_t(1), false);
5575+
for (auto [i, resultIndex] : llvm::enumerate(permutation)) {
5576+
int64_t sourceIndex = int64_t(i);
5577+
if (sourceIndex == resultIndex)
5578+
continue;
5579+
// Is the transpose crosses any non-unit dims this is also a non-unit
5580+
// transpose, so we restrict the index distance to 1:
5581+
// e.g.:
5582+
// vector.transpose %0, [0, 3, 2, 1] : vector<2x1x2x5xi32> to
5583+
// vector<2x5x2x1xi32>
5584+
// This could be relaxed to checking if all dims between the `sourceIndex`
5585+
// and `resultIndex` are unit dims (in both the source and result vector
5586+
// type).
5587+
if (std::abs(sourceIndex - resultIndex) != 1 ||
5588+
(getSourceDim(sourceIndex) != unitDim &&
5589+
getSourceDim(resultIndex) != unitDim)) {
5590+
return rewriter.notifyMatchFailure(
5591+
transpOp, "TransposeOp has non-unit permutation");
5592+
}
5593+
}
5594+
5595+
if (!isValidShapeCast(shapeCastSourceType.getShape(),
5596+
resultType.getShape()))
5597+
return rewriter.notifyMatchFailure(
5598+
transpOp, "TransposeOp cannot fold into valid ShapeCastOp");
5599+
5600+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(transpOp, resultType,
5601+
shapeCastOp.getSource());
5602+
5603+
return success();
5604+
};
5605+
};
5606+
55515607
} // namespace
55525608

55535609
void vector::TransposeOp::getCanonicalizationPatterns(
55545610
RewritePatternSet &results, MLIRContext *context) {
55555611
results.add<FoldTransposeCreateMask, FoldTransposedScalarBroadcast,
5556-
TransposeFolder, FoldTransposeSplat>(context);
5612+
TransposeFolder, FoldTransposeSplat, FoldTransposeShapeCast>(
5613+
context);
55575614
}
55585615

55595616
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,70 @@ func.func @create_mask_transpose_to_transposed_create_mask(
6767

6868
// -----
6969

70+
// CHECK-LABEL: transposed_unit_dim_shape_cast_to_shape_cast
71+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]xf32>
72+
func.func @transposed_unit_dim_shape_cast_to_shape_cast(%vec: vector<[4]xf32>) -> vector<1x[4]xf32> {
73+
// CHECK: vector.shape_cast %[[VEC]] : vector<[4]xf32> to vector<1x[4]xf32>
74+
// CHECK-NOT: vector.transpose
75+
%0 = vector.shape_cast %vec : vector<[4]xf32> to vector<[4]x1xf32>
76+
// 0 -> 1 is a unit dim:
77+
%1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
78+
return %1 : vector<1x[4]xf32>
79+
}
80+
81+
// -----
82+
83+
// CHECK-LABEL: transposed_multiple_unit_dim_shape_cast_to_shape_cast
84+
// CHECK-SAME: %[[VEC:.*]]: vector<120xf32>
85+
func.func @transposed_multiple_unit_dim_shape_cast_to_shape_cast(%vec: vector<120xf32>) -> vector<2x1x3x4x1x5xf32> {
86+
// CHECK: vector.shape_cast %[[VEC]] : vector<120xf32> to vector<2x1x3x4x1x5xf32>
87+
// CHECK-NOT: vector.transpose
88+
%0 = vector.shape_cast %vec : vector<120xf32> to vector<1x2x3x4x5x1xf32>
89+
// 0 -> 1 and 4 -> 5 are unit dims:
90+
%1 = vector.transpose %0, [1, 0, 2, 3, 5, 4] : vector<1x2x3x4x5x1xf32> to vector<2x1x3x4x1x5xf32>
91+
return %1 : vector<2x1x3x4x1x5xf32>
92+
}
93+
94+
// -----
95+
96+
// CHECK-LABEL: transposed_non_unit_dim_shape_cast_0
97+
// CHECK-SAME: %[[VEC:.*]]: vector<120xf32>
98+
func.func @transposed_non_unit_dim_shape_cast_0(%vec: vector<120xf32>) -> vector<1x3x2x4x1x5xf32> {
99+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<120xf32> to vector<1x2x3x4x5x1xf32>
100+
// CHECK-NEXT: vector.transpose %[[SHAPE_CAST]], [0, 2, 1, 3, 5, 4] : vector<1x2x3x4x5x1xf32> to vector<1x3x2x4x1x5xf32>
101+
%0 = vector.shape_cast %vec : vector<120xf32> to vector<1x2x3x4x5x1xf32>
102+
// 1 -> 2 is a non-unit dim:
103+
%1 = vector.transpose %0, [0, 2, 1, 3, 5, 4] : vector<1x2x3x4x5x1xf32> to vector<1x3x2x4x1x5xf32>
104+
return %1 : vector<1x3x2x4x1x5xf32>
105+
}
106+
// -----
107+
108+
// CHECK-LABEL: transposed_non_unit_dim_shape_cast_1
109+
// CHECK-SAME: %[[VEC:.*]]: vector<1x256x256xf32>
110+
func.func @transposed_non_unit_dim_shape_cast_1(%vec: vector<1x256x256xf32>) -> vector<256x256xf32> {
111+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<1x256x256xf32> to vector<256x256xf32>
112+
// CHECK-NEXT: vector.transpose %[[SHAPE_CAST]], [1, 0] : vector<256x256xf32> to vector<256x256xf32>
113+
%0 = vector.shape_cast %vec : vector<1x256x256xf32> to vector<256x256xf32>
114+
// 0 -> 1 is a non-unit dim:
115+
%1 = vector.transpose %0, [1, 0] : vector<256x256xf32> to vector<256x256xf32>
116+
return %1 : vector<256x256xf32>
117+
}
118+
119+
// -----
120+
121+
// CHECK-LABEL: transposed_non_unit_dim_shape_cast_2
122+
// CHECK-SAME: %[[VEC:.*]]: vector<20xf32>
123+
func.func @transposed_non_unit_dim_shape_cast_2(%vec: vector<20xf32>) -> vector<2x5x2x1xf32> {
124+
// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[VEC]] : vector<20xf32> to vector<2x1x2x5xf32>
125+
// CHECK-NEXT: vector.transpose %[[SHAPE_CAST]], [0, 3, 2, 1] : vector<2x1x2x5xf32> to vector<2x5x2x1xf32>
126+
%0 = vector.shape_cast %vec : vector<20xf32> to vector<2x1x2x5xf32>
127+
// 1 -> 3 transposes non-unit dims:
128+
%1 = vector.transpose %0, [0, 3, 2, 1] : vector<2x1x2x5xf32> to vector<2x5x2x1xf32>
129+
return %1 : vector<2x5x2x1xf32>
130+
}
131+
132+
// -----
133+
70134
// CHECK-LABEL: extract_from_create_mask
71135
// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index
72136
func.func @extract_from_create_mask(%dim0: index, %dim1: index) -> vector<[4]x[4]xi1> {

0 commit comments

Comments
 (0)