Skip to content

[MLIR][Vector] Allow Scalable Dim in OneDimMultiReductionToTwoDim #89978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 25, 2024

Conversation

zhaoshiz
Copy link
Contributor

To correctly lower multi_reduction of 1-dim scalable vector, e.g.: <[4]xf32>

zhaoshiz and others added 2 commits April 24, 2024 12:04
To correctly lower multi_reduction of 1-dim scalable vector, e.g.:
<[4]xf32>

(cherry picked from commit 35f95c21e7895181084cc3a14c4e70eb8d1e6eee)
@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Zhaoshi Zheng (zhaoshiz)

Changes

To correctly lower multi_reduction of 1-dim scalable vector, e.g.: <[4]xf32>


Full diff: https://github.com/llvm/llvm-project/pull/89978.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+6-3)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir (+18)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 2f21c50c63473b..240fedf4ccb651 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -438,7 +438,9 @@ struct OneDimMultiReductionToTwoDim
     auto srcVectorType = multiReductionOp.getSourceVectorType();
     auto srcShape = srcVectorType.getShape();
     auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()},
-                                      srcVectorType.getElementType());
+                                      srcVectorType.getElementType(),
+        ArrayRef<bool>{false, srcVectorType.getScalableDims().back()});
+
     auto accType =
         VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType());
     assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) &&
@@ -455,10 +457,11 @@ struct OneDimMultiReductionToTwoDim
         loc, accType, multiReductionOp.getAcc());
     Value castMask;
     if (maskableOp.isMasked()) {
-      auto maskType = llvm::cast<ShapedType>(mask.getType());
+      auto maskType = llvm::cast<VectorType>(mask.getType());
       auto castMaskType =
           VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()},
-                          maskType.getElementType());
+                          maskType.getElementType(),
+          ArrayRef<bool>{false, maskType.getScalableDims().back()});
       castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask);
     }
 
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
index 22808aa7d6acc3..a5fe83118a0cf2 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir
@@ -253,6 +253,7 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
   %0 = vector.multi_reduction <add>, %A, %B [2] : vector<8x[4]x2xf32> to vector<8x[4]xf32>
   return %0 : vector<8x[4]xf32>
 }
+
 // CHECK-LABEL:   func.func private @scalable_dims(
 // CHECK-SAME:                                     %[[VAL_0:.*]]: vector<8x[4]x2xf32>,
 // CHECK-SAME:                                     %[[VAL_1:.*]]: vector<8x[4]xf32>) -> vector<8x[4]xf32> {
@@ -281,6 +282,23 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32>
 // CHECK:           %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32>
 // CHECK:           return %[[VAL_163]] : vector<8x[4]xf32>
 
+// Check that OneDimMultiReductionToTwoDim handles scalable dim
+func.func private @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 {
+    %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
+    return %0 : f32
+}
+
+// CHECK-LABEL:  func.func private @scalable_dim_1d(
+// CHECK-SAME:                                      %[[ARG_0:.*]]: vector<[4]xf32>,
+// CHECK-SAME:                                      %[[ARG_1:.*]]: f32,
+// CHECK-SAME:                                      %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 {
+// CHECK-DAG:      %[[VAL_0:.*]] = arith.constant 0 : index
+// CHECK-DAG:      %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+// CHECK:          %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32
+// CHECK:          %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][%[[VAL_0]] : index] : vector<1xf32>
+// CHECK:          %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32>
+// CHECK:          return %[[VAL_4]] : f32
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
     %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link

github-actions bot commented Apr 24, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@zhaoshiz zhaoshiz requested a review from banach-space April 24, 2024 19:33
To correctly lower multi_reduction of 1-dim scalable vector, e.g.:
<[4]xf32>
@zhaoshiz zhaoshiz requested a review from dcaballe April 25, 2024 04:42
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple of minor nits but otherwise LGTM, cheers!

To correctly lower multi_reduction of 1-dim scalable vector, e.g.:
<[4]xf32>
@zhaoshiz zhaoshiz merged commit dbcc454 into llvm:main Apr 25, 2024
@zhaoshiz zhaoshiz deleted the mlir-scalable-vector-reduce branch April 25, 2024 20:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants