Skip to content

[mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes #139706

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented May 13, 2025

As a follow-up to PR #135841 (see discussion for background), this patch
removes the ConvertIllegalShapeCastOpsToTransposes pattern from the SME
legalization pass. This change unblocks folding for ShapeCastOp involving
scalable vectors.

Originally, the ConvertIllegalShapeCastOpsToTransposes pattern was introduced
to rewrite certain vector.shape_cast ops that could not be lowered otherwise.
Based on local end-to-end testing, this workaround is no longer required, and
the pattern can now be safely removed.

This patch also removes a special case from ShapeCastOp::fold, simplifying the
fold logic.

As a side effect of removing ConvertIllegalShapeCastOpsToTransposes, we lose
the mechanism that enabled lowering of certain ops like:

%res = vector.transfer_read %mem[%a, %b] (...) :
       memref<?x?xf32>, vector<[4]x1xf32>

Previously, such cases were handled by:

  • Rewriting a nearby vector.shape_cast to a vector.transpose (via
    ConvertIllegalShapeCastOpsToTransposes)
  • Then lowering the result with LiftIllegalVectorTransposeToMemory.

This patch introduces a new dedicated pattern,
LowerColumnTransferReadToLoops, that directly handles illegal
vector.transfer_read ops involving leading scalable dimensions.

@banach-space banach-space changed the base branch from main to users/banach-space/vector/document_tests May 13, 2025 10:54
@llvmbot
Copy link
Member

llvmbot commented May 13, 2025

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector][nfc] Update comments in vector-transpose.mlir
  • [mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes

Full diff: https://github.com/llvm/llvm-project/pull/139706.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (-54)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-12)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (-45)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir (+37-3)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 95965872f4098..51750f0bb9694 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
   }
 };
 
-/// A rewrite to turn unit dim transpose-like vector.shape_casts into
-/// vector.transposes. The shape_cast has to be from an illegal vector type to a
-/// legal one (as defined by isLegalVectorType).
-///
-/// The reasoning for this is if we've got to this pass and we still have
-/// shape_casts of illegal types, then they likely will not cancel out. Turning
-/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
-/// eliminate them.
-///
-/// Example:
-///
-///  BEFORE:
-///  ```mlir
-///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  ```
-struct ConvertIllegalShapeCastOpsToTransposes
-    : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto sourceType = shapeCastOp.getSourceVectorType();
-    auto resultType = shapeCastOp.getResultVectorType();
-    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
-      return rewriter.notifyMatchFailure(shapeCastOp,
-                                         kMatchFailureNotIllegalToLegal);
-
-    // Note: If we know that `sourceType` is an illegal vector type (and 2D)
-    // then dim 0 is scalable and dim 1 is fixed.
-    if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
-      return rewriter.notifyMatchFailure(
-          shapeCastOp, "expected source to be a 2D scalable vector with a "
-                       "trailing unit dim");
-
-    auto loc = shapeCastOp.getLoc();
-    auto transpose = rewriter.create<vector::TransposeOp>(
-        loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
-
-    if (resultType.getRank() == 1)
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
-                                                       transpose);
-    else
-      rewriter.replaceOp(shapeCastOp, transpose);
-
-    return success();
-  }
-};
-
 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
 /// the ZA state. This workaround rewrite to support these transposes when ZA is
 /// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
     RewritePatternSet rewritePatterns(context);
     rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                         LiftIllegalVectorTransposeToMemory,
