Skip to content

[mlir][vector]Add Vector bitwidth target to Linearize Vectorizable and Constant Ops #83314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 5, 2024

Conversation

bviyer
Copy link
Contributor

@bviyer bviyer commented Feb 28, 2024

Added a new flag targetVectorBitwidth to capture bit-width input.

@bviyer bviyer changed the title Added a flag to enable flattening of Constants and Vectors based Add Vector bitwidth target to Linearize Vectorizable and Constant Ops Feb 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 29, 2024

@llvm/pr-subscribers-mlir

Author: Balaji V. Iyer. (bviyer)

Changes

Added a new flag targetVectorBitwidth to capture bit-width input.


Full diff: https://github.com/llvm/llvm-project/pull/83314.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+48-10)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+8)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+10-2)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 46bb3ddec0baf6..453fa73429dd1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -387,7 +387,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
 /// the ops to get converted properly.
 void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target);
+    ConversionTarget &target, unsigned targetBitWidth);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c5352043955579..28a7de22954f99 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -19,10 +19,27 @@
 
 using namespace mlir;
 
+static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+  auto resultTypes = op->getResultTypes();
+  for (auto resType : resultTypes) {
+    VectorType vecType = cast<VectorType>(resType);
+    unsigned trailingVecDimBitWidth =
+        vecType.getShape().back() * vecType.getElementTypeBitWidth();
+    if (trailingVecDimBitWidth >= targetBitWidth)
+      return false;
+  }
+  return true;
+}
+
 namespace {
 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern::OpConversionPattern;
-
+  LinearizeConstant(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -31,7 +48,9 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
         getTypeConverter()->convertType<VectorType>(constOp.getType());
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
-
+    if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
     if (!dstElementsAttr)
       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
@@ -41,15 +60,28 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
                                                    dstElementsAttr);
     return success();
   }
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 struct LinearizeVectorizable final
     : OpTraitConversionPattern<OpTrait::Vectorizable> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
+public:
+  LinearizeVectorizable(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
     if (failed(newOp))
@@ -58,12 +90,16 @@ struct LinearizeVectorizable final
     rewriter.replaceOp(op, (*newOp)->getResults());
     return success();
   }
+
+private:
+  unsigned targetVectorBitWidth;
 };
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target) {
+    ConversionTarget &target, unsigned targetBitWidth) {
+
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     // Ignore scalable vectors for now.
     if (type.getRank() <= 1 || type.isScalable())
@@ -83,15 +119,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addArgumentMaterialization(materializeCast);
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
-
   target.markUnknownOpDynamicallyLegal(
-      [&](Operation *op) -> std::optional<bool> {
-        if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
-          return typeConverter.isLegal(op);
-
+      [=](Operation *op) -> std::optional<bool> {
+        if ((isa<arith::ConstantOp>(op) ||
+             op->hasTrait<OpTrait::Vectorizable>())) {
+          return (isLessThanTargetBitWidth(op, targetBitWidth)
+                      ? typeConverter.isLegal(op)
+                      : true);
+        }
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
-                                                         patterns.getContext());
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(
+      typeConverter, patterns.getContext(), targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 85e23103eaedb7..659bb021846d89 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,17 +1,25 @@
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=12 | FileCheck %s --check-prefix=CHECK12
 
 // CHECK-LABEL: test_linearize
+// CHECK12-LABEL: test_linearize
 //  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
 //       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
 func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
 //       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+//       CHECK12: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
   %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
 //       CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
 
 // Arith and math ops are handled in generic way, check some of them
 //       CHECK: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+//       CHECK12: %{{.*}} =  math.sin %{{.*}} : vector<2x2xf32>
   %1 = math.sin %arg0 : vector<2x2xf32>
 //       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+//       CHECK12: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
+
   %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
 
 //       CHECK: return %[[RES]] : vector<2x2xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b246..74d2dfa44f4fe9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -842,6 +842,9 @@ struct TestVectorLinearize final
     : public PassWrapper<TestVectorLinearize, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
 
+  TestVectorLinearize() = default;
+  TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+
   StringRef getArgument() const override { return "test-vector-linearize"; }
   StringRef getDescription() const override {
     return "Linearizes ND vectors for N >= 2 into 1D vectors";
@@ -850,6 +853,11 @@ struct TestVectorLinearize final
     registry.insert<vector::VectorDialect>();
   }
 
+  Option<unsigned> targetVectorBitwidth{
+      *this, "target-vector-bitwidth",
+      llvm::cl::desc(
+          "Minimum vector bitwidth to enable the flattening transformation"),
+      llvm::cl::init(std::numeric_limits<unsigned>::max())};
   void runOnOperation() override {
     auto *context = &getContext();
 
@@ -857,8 +865,8 @@ struct TestVectorLinearize final
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
 
-    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
-                                                              patterns, target);
+    vector::populateVectorLinearizeTypeConversionsAndLegality(
+        typeConverter, patterns, target, targetVectorBitwidth);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();

@llvmbot
Copy link
Member

llvmbot commented Feb 29, 2024

@llvm/pr-subscribers-mlir-vector

Author: Balaji V. Iyer. (bviyer)

Changes

Added a new flag targetVectorBitwidth to capture bit-width input.


Full diff: https://github.com/llvm/llvm-project/pull/83314.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+48-10)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+8)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+10-2)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 46bb3ddec0baf6..453fa73429dd1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -387,7 +387,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
 /// the ops to get converted properly.
 void populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target);
+    ConversionTarget &target, unsigned targetBitWidth);
 
 } // namespace vector
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c5352043955579..28a7de22954f99 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -19,10 +19,27 @@
 
 using namespace mlir;
 
+static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+  auto resultTypes = op->getResultTypes();
+  for (auto resType : resultTypes) {
+    VectorType vecType = cast<VectorType>(resType);
+    unsigned trailingVecDimBitWidth =
+        vecType.getShape().back() * vecType.getElementTypeBitWidth();
+    if (trailingVecDimBitWidth >= targetBitWidth)
+      return false;
+  }
+  return true;
+}
+
 namespace {
 struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
   using OpConversionPattern::OpConversionPattern;
-
+  LinearizeConstant(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
   matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
@@ -31,7 +48,9 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
         getTypeConverter()->convertType<VectorType>(constOp.getType());
     if (!resType)
       return rewriter.notifyMatchFailure(loc, "can't convert return type");
-
+    if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          loc, "Can't flatten since targetBitWidth <= OpSize");
     auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
     if (!dstElementsAttr)
       return rewriter.notifyMatchFailure(loc, "unsupported attr type");
@@ -41,15 +60,28 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
                                                    dstElementsAttr);
     return success();
   }
