Skip to content

Commit 24f7531

Browse files
committed
ArmSME fix
1 parent f2e5417 commit 24f7531

File tree

4 files changed

+14
-27
lines changed

4 files changed

+14
-27
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2351,14 +2351,10 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
23512351
return success();
23522352
}
23532353

2354-
/// For example,
2355-
/// ```
2354+
/// BEFORE:
23562355
/// %0 = vector.extract %arg0[0] : vector<4xf32> from vector<1x4xf32>
2357-
/// ```
2358-
/// becomes
2359-
/// ```
2356+
/// AFTER:
23602357
/// %0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
2361-
/// ```
23622358
struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
23632359
using OpRewritePattern::OpRewritePattern;
23642360
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
@@ -2368,8 +2364,8 @@ struct ExtractToShapeCast final : public OpRewritePattern<vector::ExtractOp> {
23682364
if (!outType)
23692365
return failure();
23702366

2371-
// Negative values in `position` indicates poison, cannot convert to
2372-
// shape_cast
2367+
// Negative values in `position` indicates poison, which cannot be
2368+
// represented with a shape_cast
23732369
if (llvm::any_of(extractOp.getMixedPosition(),
23742370
[](OpFoldResult v) { return !isConstantIntValue(v, 0); }))
23752371
return failure();
@@ -2902,14 +2898,10 @@ struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
29022898
}
29032899
};
29042900

2905-
/// For example,
2906-
/// ```
2901+
/// BEFORE:
29072902
/// %0 = vector.broadcast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2908-
/// ```
2909-
/// becomes
2910-
/// ```
2903+
/// AFTER:
29112904
/// %0 = vector.shape_cast %arg0 : vector<4xi8> to vector<1x1x4xi8>
2912-
/// ```
29132905
struct BroadcastToShapeCast final
29142906
: public OpRewritePattern<vector::BroadcastOp> {
29152907
using OpRewritePattern::OpRewritePattern;
@@ -6465,16 +6457,12 @@ static bool isOrderPreserving(TransposeOp transpose) {
64656457
return true;
64666458
}
64676459

6468-
/// For example,
6469-
/// ```
6460+
/// BEFORE:
64706461
/// %0 = vector.transpose %arg0, [0, 2, 1] :
64716462
/// vector<2x1x2xf32> to vector<2x2x1xf32>
6472-
/// ```
6473-
/// becomes
6474-
/// ```
6463+
/// AFTER:
64756464
/// %0 = vector.shape_cast %arg0 :
64766465
/// vector<2x1x2xf32> to vector<2x2x1xf32>
6477-
/// ```
64786466
struct TransposeToShapeCast final
64796467
: public OpRewritePattern<vector::TransposeOp> {
64806468
using OpRewritePattern::OpRewritePattern;

mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313

14-
#include "mlir/Dialect/Arith/IR/Arith.h"
1514
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1615
#include "mlir/Dialect/UB/IR/UBOps.h"
1716
#include "mlir/Dialect/Utils/IndexingUtils.h"
@@ -382,7 +381,6 @@ class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
382381
vector::VectorTransposeLowering vectorTransposeLowering;
383382
};
384383

385-
386384
/// Rewrite a 2-D vector.transpose as a sequence of shuffle ops.
387385
/// If the strategy is Shuffle1D, it will be lowered to:
388386
/// vector.shape_cast 2D -> 1D
@@ -454,7 +452,6 @@ class TransposeOp2DToShuffleLowering
454452
void mlir::vector::populateVectorTransposeLoweringPatterns(
455453
RewritePatternSet &patterns,
456454
VectorTransposeLowering vectorTransposeLowering, PatternBenefit benefit) {
457-
BroadcastOp::getCanonicalizationPatterns(patterns, patterns.getContext());
458455
patterns.add<TransposeOpLowering, TransposeOp2DToShuffleLowering>(
459456
vectorTransposeLowering, patterns.getContext(), benefit);
460457
}

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -480,11 +480,11 @@ func.func @lift_illegal_transpose_to_memory_with_in_bounds_attr(%a: index, %b: i
480480

481481
// -----
482482

483-
// The pass should do nothing (and not crash).
484-
// CHECK-LABEL: @illegal_transpose_no_defining_source_op
485-
func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
483+
// CHECK-LABEL: @transpose_no_defining_source_op
484+
func.func @transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32>
486485
{
487-
// CHECK: vector.transpose
486+
// CHECK: vector.shape_cast
487+
// CHECK-SAME: vector<[4]x1xf32> to vector<1x[4]xf32>
488488
%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
489489
return %0 : vector<1x[4]xf32>
490490
}

mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,8 @@ func.func @negative_broadcast_transpose_021(%arg0 : vector<3x1x3xi8>) -> vector<
141141
return %1 : vector<3x3x3xi8>
142142
}
143143

144+
// -----
145+
144146
/// +--------------------------------------------------------------------------
145147
/// Tests of ShapeCastOp::fold: shape_cast(transpose) -> shape_cast
146148
/// +--------------------------------------------------------------------------

0 commit comments

Comments
 (0)