-                        ConvertIllegalShapeCastOpsToTransposes,
                         LowerIllegalTransposeStoreViaZA>(context);
     if (failed(
             applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..83a287d29d773 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(transpose(x)) -> shape_cast(x)
   if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
-    // This folder does
-    //    shape_cast(transpose) -> shape_cast
-    // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
-    //    shape_cast -> shape_cast(transpose)
-    // i.e. the complete opposite. When paired, these 2 patterns can cause
-    // infinite cycles in pattern rewriting.
-    // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
-    // vectors, so by disabling this folder for scalable vectors the
-    // cycle is avoided.
-    // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
-    // still needed. If it's not, then we can fold here.
-    if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+    if (isOrderPreserving(transpose)) {
       setOperand(transpose.getVector());
       return getResult();
     }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d56df9814f173..6e6615c243d2a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
 
 // -----
 
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
-// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
-  // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
-// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
-  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
-  return %0 : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
-func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
-  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
-  // CHECK-NOT: vector.shape_cast
-  %pad = arith.constant 0.0 : f32
-  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
-  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %cast : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
-func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
-  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
-  // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
-  %pad = arith.constant 0.0 : f32
-  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
-  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
-  return %cast : vector<[4]xf32>
-}
-
-// -----
-
 // CHECK-LABEL: @multi_tile_splat
 func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
 {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e47578bc80719..625b4a9c53e42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
 
 // -----
 
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// (same as the example above, but one of the dims is scalable)
+// CHECK-LABEL: @transpose_shape_cast_scalable
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<[4]x4xi8>
+func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+  %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+     : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
+  return %1 : vector<[4]x4xi8>
+}
+
+// -----
+
 // In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
 // 1 -> 2
 // 2 -> 1
@@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) ->  vector<6x1x1xi
 
 // -----
 
-// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-LABEL: @transpose_shape_cast_scalable
+//  CHECK-SAME:   %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) ->  vector<[6]x1x1xi8> {
+  %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+  %1 = vector.transpose %0, [0, 2, 1]
+     : vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
+  return %1 : vector<[6]x1x1xi8>
+}
+
+// -----
+
+// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
+// (hence no folding).
+// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
 //       CHECK: vector.shape_cast
 //       CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
   %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
   %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
   return %1 : vector<4x[1]xi8>

@llvmbot
Copy link
Member

llvmbot commented May 13, 2025

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector][nfc] Update comments in vector-transpose.mlir
  • [mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes

Full diff: https://github.com/llvm/llvm-project/pull/139706.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (-54)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+1-12)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (-45)
  • (modified) mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir (+37-3)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 95965872f4098..51750f0bb9694 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
   }
 };
 
-/// A rewrite to turn unit dim transpose-like vector.shape_casts into
-/// vector.transposes. The shape_cast has to be from an illegal vector type to a
-/// legal one (as defined by isLegalVectorType).
-///
-/// The reasoning for this is if we've got to this pass and we still have
-/// shape_casts of illegal types, then they likely will not cancel out. Turning
-/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
-/// eliminate them.
-///
-/// Example:
-///
-///  BEFORE:
-///  ```mlir
-///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  ```
-///
-///  AFTER:
-///  ```mlir
-///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-///  ```
-struct ConvertIllegalShapeCastOpsToTransposes
-    : public OpRewritePattern<vector::ShapeCastOp> {
-  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
-                                PatternRewriter &rewriter) const override {
-    auto sourceType = shapeCastOp.getSourceVectorType();
-    auto resultType = shapeCastOp.getResultVectorType();
-    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
-      return rewriter.notifyMatchFailure(shapeCastOp,
-                                         kMatchFailureNotIllegalToLegal);
-
-    // Note: If we know that `sourceType` is an illegal vector type (and 2D)
-    // then dim 0 is scalable and dim 1 is fixed.
-    if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
-      return rewriter.notifyMatchFailure(
-          shapeCastOp, "expected source to be a 2D scalable vector with a "
-                       "trailing unit dim");
-
-    auto loc = shapeCastOp.getLoc();
-    auto transpose = rewriter.create<vector::TransposeOp>(
-        loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
-
-    if (resultType.getRank() == 1)
-      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
-                                                       transpose);
-    else
-      rewriter.replaceOp(shapeCastOp, transpose);
-
-    return success();
-  }
-};
-
 /// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
 /// the ZA state. This workaround rewrite to support these transposes when ZA is
 /// available.
@@ -943,7 +890,6 @@ struct VectorLegalizationPass
     RewritePatternSet rewritePatterns(context);
     rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
                         LiftIllegalVectorTransposeToMemory,