+
+private:
+  unsigned targetVectorBitWidth;
 };
 
 struct LinearizeVectorizable final
     : OpTraitConversionPattern<OpTrait::Vectorizable> {
   using OpTraitConversionPattern::OpTraitConversionPattern;
 
+public:
+  LinearizeVectorizable(
+      const TypeConverter &typeConverter, MLIRContext *context,
+      unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+      PatternBenefit benefit = 1)
+      : OpTraitConversionPattern(typeConverter, context, benefit),
+        targetVectorBitWidth(targetVectBitWidth) {}
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const override {
+    if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
     FailureOr<Operation *> newOp =
         convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
     if (failed(newOp))
@@ -58,12 +90,16 @@ struct LinearizeVectorizable final
     rewriter.replaceOp(op, (*newOp)->getResults());
     return success();
   }
+
+private:
+  unsigned targetVectorBitWidth;
 };
 } // namespace
 
 void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
     TypeConverter &typeConverter, RewritePatternSet &patterns,
-    ConversionTarget &target) {
+    ConversionTarget &target, unsigned targetBitWidth) {
+
   typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
     // Ignore scalable vectors for now.
     if (type.getRank() <= 1 || type.isScalable())
@@ -83,15 +119,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
   typeConverter.addArgumentMaterialization(materializeCast);
   typeConverter.addSourceMaterialization(materializeCast);
   typeConverter.addTargetMaterialization(materializeCast);
-
   target.markUnknownOpDynamicallyLegal(
-      [&](Operation *op) -> std::optional<bool> {
-        if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
-          return typeConverter.isLegal(op);
-
+      [=](Operation *op) -> std::optional<bool> {
+        if ((isa<arith::ConstantOp>(op) ||
+             op->hasTrait<OpTrait::Vectorizable>())) {
+          return (isLessThanTargetBitWidth(op, targetBitWidth)
+                      ? typeConverter.isLegal(op)
+                      : true);
+        }
         return std::nullopt;
       });
 
-  patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
-                                                         patterns.getContext());
+  patterns.add<LinearizeConstant, LinearizeVectorizable>(
+      typeConverter, patterns.getContext(), targetBitWidth);
 }
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 85e23103eaedb7..659bb021846d89 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,17 +1,25 @@
 // RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=12 | FileCheck %s --check-prefix=CHECK12
 
 // CHECK-LABEL: test_linearize
+// CHECK12-LABEL: test_linearize
 //  CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
 //       CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
 func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
 //       CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+//       CHECK12: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
   %0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
 //       CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
 
 // Arith and math ops are handled in generic way, check some of them
 //       CHECK: %{{.*}} =  math.sin %[[ARG]] : vector<4xf32>
+//       CHECK12: %{{.*}} =  math.sin %{{.*}} : vector<2x2xf32>
   %1 = math.sin %arg0 : vector<2x2xf32>
 //       CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+//       CHECK12: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
+
   %2 = arith.addf %arg0, %0 :  vector<2x2xf32>
 
 //       CHECK: return %[[RES]] : vector<2x2xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b246..74d2dfa44f4fe9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -842,6 +842,9 @@ struct TestVectorLinearize final
     : public PassWrapper<TestVectorLinearize, OperationPass<>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
 
+  TestVectorLinearize() = default;
+  TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+
   StringRef getArgument() const override { return "test-vector-linearize"; }
   StringRef getDescription() const override {
     return "Linearizes ND vectors for N >= 2 into 1D vectors";
@@ -850,6 +853,11 @@ struct TestVectorLinearize final
     registry.insert<vector::VectorDialect>();
   }
 
+  Option<unsigned> targetVectorBitwidth{
+      *this, "target-vector-bitwidth",
+      llvm::cl::desc(
+          "Minimum vector bitwidth to enable the flattening transformation"),
+      llvm::cl::init(std::numeric_limits<unsigned>::max())};
   void runOnOperation() override {
     auto *context = &getContext();
 
@@ -857,8 +865,8 @@ struct TestVectorLinearize final
     RewritePatternSet patterns(context);
     ConversionTarget target(*context);
 
-    vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
-                                                              patterns, target);
+    vector::populateVectorLinearizeTypeConversionsAndLegality(
+        typeConverter, patterns, target, targetVectorBitwidth);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       return signalPassFailure();

@dcaballe
Copy link
Contributor

dcaballe commented Mar 1, 2024

Hey @Hardcode84! Just to provide a bit more context, the goal here is to add a bit more control to the flattening. The ultimate state should allow us to (optionally) unroll up to the first multiple of the provided vector bitwidth, and leave the rest of the unrolling to the passes intended for that. Hopefully that makes sense to you! We are also adding the same functionality to the vector transfer read/write counterpart: #81966

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is this for CPU codegen? When do we need it? I thought that we want to get rid of shape_cast. Instead, we want the mermef.collapse_shape version? Is it for "scalar loads/stores" + "fully utilization for vector computation"?

@dcaballe
Copy link
Contributor

dcaballe commented Mar 1, 2024

Out of curiosity, is this for CPU codegen? When do we need it? I thought that we want to get rid of shape_cast. Instead, we want the mermef.collapse_shape version? Is it for "scalar loads/stores" + "fully utilization for vector computation"?

tensor.collapse_shape is only for memrefs (transfer ops). This is "collapsing" vectors. If we only flatten transfer ops but not their producers/consumers, we would end up with vector.shape_cast ops that won't be optimized away. The vector.shape_cast ops introduced by this pass should cancel out with the ones introduced by the transfer read/write flattening pass.

@hanhanW
Copy link
Contributor

hanhanW commented Mar 1, 2024

tensor.collapse_shape is only for memrefs (transfer ops). This is "collapsing" vectors. If we only flatten transfer ops but not their producers/consumers, we would end up with vector.shape_cast ops that won't be optimized away. The vector.shape_cast ops introduced by this pass should cancel out with the ones introduced by the transfer read/write flattening pass.

Oh I see the point, very good point!

@bviyer bviyer requested review from Hardcode84 and hanhanW March 1, 2024 19:52
@Hardcode84
Copy link
Contributor

Out of curiosity, is this for CPU codegen? When do we need it? I thought that we want to get rid of shape_cast. Instead, we want the mermef.collapse_shape version? Is it for "scalar loads/stores" + "fully utilization for vector computation"?

I was the one who added this transform initially, in our case we have a vectors which are logically 2D, but eventually lowered to LLVM/SPIR-V intrinsics which only support 1D types so we operate them as 2D and flatten as the last step.

@bviyer bviyer changed the title Add Vector bitwidth target to Linearize Vectorizable and Constant Ops [mlir][vector]Add Vector bitwidth target to Linearize Vectorizable and Constant Ops Mar 4, 2024
@bviyer bviyer merged commit 6f5c4f2 into llvm:main Mar 5, 2024
newling added a commit that referenced this pull request Apr 30, 2025
…136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In #83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…lvm#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request May 6, 2025
…f patterns (#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm/llvm-project#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
…lvm#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
Ankur-0429 pushed a commit to Ankur-0429/llvm-project that referenced this pull request May 9, 2025
…lvm#136581)

[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.

In llvm#83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).

As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!

The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants