-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes
#139706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes
#139706
Conversation
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/139706.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 95965872f4098..51750f0bb9694 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
}
};
-/// A rewrite to turn unit dim transpose-like vector.shape_casts into
-/// vector.transposes. The shape_cast has to be from an illegal vector type to a
-/// legal one (as defined by isLegalVectorType).
-///
-/// The reasoning for this is if we've got to this pass and we still have
-/// shape_casts of illegal types, then they likely will not cancel out. Turning
-/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
-/// eliminate them.
-///
-/// Example:
-///
-/// BEFORE:
-/// ```mlir
-/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-struct ConvertIllegalShapeCastOpsToTransposes
- : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto sourceType = shapeCastOp.getSourceVectorType();
- auto resultType = shapeCastOp.getResultVectorType();
- if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
- return rewriter.notifyMatchFailure(shapeCastOp,
- kMatchFailureNotIllegalToLegal);
-
- // Note: If we know that `sourceType` is an illegal vector type (and 2D)
- // then dim 0 is scalable and dim 1 is fixed.
- if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
- return rewriter.notifyMatchFailure(
- shapeCastOp, "expected source to be a 2D scalable vector with a "
- "trailing unit dim");
-
- auto loc = shapeCastOp.getLoc();
- auto transpose = rewriter.create<vector::TransposeOp>(
- loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
-
- if (resultType.getRank() == 1)
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
- transpose);
- else
- rewriter.replaceOp(shapeCastOp, transpose);
-
- return success();
- }
-};
-
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
RewritePatternSet rewritePatterns(context);
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..83a287d29d773 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
- // This folder does
- // shape_cast(transpose) -> shape_cast
- // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
- // shape_cast -> shape_cast(transpose)
- // i.e. the complete opposite. When paired, these 2 patterns can cause
- // infinite cycles in pattern rewriting.
- // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
- // vectors, so by disabling this folder for scalable vectors the
- // cycle is avoided.
- // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
- // still needed. If it's not, then we can fold here.
- if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+ if (isOrderPreserving(transpose)) {
setOperand(transpose.getVector());
return getResult();
}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d56df9814f173..6e6615c243d2a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
// -----
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
- return %0 : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
-func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %cast : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
-func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
- return %cast : vector<[4]xf32>
-}
-
-// -----
-
// CHECK-LABEL: @multi_tile_splat
func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
{
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e47578bc80719..625b4a9c53e42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
// -----
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// (same as the example above, but one of the dims is scalable)
+// CHECK-LABEL: @transpose_shape_cast_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8>
+func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+ %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+ : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
+ return %1 : vector<[4]x4xi8>
+}
+
+// -----
+
// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
// 1 -> 2
// 2 -> 1
@@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi
// -----
-// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-LABEL: @transpose_shape_cast_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+ %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+ %1 = vector.transpose %0, [0, 2, 1]
+ : vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
+ return %1 : vector<[6]x1x1xi8>
+}
+
+// -----
+
+// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
+// (hence no folding).
+// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
// CHECK: vector.shape_cast
// CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
%0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
return %1 : vector<4x[1]xi8>
|
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/139706.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 95965872f4098..51750f0bb9694 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
}
};
-/// A rewrite to turn unit dim transpose-like vector.shape_casts into
-/// vector.transposes. The shape_cast has to be from an illegal vector type to a
-/// legal one (as defined by isLegalVectorType).
-///
-/// The reasoning for this is if we've got to this pass and we still have
-/// shape_casts of illegal types, then they likely will not cancel out. Turning
-/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
-/// eliminate them.
-///
-/// Example:
-///
-/// BEFORE:
-/// ```mlir
-/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-///
-/// AFTER:
-/// ```mlir
-/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-/// ```
-struct ConvertIllegalShapeCastOpsToTransposes
- : public OpRewritePattern<vector::ShapeCastOp> {
- using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
- PatternRewriter &rewriter) const override {
- auto sourceType = shapeCastOp.getSourceVectorType();
- auto resultType = shapeCastOp.getResultVectorType();
- if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
- return rewriter.notifyMatchFailure(shapeCastOp,
- kMatchFailureNotIllegalToLegal);
-
- // Note: If we know that `sourceType` is an illegal vector type (and 2D)
- // then dim 0 is scalable and dim 1 is fixed.
- if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
- return rewriter.notifyMatchFailure(
- shapeCastOp, "expected source to be a 2D scalable vector with a "
- "trailing unit dim");
-
- auto loc = shapeCastOp.getLoc();
- auto transpose = rewriter.create<vector::TransposeOp>(
- loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
-
- if (resultType.getRank() == 1)
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
- transpose);
- else
- rewriter.replaceOp(shapeCastOp, transpose);
-
- return success();
- }
-};
-
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
/// the ZA state. This workaround rewrite to support these transposes when ZA is
/// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
RewritePatternSet rewritePatterns(context);
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
LiftIllegalVectorTransposeToMemory,
- ConvertIllegalShapeCastOpsToTransposes,
LowerIllegalTransposeStoreViaZA>(context);
if (failed(
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..83a287d29d773 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
// shape_cast(transpose(x)) -> shape_cast(x)
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
- // This folder does
- // shape_cast(transpose) -> shape_cast
- // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
- // shape_cast -> shape_cast(transpose)
- // i.e. the complete opposite. When paired, these 2 patterns can cause
- // infinite cycles in pattern rewriting.
- // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
- // vectors, so by disabling this folder for scalable vectors the
- // cycle is avoided.
- // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
- // still needed. If it's not, then we can fold here.
- if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+ if (isOrderPreserving(transpose)) {
setOperand(transpose.getVector());
return getResult();
}
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d56df9814f173..6e6615c243d2a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
// -----
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
-// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
- // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
- // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
- %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
- return %0 : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
-func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
- return %cast : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
-func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
- // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
- // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
- %pad = arith.constant 0.0 : f32
- %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
- %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
- return %cast : vector<[4]xf32>
-}
-
-// -----
-
// CHECK-LABEL: @multi_tile_splat
func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
{
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e47578bc80719..625b4a9c53e42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
// -----
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// (same as the example above, but one of the dims is scalable)
+// CHECK-LABEL: @transpose_shape_cast_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8>
+func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+ %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+ : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
+ %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
+ return %1 : vector<[4]x4xi8>
+}
+
+// -----
+
// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
// 1 -> 2
// 2 -> 1
@@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi
// -----
-// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-LABEL: @transpose_shape_cast_scalable
+// CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+// CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+// CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+// CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+ %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+ %1 = vector.transpose %0, [0, 2, 1]
+ : vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
+ return %1 : vector<[6]x1x1xi8>
+}
+
+// -----
+
+// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
+// (hence no folding).
+// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
// CHECK: vector.shape_cast
// CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
%0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
%1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
return %1 : vector<4x[1]xi8>
|
ConvertIllegalShapeCastOpsToTransposes
06ca6e9
to
0bf798d
Compare
// CHECK-NOT: vector.shape_cast | ||
%pad = arith.constant 0.0 : f32 | ||
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32> | ||
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32> | ||
return %cast : vector<1x[4]xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you've tested, but to know if this rewrite is still needed or not this test case should still be possible to lower to LLVM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Ben!
I'm not sure what you've tested
I used our e2e tests - from what I can tell, we don't generate such code anymore.
this test case should still be possible to lower to LLVM
Indeed. @momchil-velikov , since you are working on a generic pattern for "xfer_read with non-trailing scalable dims", could you make sure that this example lowers with your patch?
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
I will wait for Momchil to upload his patch before progressing this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see 2 PRs of @momchil-velikov in llvm-project that might be related, but just checking in that this PR is still on the radar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's on the radar and hasn't slipped through the cracks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
I just tried it and it does not lower. And, AFAICT, it shouldn't, as the last dimension of the memerf (?
) and the vector (1
) do not match and the read cannot be inferred to be contiguous, e.g. if we're reading from a memref with dynamic dimensions 4 and 2:
[*][]
[*][]
[*][]
[*][]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}
: memref<?x2xf32>, vector<[4]x2xf32>
is lowered, though (with an implication that %b
is zero, along the way).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just tried it and it does not lower. And, AFAICT, it shouldn't, as the last dimension of the memerf
This lowering does not depend on the memref/transpose being contiguous. It lowers the transfer_read
to a memref.transpose
+ transfer_read
, which lowers to a loop in the case of a non-contiguous read (such as in this test case): https://godbolt.org/z/b96E7aYq4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I've not had a chance to return to this yet. It's one of 3 things that are "next" on my list :)
As a follow-up to PR #135841 (see discussion for context), this patch removes `ConvertIllegalShapeCastOpsToTransposes` from the SME legalization pass and unblocks `ShapeCastOp::fold` for scalable vectors. AFAIK, `ConvertIllegalShapeCastOpsToTransposes` was originally needed because we were generating `vector.shape_cast` ops that couldn't be lowered otherwise. To confirm it's no longer required, I tested this patch locally using end-to-end tests. Notably, this also removes a special case from `ShapeCastOp::fold`.
Add LowerColumnTransferReadToLoops. Note, this is to address Ben's comment here: * https://github.com/llvm/llvm-project/pull/139706/files#r2088605443
0bf798d
to
d04d335
Compare
Note: I’ve rebased this on top of main to keep it up to date — please let me know if that’s problematic, and I’ll avoid doing it in the future. @MacDue, I’ve restored the ability to lower the following (thanks for flagging this!): %illegalRead = vector.transfer_read %memref[%a, %b], %pad: memref<?x?xf32>, vector<[4]x1xf32> This is now handled in
Personally, I’d suggest leaving these as TODOs for now, and focus this PR on removing WDYT? |
I think you could just update the |
I’d prefer not to, mainly because there’s a fundamental difference between
This is consistent with recent improvements from @newling - the general goal has been to treat vector.shape_cast as a no-op wherever possible.
Agreed, avoiding duplication is ideal. That said, in this specific case, the duplication introduced is relatively minor, and keeping That said, happy to add TODOs. WDYT? |
The problem I had is described here.
Another 2 cents, not directly related to this PR: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’d prefer not to, mainly because there’s a fundamental difference between vector.transpose and vector.shape_cast. The former involves actual data movement, while the latter is not supposed to - as noted in the docs:
Not blocking this PR (feel free to land it), but I don't really agree with this stance, vector.shape_cast
has never really been a no-op (the docs have been outdated for a long time). It's no different from any other operation e.g. vector.transpose
, that may mean data movement, or may fold away (e.g. into a load/store). 🙂
Thanks Ben!
Agreed - it certainly isn’t a no-op today. That said, I believe we should still aim to preserve the original design intent where possible, treating deviations as exceptions rather than the default. @newling, are you OK with these changes? If so, could you please approve? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I expect my 'canonicalize towards shape_cast' changes will be smoother now, thank you!
…139706) As a follow-up to PR llvm#135841 (see discussion for background), this patch removes the `ConvertIllegalShapeCastOpsToTransposes` pattern from the SME legalization pass. This change unblocks folding for ShapeCastOp involving scalable vectors. Originally, the `ConvertIllegalShapeCastOpsToTransposes` pattern was introduced to rewrite certain `vector.shape_cast` ops that could not be lowered otherwise. Based on local end-to-end testing, this workaround is no longer required, and the pattern can now be safely removed. This patch also removes a special case from `ShapeCastOp::fold`, simplifying the fold logic. As a side effect of removing `ConvertIllegalShapeCastOpsToTransposes`, we lose the mechanism that enabled lowering of certain ops like: ```mlir %res = vector.transfer_read %mem[%a, %b] (...) : memref<?x?xf32>, vector<[4]x1xf32> ``` Previously, such cases were handled by: * Rewriting a nearby `vector.shape_cast` to a `vector.transpose` (via `ConvertIllegalShapeCastOpsToTransposes`) * Then lowering the result with `LiftIllegalVectorTransposeToMemory`. This patch introduces a new dedicated pattern, `LowerColumnTransferReadToLoops`, that directly handles illegal `vector.transfer_read` ops involving leading scalable dimensions.
…139706) As a follow-up to PR llvm#135841 (see discussion for background), this patch removes the `ConvertIllegalShapeCastOpsToTransposes` pattern from the SME legalization pass. This change unblocks folding for ShapeCastOp involving scalable vectors. Originally, the `ConvertIllegalShapeCastOpsToTransposes` pattern was introduced to rewrite certain `vector.shape_cast` ops that could not be lowered otherwise. Based on local end-to-end testing, this workaround is no longer required, and the pattern can now be safely removed. This patch also removes a special case from `ShapeCastOp::fold`, simplifying the fold logic. As a side effect of removing `ConvertIllegalShapeCastOpsToTransposes`, we lose the mechanism that enabled lowering of certain ops like: ```mlir %res = vector.transfer_read %mem[%a, %b] (...) : memref<?x?xf32>, vector<[4]x1xf32> ``` Previously, such cases were handled by: * Rewriting a nearby `vector.shape_cast` to a `vector.transpose` (via `ConvertIllegalShapeCastOpsToTransposes`) * Then lowering the result with `LiftIllegalVectorTransposeToMemory`. This patch introduces a new dedicated pattern, `LowerColumnTransferReadToLoops`, that directly handles illegal `vector.transfer_read` ops involving leading scalable dimensions.
As a follow-up to PR #135841 (see discussion for background), this patch
removes the
ConvertIllegalShapeCastOpsToTransposes
pattern from the SMElegalization pass. This change unblocks folding for ShapeCastOp involving
scalable vectors.
Originally, the
ConvertIllegalShapeCastOpsToTransposes
pattern was introducedto rewrite certain
vector.shape_cast
ops that could not be lowered otherwise.Based on local end-to-end testing, this workaround is no longer required, and
the pattern can now be safely removed.
This patch also removes a special case from
ShapeCastOp::fold
, simplifying thefold logic.
As a side effect of removing
ConvertIllegalShapeCastOpsToTransposes
, we losethe mechanism that enabled lowering of certain ops like:
Previously, such cases were handled by:
vector.shape_cast
to avector.transpose
(viaConvertIllegalShapeCastOpsToTransposes
)LiftIllegalVectorTransposeToMemory
.This patch introduces a new dedicated pattern,
LowerColumnTransferReadToLoops
, that directly handles illegalvector.transfer_read
ops involving leading scalable dimensions.