-                        ConvertIllegalShapeCastOpsToTransposes,
                         LowerIllegalTransposeStoreViaZA>(context);
     if (failed(
             applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f6c3c6a61afb6..83a287d29d773 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5617,18 +5617,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
 
   // shape_cast(transpose(x)) -> shape_cast(x)
   if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
-    // This folder does
-    //    shape_cast(transpose) -> shape_cast
-    // But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
-    //    shape_cast -> shape_cast(transpose)
-    // i.e. the complete opposite. When paired, these 2 patterns can cause
-    // infinite cycles in pattern rewriting.
-    // ConvertIllegalShapeCastOpsToTransposes only matches on scalable
-    // vectors, so by disabling this folder for scalable vectors the
-    // cycle is avoided.
-    // TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
-    // still needed. If it's not, then we can fold here.
-    if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
+    if (isOrderPreserving(transpose)) {
       setOperand(transpose.getVector());
       return getResult();
     }
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index d56df9814f173..6e6615c243d2a 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
 
 // -----
 
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
-// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
-  // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %0 : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
-// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
-func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
-  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
-  // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
-  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
-  return %0 : vector<[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
-func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
-  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
-  // CHECK-NOT: vector.shape_cast
-  %pad = arith.constant 0.0 : f32
-  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
-  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
-  return %cast : vector<1x[4]xf32>
-}
-
-// -----
-
-// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
-func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
-  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
-  // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
-  %pad = arith.constant 0.0 : f32
-  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
-  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
-  return %cast : vector<[4]xf32>
-}
-
-// -----
-
 // CHECK-LABEL: @multi_tile_splat
 func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
 {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
index e47578bc80719..625b4a9c53e42 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-transpose.mlir
@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
 
 // -----
 
+// In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
+// 1 -> 0
+// 2 -> 4
+// Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
+// (same as the example above, but one of the dims is scalable)
+// CHECK-LABEL: @transpose_shape_cast_scalable
+//  CHECK-SAME:   %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<[4]x4xi8>
+func.func @transpose_shape_cast_scalable(%arg : vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
+  %0 = vector.transpose %arg, [1, 0, 3, 4, 2]
+     : vector<1x[4]x4x1x1xi8> to vector<[4]x1x1x1x4xi8>
+  %1 = vector.shape_cast %0 : vector<[4]x1x1x1x4xi8> to vector<[4]x4xi8>
+  return %1 : vector<[4]x4xi8>
+}
+
+// -----
+
 // In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
 // 1 -> 2
 // 2 -> 1
@@ -225,11 +244,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) ->  vector<6x1x1xi
 
 // -----
 
-// Scalable dimensions should be treated as non-unit dimensions.
-// CHECK-LABEL: @transpose_of_shape_cast_scalable
+// CHECK-LABEL: @transpose_shape_cast_scalable
+//  CHECK-SAME:   %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
+//       CHECK:   %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
+//  CHECK-SAME:   vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+//       CHECK:   return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
+func.func @shape_cast_transpose_scalable(%arg : vector<[2]x3x1x1xi8>) ->  vector<[6]x1x1xi8> {
+  %0 = vector.shape_cast %arg : vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
+  %1 = vector.transpose %0, [0, 2, 1]
+     : vector<[6]x1x1xi8> to vector<[6]x1x1xi8>
+  return %1 : vector<[6]x1x1xi8>
+}
+
+// -----
+
+// Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
+// (hence no folding).
+// CHECK-LABEL: @negative_shape_cast_transpose_scalable_unit
 //       CHECK: vector.shape_cast
 //       CHECK: vector.transpose
-func.func @transpose_of_shape_cast_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
+func.func @negative_shape_cast_transpose_scalable_unit(%arg : vector<[1]x4x1xi8>) -> vector<4x[1]xi8> {
   %0 = vector.shape_cast %arg : vector<[1]x4x1xi8> to vector<[1]x4xi8>
   %1 = vector.transpose %0, [1, 0] : vector<[1]x4xi8> to vector<4x[1]xi8>
   return %1 : vector<4x[1]xi8>

@banach-space banach-space changed the title users/banach space/sme/remove ConvertIllegalShapeCastOpsToTransposes [mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes May 13, 2025
@banach-space banach-space requested review from newling and MacDue May 13, 2025 10:54
Base automatically changed from users/banach-space/vector/document_tests to main May 14, 2025 09:13
@banach-space banach-space force-pushed the users/banach-space/sme/remove_ConvertIllegalShapeCastOpsToTransposes branch from 06ca6e9 to 0bf798d Compare May 14, 2025 09:28
Comment on lines -518 to -522
// CHECK-NOT: vector.shape_cast
%pad = arith.constant 0.0 : f32
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
return %cast : vector<1x[4]xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you've tested, but to know if this rewrite is still needed or not this test case should still be possible to lower to LLVM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ben!

I'm not sure what you've tested

I used our e2e tests - from what I can tell, we don't generate such code anymore.

this test case should still be possible to lower to LLVM

Indeed. @momchil-velikov , since you are working on a generic pattern for "xfer_read with non-trailing scalable dims", could you make sure that this example lowers with your patch?

  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>

I will wait for Momchil to upload his patch before progressing this one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see 2 PRs of @momchil-velikov in llvm-project that might be related, but just checking in that this PR is still on the radar.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's on the radar and hasn't slipped through the cracks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>

I just tried it and it does not lower. And, AFAICT, it shouldn't, as the last dimension of the memerf (?) and the vector (1) do not match and the read cannot be inferred to be contiguous, e.g. if we're reading from a memref with dynamic dimensions 4 and 2:

[*][]
[*][]
[*][]
[*][]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one

%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}
  : memref<?x2xf32>, vector<[4]x2xf32>

is lowered, though (with an implication that %b is zero, along the way).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried it and it does not lower. And, AFAICT, it shouldn't, as the last dimension of the memerf

This lowering does not depend on the memref/transpose being contiguous. It lowers the transfer_read to a memref.transpose + transfer_read, which lowers to a loop in the case of a non-contiguous read (such as in this test case): https://godbolt.org/z/b96E7aYq4

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I've not had a chance to return to this yet. It's one of 3 things that are "next" on my list :)

As a follow-up to PR #135841 (see discussion for context), this patch
removes `ConvertIllegalShapeCastOpsToTransposes` from the SME legalization
pass and unblocks `ShapeCastOp::fold` for scalable vectors.

AFAIK, `ConvertIllegalShapeCastOpsToTransposes` was originally needed
because we were generating `vector.shape_cast` ops that couldn't be
lowered otherwise. To confirm it's no longer required, I tested this
patch locally using end-to-end tests.

Notably, this also removes a special case from `ShapeCastOp::fold`.
@banach-space banach-space force-pushed the users/banach-space/sme/remove_ConvertIllegalShapeCastOpsToTransposes branch from 0bf798d to d04d335 Compare June 18, 2025 08:18
@banach-space
Copy link
Contributor Author

Note: I’ve rebased this on top of main to keep it up to date — please let me know if that’s problematic, and I’ll avoid doing it in the future.

@MacDue, I’ve restored the ability to lower the following (thanks for flagging this!):

  %illegalRead = vector.transfer_read %memref[%a, %b], %pad: memref<?x?xf32>, vector<[4]x1xf32>

This is now handled in LowerColumnTransferReadToLoops. There are two follow-up points worth discussing:

  • I might have been better off using TransferOpConversion - I’ll take another look.
  • in_bounds isn’t handled properly yet.

Personally, I’d suggest leaving these as TODOs for now, and focus this PR on removing ConvertIllegalShapeCastOpsToTransposes, since we now have a clear forward path. That said, I don’t feel strongly - just trying to prioritise, especially considering that ConvertIllegalShapeCastOpsToTransposes currently isn’t exercised at all.

WDYT?

@MacDue
Copy link
Member

MacDue commented Jun 18, 2025

I think you could just update the LiftIllegalVectorTransposeToMemory rewrite to match vector.shape_cast for this case (rather than vector.transpose). That should lead to less duplication as the existing lowerings for memref.transpose and the vector transfer operations would introduce the necessary loops.

@banach-space
Copy link
Contributor Author

I think you could just update the LiftIllegalVectorTransposeToMemory rewrite to match vector.shape_cast for this case (rather than vector.transpose).

I’d prefer not to, mainly because there’s a fundamental difference between vector.transpose and vector.shape_cast. The former involves actual data movement, while the latter is not supposed to - as noted in the docs:

It is currently assumed that this operation does not require moving data, and that it will be folded away before lowering vector operations.

This is consistent with recent improvements from @newling - the general goal has been to treat vector.shape_cast as a no-op wherever possible.

That should lead to less duplication as the existing lowerings for memref.transpose and the vector transfer operations would introduce the necessary loops.

Agreed, avoiding duplication is ideal. That said, in this specific case, the duplication introduced is relatively minor, and keeping shape_cast logic separate from loop-emitting rewrites aligns better with its semantics.

That said, happy to add TODOs. WDYT?

@newling
Copy link
Contributor

newling commented Jun 18, 2025

The problem I had is described here.

the general goal has been to treat vector.shape_cast as a no-op wherever possible.

Another 2 cents, not directly related to this PR: shape_cast ops should ideally be eliminated before converting to LLVM/SPIRV. I think the approach of "don't worry about creating shape_casts, because shape_casts are most likely to be folded before conversion" is mostly fine. But ultimately, if shape_casts are not eliminated, they lower in a very similar way to broadcast/extract/transpose -- in LLVM as a bunch of extractelement, extractvalue, and shufflevector ops -- which may or may not be register-register copies. Which I think is fine, too.

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’d prefer not to, mainly because there’s a fundamental difference between vector.transpose and vector.shape_cast. The former involves actual data movement, while the latter is not supposed to - as noted in the docs:

Not blocking this PR (feel free to land it), but I don't really agree with this stance, vector.shape_cast has never really been a no-op (the docs have been outdated for a long time). It's no different from any other operation e.g. vector.transpose, that may mean data movement, or may fold away (e.g. into a load/store). 🙂

@banach-space
Copy link
Contributor Author

Thanks Ben!

vector.shape_cast has never really been a no-op (the docs have been outdated for a long time)

Agreed - it certainly isn’t a no-op today. That said, I believe we should still aim to preserve the original design intent where possible, treating deviations as exceptions rather than the default.

@newling, are you OK with these changes? If so, could you please approve?

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I expect my 'canonicalize towards shape_cast' changes will be smoother now, thank you!

@banach-space banach-space merged commit c92b580 into main Jun 25, 2025
7 checks passed
@banach-space banach-space deleted the users/banach-space/sme/remove_ConvertIllegalShapeCastOpsToTransposes branch June 25, 2025 15:40
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
…139706)

As a follow-up to PR llvm#135841 (see discussion for background), this patch
removes the `ConvertIllegalShapeCastOpsToTransposes` pattern from the SME
legalization pass. This change unblocks folding for ShapeCastOp involving
scalable vectors.

Originally, the `ConvertIllegalShapeCastOpsToTransposes` pattern was introduced
to rewrite certain `vector.shape_cast` ops that could not be lowered otherwise.
Based on local end-to-end testing, this workaround is no longer required, and
the pattern can now be safely removed.

This patch also removes a special case from `ShapeCastOp::fold`, simplifying
the fold logic.

As a side effect of removing `ConvertIllegalShapeCastOpsToTransposes`, we lose
the mechanism that enabled lowering of certain ops like:

```mlir
%res = vector.transfer_read %mem[%a, %b] (...) :
       memref<?x?xf32>, vector<[4]x1xf32>
```
Previously, such cases were handled by:
* Rewriting a nearby `vector.shape_cast` to a `vector.transpose` (via
  `ConvertIllegalShapeCastOpsToTransposes`)
* Then lowering the result with `LiftIllegalVectorTransposeToMemory`.

This patch introduces a new dedicated pattern,
`LowerColumnTransferReadToLoops`, that directly handles illegal
`vector.transfer_read` ops involving leading scalable dimensions.
rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
…139706)

As a follow-up to PR llvm#135841 (see discussion for background), this patch
removes the `ConvertIllegalShapeCastOpsToTransposes` pattern from the SME
legalization pass. This change unblocks folding for ShapeCastOp involving
scalable vectors.

Originally, the `ConvertIllegalShapeCastOpsToTransposes` pattern was introduced
to rewrite certain `vector.shape_cast` ops that could not be lowered otherwise.
Based on local end-to-end testing, this workaround is no longer required, and
the pattern can now be safely removed.

This patch also removes a special case from `ShapeCastOp::fold`, simplifying
the fold logic.

As a side effect of removing `ConvertIllegalShapeCastOpsToTransposes`, we lose
the mechanism that enabled lowering of certain ops like:

```mlir
%res = vector.transfer_read %mem[%a, %b] (...) :
       memref<?x?xf32>, vector<[4]x1xf32>
```
Previously, such cases were handled by:
* Rewriting a nearby `vector.shape_cast` to a `vector.transpose` (via
  `ConvertIllegalShapeCastOpsToTransposes`)
* Then lowering the result with `LiftIllegalVectorTransposeToMemory`.

This patch introduces a new dedicated pattern,
`LowerColumnTransferReadToLoops`, that directly handles illegal
`vector.transfer_read` ops involving leading scalable dimensions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants