-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] add support for linearizing vector.bitcast in VectorLinearize #123110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Chao Chen (chencha3) ChangesThis PR adds support for converting Full diff: https://github.com/llvm/llvm-project/pull/123110.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 68535ae5a7a5c6..b450ea91fef651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -72,13 +72,14 @@ struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
auto resType =
getTypeConverter()->convertType<VectorType>(constOp.getType());
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type");
+
if (resType.isScalable() && !isa<SplatElementsAttr>(constOp.getValue()))
return rewriter.notifyMatchFailure(
loc,
"Cannot linearize a constant scalable vector that's not a splat");
- if (!resType)
- return rewriter.notifyMatchFailure(loc, "can't convert return type");
if (!isLessThanTargetBitWidth(constOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
loc, "Can't flatten since targetBitWidth <= OpSize");
@@ -459,6 +460,35 @@ struct LinearizeVectorInsert final
private:
unsigned targetVectorBitWidth;
};
+
+struct LinearizeVectorBitCast final
+ : public OpConversionPattern<vector::BitCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LinearizeVectorBitCast(
+ const TypeConverter &typeConverter, MLIRContext *context,
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned>::max(),
+ PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit),
+ targetVectorBitWidth(targetVectBitWidth) {}
+ LogicalResult
+ matchAndRewrite(vector::BitCastOp castOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = castOp.getLoc();
+ auto resType = getTypeConverter()->convertType(castOp.getType());
+ if (!resType)
+ return rewriter.notifyMatchFailure(loc, "can't convert return type.");
+
+ if (!isLessThanTargetBitWidth(castOp, targetVectorBitWidth))
+ return rewriter.notifyMatchFailure(
+ loc, "Can't flatten since targetBitWidth <= OpSize");
+
+ rewriter.replaceOpWithNewOp<vector::BitCastOp>(castOp, resType, adaptor.getSource());
+ return mlir::success();
+ }
+private:
+ unsigned targetVectorBitWidth;
+};
+
} // namespace
void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
@@ -486,6 +516,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
if ((isa<arith::ConstantOp>(op) ||
+ isa<vector::BitCastOp>(op) ||
op->hasTrait<OpTrait::Vectorizable>())) {
return (isLessThanTargetBitWidth(op, targetBitWidth)
? typeConverter.isLegal(op)
@@ -494,7 +525,7 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
return std::nullopt;
});
- patterns.add<LinearizeConstant, LinearizeVectorizable>(
+ patterns.add<LinearizeConstant, LinearizeVectorizable, LinearizeVectorBitCast>(
typeConverter, patterns.getContext(), targetBitWidth);
}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 543e76b5b26e0c..0358c2637f72b2 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -179,7 +179,7 @@ func.func @test_extract_strided_slice_1(%arg0 : vector<4x8xf32>) -> vector<2x2xf
// ALL-LABEL: func.func @test_extract_strided_slice_1_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
-func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
+func.func @test_extract_strided_slice_1_scalable(%arg0: vector<4x[8]xf32>) -> vector<2x[8]xf32> {
// ALL-NOT: vector.shuffle
// ALL-NOT: vector.shape_cast
// ALL: %[[RES:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [1, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x[8]xf32> to vector<2x[8]xf32>
@@ -318,3 +318,22 @@ func.func @test_vector_extract_scalar() {
%0 = vector.extract %cst[0] : i32 from vector<4xi32>
return
}
+
+// -----
+
+// ALL-LABEL: test_vector_bitcast
+// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>)
+func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> {
+
+ // DEFAULT: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
+ // DEFAULT: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
+ // DEFAULT: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+
+ // BW-128: %[[R0:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4xf32>
+ // BW-128: %[[R1:.*]] = vector.bitcast %[[R0]] : vector<4xf32> to vector<8xf16>
+ // BW-128: %[[R2:.*]] = vector.shape_cast %[[R1]] : vector<8xf16> to vector<4x2xf16>
+
+ // BW-0: %[[R2:.*]] = vector.bitcast %[[ORIG_ARG]] : vector<4x1xf32> to vector<4x2xf16>
+ %1 = vector.bitcast %arg0 : vector<4x1xf32> to vector<4x2xf16>
+ return %1 : vector<4x2xf16>
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
||
// ALL-LABEL: test_vector_bitcast | ||
// ALL-SAME: (%[[ORIG_ARG:.*]]: vector<4x1xf32>) | ||
func.func @test_vector_bitcast(%arg0: vector<4x1xf32>) -> vector<4x2xf16> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind changing the test so that it doesn't use a unit dim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I updated it. Thanks for your feedback.
@@ -459,6 +460,37 @@ struct LinearizeVectorInsert final | |||
private: | |||
unsigned targetVectorBitWidth; | |||
}; | |||
|
|||
struct LinearizeVectorBitCast final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add documentation and tests for scalable vectors (bailing out for scalable vectors would be fine, though this should just work).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I added comments and tests for scalable vectors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@banach-space Do you have any more concerns?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the delay—this slipped off my radar.
Everything looks good, but I have a small request regarding the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
I've made some small requests re test formatting, but these are nits.
@@ -459,6 +460,37 @@ struct LinearizeVectorInsert final | |||
private: | |||
unsigned targetVectorBitWidth; | |||
}; | |||
|
|||
struct LinearizeVectorBitCast final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the delay—this slipped off my radar.
Everything looks good, but I have a small request regarding the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind using more descriptive LIT variables ? (e.g. RES
or DOWNCAST
/UPCAST
for vector.shape_cast
) This would be more consistent with the existing convention within the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, fixed them and thanks.
// ----- | ||
|
||
// ALL-LABEL: test_vector_bitcast | ||
// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Since the name of the argument is arg0
.
// ALL-SAME: %[[ORIG_ARG:.*]]: vector<4x4xf32> | |
// ALL-SAME: %[[ARG_0:.*]]: vector<4x4xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, fixed and thanks.
This PR adds support for converting
vector.bitcast
working on ND (N > 1) vectors into the same op working on linearized (1D) vectors.