-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Rewrite illegal shape_casts
to vector.transpose
ops
#82985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis adds a rewrite that converts illegal 2D unit-dim E.g. // Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32> Becomes: // Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32> Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR. Rewriting them as a transpose gives Full diff: https://github.com/llvm/llvm-project/pull/82985.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 11f8bc04b21844..55b20e5a477d4e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -46,6 +46,8 @@ static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
"op mask is unsupported for legalization/decomposition");
static constexpr StringLiteral
kMatchFailureNonPermutationMap("op affine map is not a permutation");
+static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
+ "expected transpose from illegal type to legal type");
/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
/// larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -416,6 +418,17 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
}
};
+/// A vector type where no fixed dimension comes after a scalable dimension.
+bool isLegalVectorType(VectorType vType) {
+ bool seenFixedDim = false;
+ for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+ seenFixedDim |= !scalableFlag;
+ if (seenFixedDim && scalableFlag)
+ return false;
+ }
+ return true;
+}
+
/// Lifts an illegal vector.transpose and vector.transfer_read to a
/// memref.subview + memref.transpose, followed by a legal read.
///
@@ -448,16 +461,6 @@ struct LiftIllegalVectorTransposeToMemory
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
- static bool isIllegalVectorType(VectorType vType) {
- bool seenFixedDim = false;
- for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
- seenFixedDim |= !scalableFlag;
- if (seenFixedDim && scalableFlag)
- return true;
- }
- return false;
- }
-
static Value getExtensionSource(Operation *op) {
if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
return op->getOperand(0);
@@ -468,9 +471,9 @@ struct LiftIllegalVectorTransposeToMemory
PatternRewriter &rewriter) const override {
auto sourceType = transposeOp.getSourceVectorType();
auto resultType = transposeOp.getResultVectorType();
- if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
- return rewriter.notifyMatchFailure(
- transposeOp, "expected transpose from illegal type to legal type");
+ if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(transposeOp,
+ kMatchFailureNotIllegalToLegal);
// Look through extend for transfer_read.
Value maybeRead = transposeOp.getVector();
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
}
};
+/// A rewrite to turn unit dim transpose-like vector.shape_cast into a
+/// vector.transpose. The shape_cast has to be from an illegal vector type to a
+/// legal one (as defined by isLegalVectorType).
+///
+/// The reasoning for this is if we've got to this pass and we still have
+/// shape_casts of illegal types, then they likely will not cancel out. Turning
+/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
+/// eliminate them.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+/// ```
+struct ConvertIllegalShapeCastOpsToTransposes
+ : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = shapeCastOp.getSourceVectorType();
+ auto resultType = shapeCastOp.getResultVectorType();
+ if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(shapeCastOp,
+ kMatchFailureNotIllegalToLegal);
+
+ // Note: If we know the that is source is an illegal vector type (and 2D)
+ // then dim 0 is scalable and dim 1 is fixed.
+ if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "expected source to be a 2D scalable vector with a "
+ "trailing unit dim");
+
+ auto loc = shapeCastOp.getLoc();
+ auto transpose = rewriter.create<vector::TransposeOp>(
+ loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
+
+ if (resultType.getRank() == 1)
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
+ transpose);
+ else
+ rewriter.replaceOp(shapeCastOp, transpose);
+
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
@@ -576,7 +632,8 @@ struct VectorLegalizationPass
});
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
- LiftIllegalVectorTransposeToMemory>(context);
+ LiftIllegalVectorTransposeToMemory,
+ ConvertIllegalShapeCastOpsToTransposes>(context);
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index bf0b58ff4cf073..f8be697548c197 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -388,3 +388,48 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+ // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %0 : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
+ %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
+func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+ // CHECK-NOT: vector.shape_cast
+ %pad = arith.constant 0.0 : f32
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+ %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %cast : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
+func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+ // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
+ %pad = arith.constant 0.0 : f32
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+ %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
+ return %cast : vector<[4]xf32>
+}
|
@llvm/pr-subscribers-mlir-sme Author: Benjamin Maxwell (MacDue) ChangesThis adds a rewrite that converts illegal 2D unit-dim E.g. // Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32> Becomes: // Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32> Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR. Rewriting them as a transpose gives Full diff: https://github.com/llvm/llvm-project/pull/82985.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 11f8bc04b21844..55b20e5a477d4e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -46,6 +46,8 @@ static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
"op mask is unsupported for legalization/decomposition");
static constexpr StringLiteral
kMatchFailureNonPermutationMap("op affine map is not a permutation");
+static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
+ "expected transpose from illegal type to legal type");
/// An SMESubTile represents a single SME-sized sub-tile from decomposing a
/// larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -416,6 +418,17 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
}
};
+/// A vector type where no fixed dimension comes after a scalable dimension.
+bool isLegalVectorType(VectorType vType) {
+ bool seenFixedDim = false;
+ for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+ seenFixedDim |= !scalableFlag;
+ if (seenFixedDim && scalableFlag)
+ return false;
+ }
+ return true;
+}
+
/// Lifts an illegal vector.transpose and vector.transfer_read to a
/// memref.subview + memref.transpose, followed by a legal read.
///
@@ -448,16 +461,6 @@ struct LiftIllegalVectorTransposeToMemory
: public OpRewritePattern<vector::TransposeOp> {
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
- static bool isIllegalVectorType(VectorType vType) {
- bool seenFixedDim = false;
- for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
- seenFixedDim |= !scalableFlag;
- if (seenFixedDim && scalableFlag)
- return true;
- }
- return false;
- }
-
static Value getExtensionSource(Operation *op) {
if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
return op->getOperand(0);
@@ -468,9 +471,9 @@ struct LiftIllegalVectorTransposeToMemory
PatternRewriter &rewriter) const override {
auto sourceType = transposeOp.getSourceVectorType();
auto resultType = transposeOp.getResultVectorType();
- if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
- return rewriter.notifyMatchFailure(
- transposeOp, "expected transpose from illegal type to legal type");
+ if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(transposeOp,
+ kMatchFailureNotIllegalToLegal);
// Look through extend for transfer_read.
Value maybeRead = transposeOp.getVector();
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
}
};
+/// A rewrite to turn unit dim transpose-like vector.shape_cast into a
+/// vector.transpose. The shape_cast has to be from an illegal vector type to a
+/// legal one (as defined by isLegalVectorType).
+///
+/// The reasoning for this is if we've got to this pass and we still have
+/// shape_casts of illegal types, then they likely will not cancel out. Turning
+/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
+/// eliminate them.
+///
+/// Example:
+///
+/// BEFORE:
+/// ```mlir
+/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
+/// ```
+///
+/// AFTER:
+/// ```mlir
+/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+/// ```
+struct ConvertIllegalShapeCastOpsToTransposes
+ : public OpRewritePattern<vector::ShapeCastOp> {
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ auto sourceType = shapeCastOp.getSourceVectorType();
+ auto resultType = shapeCastOp.getResultVectorType();
+ if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+ return rewriter.notifyMatchFailure(shapeCastOp,
+ kMatchFailureNotIllegalToLegal);
+
+ // Note: If we know the that is source is an illegal vector type (and 2D)
+ // then dim 0 is scalable and dim 1 is fixed.
+ if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "expected source to be a 2D scalable vector with a "
+ "trailing unit dim");
+
+ auto loc = shapeCastOp.getLoc();
+ auto transpose = rewriter.create<vector::TransposeOp>(
+ loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
+
+ if (resultType.getRank() == 1)
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
+ transpose);
+ else
+ rewriter.replaceOp(shapeCastOp, transpose);
+
+ return success();
+ }
+};
+
struct VectorLegalizationPass
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
void runOnOperation() override {
@@ -576,7 +632,8 @@ struct VectorLegalizationPass
});
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
- LiftIllegalVectorTransposeToMemory>(context);
+ LiftIllegalVectorTransposeToMemory,
+ ConvertIllegalShapeCastOpsToTransposes>(context);
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index bf0b58ff4cf073..f8be697548c197 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -388,3 +388,48 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
%0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
return %0 : vector<1x[4]xf32>
}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+ // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %0 : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
+// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
+ // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+ // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
+ %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
+ return %0 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
+func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+ // CHECK-NOT: vector.shape_cast
+ %pad = arith.constant 0.0 : f32
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+ %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
+ return %cast : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
+func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
+ // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+ // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
+ %pad = arith.constant 0.0 : f32
+ %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+ %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
+ return %cast : vector<[4]xf32>
+}
|
9865dd5
to
b7edfde
Compare
...
something not clear to me here, |
In IREE the generic vector lowering happens quite a bit earlier than the ArmSME pipeline (and it's not part of |
ok, thanks for clarifying |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one minor comment otherwise LGTM cheers
This adds a rewrite that converts illegal 2D unit-dim `shape_casts` into `vector.transpose` ops. E.g. ```mlir // Case 1: %a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32> // Case 2: %b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32> ``` Becomes: ``` // Case 1: %a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32> // Case 2: %t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32> %b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32> ``` Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR. Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory` a chance to eliminate the illegal types.
b7edfde
to
517e5f0
Compare
This adds a rewrite that converts illegal 2D unit-dim
shape_casts
intovector.transpose
ops.E.g.
Becomes:
Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR.
Rewriting them as a transpose gives
LiftIllegalVectorTransposeToMemory
a chance to eliminate the illegal types.