-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector]Add Vector bitwidth target to Linearize Vectorizable and Constant Ops #83314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
on user specified vector-lengths.
@llvm/pr-subscribers-mlir Author: Balaji V. Iyer. (bviyer) ChangesAdded a new flag Full diff: https://github.com/llvm/llvm-project/pull/83314.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 46bb3ddec0baf6..453fa73429dd1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -387,7 +387,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target);
+ ConversionTarget &target, unsigned targetBitWidth);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c5352043955579..28a7de22954f99 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -19,10 +19,27 @@
using namespace mlir;
+static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+ auto resultTypes = op->getResultTypes();
+ for (auto resType : resultTypes) {
+ VectorType vecType = cast<VectorType>(resType);
+ unsigned trailingVecDimBitWidth =
+ vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ if (trailingVecDimBitWidth >= targetBitWidth)
+ return false;
+ }
+ return true;
+}
+
namespace {
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
-
+ LinearizeConstant(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -31,7 +48,9 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
getTypeConverter()->convertType<VectorType>(constOp.getType());
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
-
+ if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ loc, "Can't flatten since targetBitWidth <= OpSize");
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!dstElementsAttr)
return rewriter.notifyMatchFailure(loc, "unsupported attr type");
@@ -41,15 +60,28 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
dstElementsAttr);
return success();
}
+
+private:
+ unsigned targetVectorBitWidth;
};
struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;
+public:
+ LinearizeVectorizable(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpTraitConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+ if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
if (failed(newOp))
@@ -58,12 +90,16 @@ struct LinearizeVectorizable final
rewriter.replaceOp(op, (*newOp)->getResults());
return success();
}
+
+private:
+ unsigned targetVectorBitWidth;
};
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target) {
+ ConversionTarget &target, unsigned targetBitWidth) {
+
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
// Ignore scalable vectors for now.
if (type.getRank() <= 1 || type.isScalable())
@@ -83,15 +119,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addArgumentMaterialization(materializeCast);
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
-
target.markUnknownOpDynamicallyLegal(
- [&](Operation *op) -> std::optional<bool> {
- if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
- return typeConverter.isLegal(op);
-
+ [=](Operation *op) -> std::optional<bool> {
+ if ((isa<arith::ConstantOp>(op) ||
+ op->hasTrait<OpTrait::Vectorizable>())) {
+ return (isLessThanTargetBitWidth(op, targetBitWidth)
+ ? typeConverter.isLegal(op)
+ : true);
+ }
return std::nullopt;
});
- patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
- patterns.getContext());
+ patterns.add<LinearizeConstant, LinearizeVectorizable>(
+ typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 85e23103eaedb7..659bb021846d89 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,17 +1,25 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=12 | FileCheck %s --check-prefix=CHECK12
// CHECK-LABEL: test_linearize
+// CHECK12-LABEL: test_linearize
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+// CHECK12: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
// Arith and math ops are handled in generic way, check some of them
// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
+// CHECK12: %{{.*}} = math.sin %{{.*}} : vector<2x2xf32>
%1 = math.sin %arg0 : vector<2x2xf32>
// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+// CHECK12: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
+
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
// CHECK: return %[[RES]] : vector<2x2xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b246..74d2dfa44f4fe9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -842,6 +842,9 @@ struct TestVectorLinearize final
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+ TestVectorLinearize() = default;
+ TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+
StringRef getArgument() const override { return "test-vector-linearize"; }
StringRef getDescription() const override {
return "Linearizes ND vectors for N >= 2 into 1D vectors";
@@ -850,6 +853,11 @@ struct TestVectorLinearize final
registry.insert<vector::VectorDialect>();
}
+ Option<unsigned> targetVectorBitwidth{
+ *this, "target-vector-bitwidth",
+ llvm::cl::desc(
+ "Minimum vector bitwidth to enable the flattening transformation"),
+ llvm::cl::init(std::numeric_limits<unsigned>::max())};
void runOnOperation() override {
auto *context = &getContext();
@@ -857,8 +865,8 @@ struct TestVectorLinearize final
RewritePatternSet patterns(context);
ConversionTarget target(*context);
- vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
- patterns, target);
+ vector::populateVectorLinearizeTypeConversionsAndLegality(
+ typeConverter, patterns, target, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
|
@llvm/pr-subscribers-mlir-vector Author: Balaji V. Iyer. (bviyer) ChangesAdded a new flag Full diff: https://github.com/llvm/llvm-project/pull/83314.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 46bb3ddec0baf6..453fa73429dd1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -387,7 +387,7 @@ void populateVectorTransposeNarrowTypeRewritePatterns(
/// the ops to get converted properly.
void populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target);
+ ConversionTarget &target, unsigned targetBitWidth);
} // namespace vector
} // namespace mlir
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index c5352043955579..28a7de22954f99 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -19,10 +19,27 @@
using namespace mlir;
+static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
+ auto resultTypes = op->getResultTypes();
+ for (auto resType : resultTypes) {
+ VectorType vecType = cast<VectorType>(resType);
+ unsigned trailingVecDimBitWidth =
+ vecType.getShape().back() * vecType.getElementTypeBitWidth();
+ if (trailingVecDimBitWidth >= targetBitWidth)
+ return false;
+ }
+ return true;
+}
+
namespace {
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern::OpConversionPattern;
-
+ LinearizeConstant(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(arith::ConstantOp constOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
@@ -31,7 +48,9 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
getTypeConverter()->convertType<VectorType>(constOp.getType());
if (!resType)
return rewriter.notifyMatchFailure(loc, "can't convert return type");
-
+ if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ loc, "Can't flatten since targetBitWidth <= OpSize");
auto dstElementsAttr = dyn_cast<DenseElementsAttr>(constOp.getValue());
if (!dstElementsAttr)
return rewriter.notifyMatchFailure(loc, "unsupported attr type");
@@ -41,15 +60,28 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
dstElementsAttr);
return success();
}
+
+private:
+ unsigned targetVectorBitWidth;
};
struct LinearizeVectorizable final
: OpTraitConversionPattern<OpTrait::Vectorizable> {
using OpTraitConversionPattern::OpTraitConversionPattern;
+public:
+ LinearizeVectorizable(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpTraitConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+ if (!isLessThanTargetBitWidth(op, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "Can't flatten since targetBitWidth <= OpSize");
FailureOr<Operation *> newOp =
convertOpResultTypes(op, operands, *getTypeConverter(), rewriter);
if (failed(newOp))
@@ -58,12 +90,16 @@ struct LinearizeVectorizable final
rewriter.replaceOp(op, (*newOp)->getResults());
return success();
}
+
+private:
+ unsigned targetVectorBitWidth;
};
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
- ConversionTarget &target) {
+ ConversionTarget &target, unsigned targetBitWidth) {
+
typeConverter.addConversion([](VectorType type) -> std::optional<Type> {
// Ignore scalable vectors for now.
if (type.getRank() <= 1 || type.isScalable())
@@ -83,15 +119,17 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
typeConverter.addArgumentMaterialization(materializeCast);
typeConverter.addSourceMaterialization(materializeCast);
typeConverter.addTargetMaterialization(materializeCast);
-
target.markUnknownOpDynamicallyLegal(
- [&](Operation *op) -> std::optional<bool> {
- if (isa<arith::ConstantOp>(op) || op->hasTrait<OpTrait::Vectorizable>())
- return typeConverter.isLegal(op);
-
+ [=](Operation *op) -> std::optional<bool> {
+ if ((isa<arith::ConstantOp>(op) ||
+ op->hasTrait<OpTrait::Vectorizable>())) {
+ return (isLessThanTargetBitWidth(op, targetBitWidth)
+ ? typeConverter.isLegal(op)
+ : true);
+ }
return std::nullopt;
});
- patterns.add<LinearizeConstant, LinearizeVectorizable>(typeConverter,
- patterns.getContext());
+ patterns.add<LinearizeConstant, LinearizeVectorizable>(
+ typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 85e23103eaedb7..659bb021846d89 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -1,17 +1,25 @@
// RUN: mlir-opt %s -split-input-file -test-vector-linearize | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=128 | FileCheck %s
+// RUN: mlir-opt %s -split-input-file -test-vector-linearize=target-vector-bitwidth=12 | FileCheck %s --check-prefix=CHECK12
// CHECK-LABEL: test_linearize
+// CHECK12-LABEL: test_linearize
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x2xf32>)
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// CHECK: %[[C1:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+// CHECK12: %[[C1:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
+
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// CHECK: %[[RES:.*]] = vector.shape_cast %[[C1]] : vector<4xf32> to vector<2x2xf32>
// Arith and math ops are handled in generic way, check some of them
// CHECK: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
+// CHECK12: %{{.*}} = math.sin %{{.*}} : vector<2x2xf32>
%1 = math.sin %arg0 : vector<2x2xf32>
// CHECK: %{{.*}} = arith.addf %[[ARG]], %[[C1]] : vector<4xf32>
+// CHECK12: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
+
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
// CHECK: return %[[RES]] : vector<2x2xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b246..74d2dfa44f4fe9 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -842,6 +842,9 @@ struct TestVectorLinearize final
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+ TestVectorLinearize() = default;
+ TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {}
+
StringRef getArgument() const override { return "test-vector-linearize"; }
StringRef getDescription() const override {
return "Linearizes ND vectors for N >= 2 into 1D vectors";
@@ -850,6 +853,11 @@ struct TestVectorLinearize final
registry.insert<vector::VectorDialect>();
}
+ Option<unsigned> targetVectorBitwidth{
+ *this, "target-vector-bitwidth",
+ llvm::cl::desc(
+ "Minimum vector bitwidth to enable the flattening transformation"),
+ llvm::cl::init(std::numeric_limits<unsigned>::max())};
void runOnOperation() override {
auto *context = &getContext();
@@ -857,8 +865,8 @@ struct TestVectorLinearize final
RewritePatternSet patterns(context);
ConversionTarget target(*context);
- vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
- patterns, target);
+ vector::populateVectorLinearizeTypeConversionsAndLegality(
+ typeConverter, patterns, target, targetVectorBitwidth);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
|
Hey @Hardcode84! Just to provide a bit more context, the goal here is to add a bit more control to the flattening. The ultimate state should allow us to (optionally) unroll up to the first multiple of the provided vector bitwidth, and leave the rest of the unrolling to the passes intended for that. Hopefully that makes sense to you! We are also adding the same functionality to the vector transfer read/write counterpart: #81966 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, is this for CPU codegen? When do we need it? I thought that we want to get rid of shape_cast. Instead, we want the mermef.collapse_shape version? Is it for "scalar loads/stores" + "fully utilization for vector computation"?
|
Oh I see the point, very good point! |
I was the one who added this transform initially, in our case we have a vectors which are logically 2D, but eventually lowered to LLVM/SPIR-V intrinsics which only support 1D types so we operate them as 2D and flatten as the last step. |
…136581) [NFC] Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results. In #83314 an option to ignore (make 'legal') operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to remove non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all). As a follow-up to this PR, I propose that user(s) of the `targetVectorBitWidth` move the relevant code (now in mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code bases, and then eventually remove it from upstream. In addition the tests need to split out (I've intentionally not modified the lit tests here, to make it easier to confirm that this is a NFC). I'm happy to help make it easier to do this final step! The approach I've used is to move the logic pertaining to `targetVectorBitWidth` out the patterns, and into the conversion target, which the end user can control outside of core MLIR.
…lvm#136581) [NFC] Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results. In llvm#83314 an option to ignore (make 'legal') operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to remove non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all). As a follow-up to this PR, I propose that user(s) of the `targetVectorBitWidth` move the relevant code (now in mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code bases, and then eventually remove it from upstream. In addition the tests need to split out (I've intentionally not modified the lit tests here, to make it easier to confirm that this is a NFC). I'm happy to help make it easier to do this final step! The approach I've used is to move the logic pertaining to `targetVectorBitWidth` out the patterns, and into the conversion target, which the end user can control outside of core MLIR.
…lvm#136581) [NFC] Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results. In llvm#83314 an option to ignore (make 'legal') operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to remove non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all). As a follow-up to this PR, I propose that user(s) of the `targetVectorBitWidth` move the relevant code (now in mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code bases, and then eventually remove it from upstream. In addition the tests need to split out (I've intentionally not modified the lit tests here, to make it easier to confirm that this is a NFC). I'm happy to help make it easier to do this final step! The approach I've used is to move the logic pertaining to `targetVectorBitWidth` out the patterns, and into the conversion target, which the end user can control outside of core MLIR.
…lvm#136581) [NFC] Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results. In llvm#83314 an option to ignore (make 'legal') operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to remove non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all). As a follow-up to this PR, I propose that user(s) of the `targetVectorBitWidth` move the relevant code (now in mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code bases, and then eventually remove it from upstream. In addition the tests need to split out (I've intentionally not modified the lit tests here, to make it easier to confirm that this is a NFC). I'm happy to help make it easier to do this final step! The approach I've used is to move the logic pertaining to `targetVectorBitWidth` out the patterns, and into the conversion target, which the end user can control outside of core MLIR.
…f patterns (#136581) [NFC] Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results. In llvm/llvm-project#83314 an option to ignore (make 'legal') operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to remove non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all). As a follow-up to this PR, I propose that user(s) of the `targetVectorBitWidth` move the relevant code (now in mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code bases, and then eventually remove it from upstream. In addition the tests need to split out (I've intentionally not modified the lit tests here, to make it easier to confirm that this is a NFC). I'm happy to help make it easier to do this final step! The approach I've used is to move the logic pertaining to `targetVectorBitWidth` out the patterns, and into the conversion target, which the end user can control outside of core MLIR.
…lvm#136581) [NFC] Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results. In llvm#83314 an option to ignore (make 'legal') operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to remove non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all). As a follow-up to this PR, I propose that user(s) of the `targetVectorBitWidth` move the relevant code (now in mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code bases, and then eventually remove it from upstream. In addition the tests need to split out (I've intentionally not modified the lit tests here, to make it easier to confirm that this is a NFC). I'm happy to help make it easier to do this final step! The approach I've used is to move the logic pertaining to `targetVectorBitWidth` out the patterns, and into the conversion target, which the end user can control outside of core MLIR.
…lvm#136581) [NFC] Vector linearization is a collection of rewrite patterns that reduce the rank of vector operands and results. In llvm#83314 an option to ignore (make 'legal') operations with large inner-most dimensions was added. This current PR is a step towards making that option live outside of upstream MLIR. The motivation is to remove non-core functionality (I would like to use this pass, but would prefer not to deal with 'targetVectorBitWidth` at all). As a follow-up to this PR, I propose that user(s) of the `targetVectorBitWidth` move the relevant code (now in mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code bases, and then eventually remove it from upstream. In addition the tests need to split out (I've intentionally not modified the lit tests here, to make it easier to confirm that this is a NFC). I'm happy to help make it easier to do this final step! The approach I've used is to move the logic pertaining to `targetVectorBitWidth` out the patterns, and into the conversion target, which the end user can control outside of core MLIR.
Added a new flag
targetVectorBitwidth
to capture bit-width input.