-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Vector] Allow Scalable Dim in OneDimMultiReductionToTwoDim #89978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
To correctly lower multi_reduction of 1-dim scalable vector, e.g.: <[4]xf32> (cherry picked from commit 35f95c21e7895181084cc3a14c4e70eb8d1e6eee)
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Zhaoshi Zheng (zhaoshiz) ChangesTo correctly lower multi_reduction of 1-dim scalable vector, e.g.: <[4]xf32> Full diff: https://github.com/llvm/llvm-project/pull/89978.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2f21c50c63473b..240fedf4ccb651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -438,7 +438,9 @@ struct OneDimMultiReductionToTwoDim
auto srcVectorType = multiReductionOp.getSourceVectorType();
auto srcShape = srcVectorType.getShape();
auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
- srcVectorType.getElementType());
+ srcVectorType.getElementType(),
+ ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
+
auto accType =
VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
@@ -455,10 +457,11 @@ struct OneDimMultiReductionToTwoDim
loc, accType, multiReductionOp.getAcc());
Value castMask;
if (maskableOp.isMasked()) {
- auto maskType = llvm::cast<ShapedType>(mask.getType());
+ auto maskType = llvm::cast<VectorType>(mask.getType());
auto castMaskType =
VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
- maskType.getElementType());
+ maskType.getElementType(),
+ ArrayRef<bool>{false, maskType.getScalableDims().back()});
castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 22808aa7d6acc3..a5fe83118a0cf2 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -253,6 +253,7 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
%0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
return %0 : vector<8x[4]xf32>
}
+
// CHECK-LABEL: func.func private @scalable_dims(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
@@ -281,6 +282,23 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
// CHECK: %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
// CHECK: return %[[VAL_163]] : vector<8x[4]xf32>
+// Check that OneDimMultiReductionToTwoDim handles scalable dim
+func.func private @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
+ %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
+ return %0 : f32
+}
+
+// CHECK-LABEL: func.func private @scalable_dim_1d(
+// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
+// CHECK-SAME: %[[ARG_1:.*]]: f32,
+// CHECK-SAME: %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
+// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK: %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][%[[VAL_0]] : index] : vector<1xf32>
+// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
+// CHECK: return %[[VAL_4]] : f32
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
To correctly lower multi_reduction of 1-dim scalable vector, e.g.: <[4]xf32>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of minor nits but otherwise LGTM, cheers!
To correctly lower multi_reduction of 1-dim scalable vector, e.g.: <[4]xf32>
To correctly lower multi_reduction of 1-dim scalable vector, e.g.: <[4]xf32>