Skip to content

[mlir][ArmSME] Rewrite illegal shape_casts to vector.transpose ops #82985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Feb 26, 2024

This adds a rewrite that converts illegal 2D unit-dim shape_casts into vector.transpose ops.

E.g.

// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32>

Becomes:

// Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32>

Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR.

Rewriting them as a transpose gives LiftIllegalVectorTransposeToMemory a chance to eliminate the illegal types.

@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2024

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This adds a rewrite that converts illegal 2D unit-dim shape_casts into vector.transpose ops.

E.g.

// Case 1:
%a = vector.shape_cast %0 : vector&lt;[4]x1xf32&gt; to vector&lt;1x[4]xf32&gt;
// Case 2:
%b = vector.shape_cast %1 : vector&lt;[4]x1xf32&gt; to vector&lt;[4]xf32&gt;

Becomes:

// Case 1:
%a = vector.transpose %0 : [1, 0] vector&lt;[4]x1xf32&gt; to vector&lt;1x[4]xf32&gt;
// Case 2:
%t = vector.transpose %1 : [1, 0] vector&lt;[4]x1xf32&gt; to vector&lt;1x[4]xf32&gt;
%b = vector.shape_cast %t : vector&lt;1x[4]xf32&gt; to vector&lt;[4]xf32&gt;

Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR.

Rewriting them as a transpose gives LiftIllegalVectorTransposeToMemory a chance to eliminate the illegal types.


Full diff: https://github.com/llvm/llvm-project/pull/82985.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+71-14)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+45)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 11f8bc04b21844..55b20e5a477d4e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -46,6 +46,8 @@ static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
     "op mask is unsupported for legalization/decomposition");
 static constexpr StringLiteral
     kMatchFailureNonPermutationMap("op affine map is not a permutation");
+static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
+    "expected transpose from illegal type to legal type");
 
 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
 /// larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -416,6 +418,17 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
   }
 };
 
+/// A vector type where no fixed dimension comes after a scalable dimension.
+bool isLegalVectorType(VectorType vType) {
+  bool seenFixedDim = false;
+  for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+    seenFixedDim |= !scalableFlag;
+    if (seenFixedDim && scalableFlag)
+      return false;
+  }
+  return true;
+}
+
 /// Lifts an illegal vector.transpose and vector.transfer_read to a
 /// memref.subview + memref.transpose, followed by a legal read.
 ///
@@ -448,16 +461,6 @@ struct LiftIllegalVectorTransposeToMemory
     : public OpRewritePattern<vector::TransposeOp> {
   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
 
-  static bool isIllegalVectorType(VectorType vType) {
-    bool seenFixedDim = false;
-    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
-      seenFixedDim |= !scalableFlag;
-      if (seenFixedDim && scalableFlag)
-        return true;
-    }
-    return false;
-  }
-
   static Value getExtensionSource(Operation *op) {
     if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
       return op->getOperand(0);
@@ -468,9 +471,9 @@ struct LiftIllegalVectorTransposeToMemory
                                 PatternRewriter &rewriter) const override {
     auto sourceType = transposeOp.getSourceVectorType();
     auto resultType = transposeOp.getResultVectorType();
-    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
-      return rewriter.notifyMatchFailure(
-          transposeOp, "expected transpose from illegal type to legal type");
+    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         kMatchFailureNotIllegalToLegal);
 
     // Look through extend for transfer_read.
     Value maybeRead = transposeOp.getVector();
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
   }
 };
 
+/// A rewrite to turn unit dim transpose-like vector.shape_cast into a
+/// vector.transpose. The shape_cast has to be from an illegal vector type to a
+/// legal one (as defined by isLegalVectorType).
+///
+/// The reasoning for this is if we've got to this pass and we still have
+/// shape_casts of illegal types, then they likely will not cancel out. Turning
+/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
+/// eliminate them.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+struct ConvertIllegalShapeCastOpsToTransposes
+    : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = shapeCastOp.getSourceVectorType();
+    auto resultType = shapeCastOp.getResultVectorType();
+    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(shapeCastOp,
+                                         kMatchFailureNotIllegalToLegal);
+
+    // Note: If we know the that is source is an illegal vector type (and 2D)
+    // then dim 0 is scalable and dim 1 is fixed.
+    if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
+      return rewriter.notifyMatchFailure(
+          shapeCastOp, "expected source to be a 2D scalable vector with a "
+                       "trailing unit dim");
+
+    auto loc = shapeCastOp.getLoc();
+    auto transpose = rewriter.create<vector::TransposeOp>(
+        loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
+
+    if (resultType.getRank() == 1)
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
+                                                       transpose);
+    else
+      rewriter.replaceOp(shapeCastOp, transpose);
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
@@ -576,7 +632,8 @@ struct VectorLegalizationPass
         });
 
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
-                 LiftIllegalVectorTransposeToMemory>(context);
+                 LiftIllegalVectorTransposeToMemory,
+                 ConvertIllegalShapeCastOpsToTransposes>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index bf0b58ff4cf073..f8be697548c197 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -388,3 +388,48 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
+// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+  // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
+// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
+  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
+func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+  // CHECK-NOT: vector.shape_cast
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %cast : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
+func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+  // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
+  return %cast : vector<[4]xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Feb 26, 2024

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

Changes

This adds a rewrite that converts illegal 2D unit-dim shape_casts into vector.transpose ops.

E.g.

// Case 1:
%a = vector.shape_cast %0 : vector&lt;[4]x1xf32&gt; to vector&lt;1x[4]xf32&gt;
// Case 2:
%b = vector.shape_cast %1 : vector&lt;[4]x1xf32&gt; to vector&lt;[4]xf32&gt;

Becomes:

// Case 1:
%a = vector.transpose %0 : [1, 0] vector&lt;[4]x1xf32&gt; to vector&lt;1x[4]xf32&gt;
// Case 2:
%t = vector.transpose %1 : [1, 0] vector&lt;[4]x1xf32&gt; to vector&lt;1x[4]xf32&gt;
%b = vector.shape_cast %t : vector&lt;1x[4]xf32&gt; to vector&lt;[4]xf32&gt;

Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR.

Rewriting them as a transpose gives LiftIllegalVectorTransposeToMemory a chance to eliminate the illegal types.


Full diff: https://github.com/llvm/llvm-project/pull/82985.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp (+71-14)
  • (modified) mlir/test/Dialect/ArmSME/vector-legalization.mlir (+45)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 11f8bc04b21844..55b20e5a477d4e 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -46,6 +46,8 @@ static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
     "op mask is unsupported for legalization/decomposition");
 static constexpr StringLiteral
     kMatchFailureNonPermutationMap("op affine map is not a permutation");
+static constexpr StringLiteral kMatchFailureNotIllegalToLegal(
+    "expected transpose from illegal type to legal type");
 
 /// An SMESubTile represents a single SME-sized sub-tile from decomposing a
 /// larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -416,6 +418,17 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
   }
 };
 
+/// A vector type where no fixed dimension comes after a scalable dimension.
+bool isLegalVectorType(VectorType vType) {
+  bool seenFixedDim = false;
+  for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
+    seenFixedDim |= !scalableFlag;
+    if (seenFixedDim && scalableFlag)
+      return false;
+  }
+  return true;
+}
+
 /// Lifts an illegal vector.transpose and vector.transfer_read to a
 /// memref.subview + memref.transpose, followed by a legal read.
 ///
@@ -448,16 +461,6 @@ struct LiftIllegalVectorTransposeToMemory
     : public OpRewritePattern<vector::TransposeOp> {
   using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
 
-  static bool isIllegalVectorType(VectorType vType) {
-    bool seenFixedDim = false;
-    for (bool scalableFlag : llvm::reverse(vType.getScalableDims())) {
-      seenFixedDim |= !scalableFlag;
-      if (seenFixedDim && scalableFlag)
-        return true;
-    }
-    return false;
-  }
-
   static Value getExtensionSource(Operation *op) {
     if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
       return op->getOperand(0);
@@ -468,9 +471,9 @@ struct LiftIllegalVectorTransposeToMemory
                                 PatternRewriter &rewriter) const override {
     auto sourceType = transposeOp.getSourceVectorType();
     auto resultType = transposeOp.getResultVectorType();
-    if (!isIllegalVectorType(sourceType) || isIllegalVectorType(resultType))
-      return rewriter.notifyMatchFailure(
-          transposeOp, "expected transpose from illegal type to legal type");
+    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(transposeOp,
+                                         kMatchFailureNotIllegalToLegal);
 
     // Look through extend for transfer_read.
     Value maybeRead = transposeOp.getVector();
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
   }
 };
 
+/// A rewrite to turn unit dim transpose-like vector.shape_cast into a
+/// vector.transpose. The shape_cast has to be from an illegal vector type to a
+/// legal one (as defined by isLegalVectorType).
+///
+/// The reasoning for this is if we've got to this pass and we still have
+/// shape_casts of illegal types, then they likely will not cancel out. Turning
+/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
+/// eliminate them.
+///
+/// Example:
+///
+///  BEFORE:
+///  ```mlir
+///  %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+///
+///  AFTER:
+///  ```mlir
+///  %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+///  ```
+struct ConvertIllegalShapeCastOpsToTransposes
+    : public OpRewritePattern<vector::ShapeCastOp> {
+  using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+                                PatternRewriter &rewriter) const override {
+    auto sourceType = shapeCastOp.getSourceVectorType();
+    auto resultType = shapeCastOp.getResultVectorType();
+    if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
+      return rewriter.notifyMatchFailure(shapeCastOp,
+                                         kMatchFailureNotIllegalToLegal);
+
+    // Note: If we know the that is source is an illegal vector type (and 2D)
+    // then dim 0 is scalable and dim 1 is fixed.
+    if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
+      return rewriter.notifyMatchFailure(
+          shapeCastOp, "expected source to be a 2D scalable vector with a "
+                       "trailing unit dim");
+
+    auto loc = shapeCastOp.getLoc();
+    auto transpose = rewriter.create<vector::TransposeOp>(
+        loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
+
+    if (resultType.getRank() == 1)
+      rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
+                                                       transpose);
+    else
+      rewriter.replaceOp(shapeCastOp, transpose);
+
+    return success();
+  }
+};
+
 struct VectorLegalizationPass
     : public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
   void runOnOperation() override {
@@ -576,7 +632,8 @@ struct VectorLegalizationPass
         });
 
     patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
-                 LiftIllegalVectorTransposeToMemory>(context);
+                 LiftIllegalVectorTransposeToMemory,
+                 ConvertIllegalShapeCastOpsToTransposes>(context);
     // Note: High benefit to ensure masked outer products are lowered first.
     patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
         converter, context, 1024);
diff --git a/mlir/test/Dialect/ArmSME/vector-legalization.mlir b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
index bf0b58ff4cf073..f8be697548c197 100644
--- a/mlir/test/Dialect/ArmSME/vector-legalization.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-legalization.mlir
@@ -388,3 +388,48 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
   %0 = vector.transpose %vec, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
   return %0 : vector<1x[4]xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
+// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
+  // CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %0 : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
+// CHECK-SAME:                                      %[[VEC:.*]]: vector<[4]x1xf32>)
+func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
+  // CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
+  // CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
+  %0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
+  return %0 : vector<[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
+func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+  // CHECK-NOT: vector.shape_cast
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
+  return %cast : vector<1x[4]xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
+func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
+  // CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
+  // CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
+  %pad = arith.constant 0.0 : f32
+  %illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
+  %cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
+  return %cast : vector<[4]xf32>
+}

@MacDue MacDue force-pushed the shape_cast_to_transpose branch from 9865dd5 to b7edfde Compare February 26, 2024 11:38
@c-rhodes
Copy link
Collaborator

// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>

...

Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR.

something not clear to me here, Case 1 would be introduced by TransposeOpLowering in -convert-vector-to-llvm and vector legalization is way before then? Or is the ordering slightly different in IREE?

@MacDue
Copy link
Member Author

MacDue commented Feb 26, 2024

// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>

...

Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR.

something not clear to me here, Case 1 would be introduced by TransposeOpLowering in -convert-vector-to-llvm and vector legalization is way before then? Or is the ordering slightly different in IREE?

In IREE the generic vector lowering happens quite a bit earlier than the ArmSME pipeline (and it's not part of -convert-vector-to-llvm, it's just a general vector dialect level pass).

@c-rhodes
Copy link
Collaborator

// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>

...

Various lowerings and drop unit-dims patterns add such shape_casts, however, if they do not cancel out (which they likely won't if we've reached the vector-legalization pass) they will prevent lowering the IR.

something not clear to me here, Case 1 would be introduced by TransposeOpLowering in -convert-vector-to-llvm and vector legalization is way before then? Or is the ordering slightly different in IREE?

In IREE the generic vector lowering happens quite a bit earlier than the ArmSME pipeline (and it's not part of -convert-vector-to-llvm, it's just a general vector dialect level pass).

ok, thanks for clarifying

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one minor comment otherwise LGTM cheers

MacDue added 2 commits March 7, 2024 16:23
This adds a rewrite that converts illegal 2D unit-dim `shape_casts`
into `vector.transpose` ops.

E.g.

```mlir
// Case 1:
%a = vector.shape_cast %0 : vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%b = vector.shape_cast %1 : vector<[4]x1xf32> to vector<[4]xf32>
```

Becomes:

```
// Case 1:
%a = vector.transpose %0 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
// Case 2:
%t = vector.transpose %1 : [1, 0] vector<[4]x1xf32> to vector<1x[4]xf32>
%b = vector.shape_cast %t : vector<1x[4]xf32> to vector<[4]xf32>
```

Various lowerings and drop unit-dims patterns add such shape_casts,
however, if they do not cancel out (which they likely won't if we've
reached the vector-legalization pass) they will prevent lowering the IR.

Rewriting them as a transpose gives `LiftIllegalVectorTransposeToMemory`
a chance to eliminate the illegal types.
@MacDue MacDue force-pushed the shape_cast_to_transpose branch from b7edfde to 517e5f0 Compare March 7, 2024 16:32
@MacDue MacDue merged commit d1fc59c into llvm:main Mar 7, 2024
@MacDue MacDue deleted the shape_cast_to_transpose branch March 7, 2024 17:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants