Skip to content

Commit d1fc59c

Browse files
authored
[mlir][ArmSME] Rewrite illegal shape_casts to vector.transpose ops (#82985)
This adds a rewrite that converts illegal 2D unit-dim `shape_casts` into `vector.transpose` ops. E.g. ```mlir // Case 1: %a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32> // Case 2: %b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32> ``` Becomes: ```mlir // Case 1: %a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32> // Case 2: %t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32> %b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32> ``` Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR. Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory` a chance to eliminate the illegal types.
1 parent 5c752df commit d1fc59c

File tree

2 files changed

+116
-14
lines changed

2 files changed

+116
-14
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 71 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
4646
"op mask is unsupported for legalization/decomposition");
4747
static constexpr StringLiteral
4848
kMatchFailureNonPermutationMap("op affine map is not a permutation");
49+
static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
50+
"expected transpose from illegal type to legal type");
4951

5052
/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
5153
/// larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -416,6 +418,17 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
416418
}
417419
};
418420

421+
/// A vector type where no fixed dimension comes after a scalable dimension.
422+
bool isLegalVectorType(VectorType vType) {
423+
bool seenFixedDim = false;
424+
for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
425+
seenFixedDim |= !scalableFlag;
426+
if (seenFixedDim && scalableFlag)
427+
return false;
428+
}
429+
return true;
430+
}
431+
419432
/// Lifts an illegal vector.transpose and vector.transfer_read to a
420433
/// memref.subview + memref.transpose, followed by a legal read.
421434
///
@@ -448,16 +461,6 @@ struct LiftIllegalVectorTransposeToMemory
448461
: public OpRewritePattern<vector::TransposeOp> {
449462
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
450463

451-
static bool isIllegalVectorType(VectorType vType) {
452-
bool seenFixedDim = false;
453-
for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
454-
seenFixedDim |= !scalableFlag;
455-
if (seenFixedDim && scalableFlag)
456-
return true;
457-
}
458-
return false;
459-
}
460-
461464
static Value getExtensionSource(Operation *op) {
462465
if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
463466
return op->getOperand(0);
@@ -468,9 +471,9 @@ struct LiftIllegalVectorTransposeToMemory
468471
PatternRewriter &rewriter) const override {
469472
auto sourceType = transposeOp.getSourceVectorType();
470473
auto resultType = transposeOp.getResultVectorType();
471-
if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
472-
return rewriter.notifyMatchFailure(
473-
transposeOp, "expected transpose from illegal type to legal type");
474+
if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
475+
return rewriter.notifyMatchFailure(transposeOp,
476+
kMatchFailureNotIllegalToLegal);
474477

475478
// Look through extend for transfer_read.
476479
Value maybeRead = transposeOp.getVector();
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
556559
}
557560
};
558561

562+
/// A rewrite to turn unit dim transpose-like vector.shape_casts into
563+
/// vector.transposes. The shape_cast has to be from an illegal vector type to a
564+
/// legal one (as defined by isLegalVectorType).
565+
///
566+
/// The reasoning for this is if we've got to this pass and we still have
567+
/// shape_casts of illegal types, then they likely will not cancel out. Turning
568+
/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
569+
/// eliminate them.
570+
///
571+
/// Example:
572+
///
573+
/// BEFORE:
574+
/// ```mlir
575+
/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
576+
/// ```
577+
///
578+
/// AFTER:
579+
/// ```mlir
580+
/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
581+
/// ```
582+
struct ConvertIllegalShapeCastOpsToTransposes
583+
: public OpRewritePattern<vector::ShapeCastOp> {
584+
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
585+
586+
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
587+
PatternRewriter &rewriter) const override {
588+
auto sourceType = shapeCastOp.getSourceVectorType();
589+
auto resultType = shapeCastOp.getResultVectorType();
590+
if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
591+
return rewriter.notifyMatchFailure(shapeCastOp,
592+
kMatchFailureNotIllegalToLegal);
593+
594+
// Note: If we know that `sourceType` is an illegal vector type (and 2D)
595+
// then dim 0 is scalable and dim 1 is fixed.
596+
if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
597+
return rewriter.notifyMatchFailure(
598+
shapeCastOp, "expected source to be a 2D scalable vector with a "
599+
"trailing unit dim");
600+
601+
auto loc = shapeCastOp.getLoc();
602+
auto transpose = rewriter.create<vector::TransposeOp>(
603+
loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
604+
605+
if (resultType.getRank() == 1)
606+
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
607+
transpose);
608+
else
609+
rewriter.replaceOp(shapeCastOp, transpose);
610+
611+
return success();
612+
}
613+
};
614+
559615
struct VectorLegalizationPass
560616
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
561617
void runOnOperation() override {
@@ -576,7 +632,8 @@ struct VectorLegalizationPass
576632
});
577633

578634
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
579-
LiftIllegalVectorTransposeToMemory>(context);
635+
LiftIllegalVectorTransposeToMemory,
636+
ConvertIllegalShapeCastOpsToTransposes>(context);
580637
// Note: High benefit to ensure masked outer products are lowered first.
581638
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
582639
converter, context, 1024);

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,48 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
388388
%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
389389
return %0 : vector<1x[4]xf32>
390390
}
391+
392+
// -----
393+
394+
// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
395+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
396+
func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
397+
// CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
398+
%0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
399+
return %0 : vector<1x[4]xf32>
400+
}
401+
402+
// -----
403+
404+
// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
405+
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
406+
func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
407+
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
408+
// CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
409+
%0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
410+
return %0 : vector<[4]xf32>
411+
}
412+
413+
// -----
414+
415+
// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
416+
func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
417+
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
418+
// CHECK-NOT: vector.shape_cast
419+
%pad = arith.constant 0.0 : f32
420+
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
421+
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
422+
return %cast : vector<1x[4]xf32>
423+
}
424+
425+
// -----
426+
427+
// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
428+
func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
429+
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
430+
// CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
431+
%pad = arith.constant 0.0 : f32
432+
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
433+
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
434+
return %cast : vector<[4]xf32>
435+
}

0 commit comments

Comments
 (0)