Skip to content

[MLIR][Vector] Add support for inner-parallel masked multi-reductions #126722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 14, 2025

Conversation

manupak
Copy link
Contributor

@manupak manupak commented Feb 11, 2025

This commit adds support to lower inner-parallel flavor of masked vector multi-reductions.

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Manupa Karunaratne (manupak)

Changes

This commit adds suppot to lower inner-parallel flavor of masked vector multi-reductions.


Full diff: https://github.com/llvm/llvm-project/pull/126722.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp (+16-5)
  • (modified) mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir (+20)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 0cafc9cd3551746..63897c8ed36e5fd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -310,9 +310,16 @@ struct TwoDimMultiReductionToElementWise
                                 PatternRewriter &rewriter) const override {
     auto maskableOp =
         cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
-    if (maskableOp.isMasked())
-      // TODO: Support masking.
-      return failure();
+
+    Operation *rootOp;
+    Value mask = nullptr;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+      mask = maskableOp.getMaskingOp().getMask();
+    } else {
+      rootOp = multiReductionOp;
+    }
 
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     // Rank-2 ["parallel", "reduce"] or bail.
@@ -334,11 +341,15 @@ struct TwoDimMultiReductionToElementWise
     for (int64_t i = 0; i < srcShape[0]; i++) {
       auto operand = rewriter.create<vector::ExtractOp>(
           loc, multiReductionOp.getSource(), i);
+      Value extractMask = nullptr;
+      if (mask) {
+        extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i);
+      }
       result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
-                                  operand, result);
+                                  operand, result, nullptr, extractMask);
     }
 
-    rewriter.replaceOp(multiReductionOp, result);
+    rewriter.replaceOp(rootOp, result);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index 68621ffaac3d20d..e7b8697f554a1f6 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -41,3 +41,23 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
 //        ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
 // INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
 //  INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
+
+// -----
+
+func.func @vector_masked_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>, %mask: vector<2x4xi1>) -> vector<2xf32> {
+    %0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> } : vector<2x4xi1> -> vector<2xf32>
+    return %0 : vector<2xf32>
+}
+
+//       ALL-LABEL: func @vector_masked_multi_reduction
+//        ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<2x4xi1>
+// INNER-REDUCTION: %[[INNERVEC:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[INNERACC:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[INNERMASK:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: vector.mask %[[INNERMASK]] { vector.reduction <mul>, %[[INNERVEC]], %[[INNERACC]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+//  INNER-PARALLEL: %[[TPMASK:.+]] = vector.transpose %[[MASK]], [1, 0] : vector<2x4xi1> to vector<4x2xi1>
+//  INNER-PARALLEL: %[[TPINPUT:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+//  INNER-PARALLEL: %[[INNERVEC:.+]] = vector.extract %[[TPINPUT]][0] : vector<2xf32> from vector<4x2xf32>
+//  INNER-PARALLEL: %[[INNERMASK:.+]] = vector.extract %[[TPMASK]][0] : vector<2xi1> from vector<4x2xi1>
+//  INNER-PARALLEL: %[[REDUCED:.+]] = arith.mulf %[[INNERVEC]], %[[ACC]] : vector<2xf32>
+//  INNER-PARALLEL: arith.select %[[INNERMASK]], %[[REDUCED]], %[[ACC]] : vector<2xi1>, vector<2xf32>

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks!


// -----

func.func @vector_masked_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>, %mask: vector<2x4xi1>) -> vector<2xf32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you rename this to vector_multi_reduction_masked

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Operation *rootOp;
Value mask = nullptr;
if (maskableOp.isMasked()) {
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this lower to when we start creating operations and just add rewriter.setInsertionPoint(rootOp) ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also add an insertion guard to restore the prev insertion point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved it a bit lower but I d like to obtain the mask inside the same conditional.. hence it cant be at where we start creating operations -- consistent with all other masked implementation in this file.

thanks for catching insertion guard. fixed now.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Operation *rootOp;
Value mask = nullptr;
if (maskableOp.isMasked()) {
rewriter.setInsertionPoint(maskableOp.getMaskingOp());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also add an insertion guard to restore the prev insertion point.

result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
operand, result);
operand, result, nullptr, extractMask);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> /* nameOfTheFuncParam=nullptr */ for readability

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@manupak manupak force-pushed the masked-vector-multi-reduce branch from ee5193c to 6c94d5e Compare February 12, 2025 10:49
Copy link

github-actions bot commented Feb 12, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@manupak manupak force-pushed the masked-vector-multi-reduce branch from 6c94d5e to bea5b92 Compare February 12, 2025 10:58
This commit adds suppot to lower inner-parallel flavor
of masked vector multi-reductions.
@manupak
Copy link
Contributor Author

manupak commented Feb 14, 2025

If there are not anymore comments, I d appreciate someone hitting the button :).

@hanhanW hanhanW merged commit db1e15a into llvm:main Feb 14, 2025
8 checks passed
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
…llvm#126722)

This commit adds support to lower inner-parallel flavor of masked vector
multi-reductions.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…llvm#126722)

This commit adds support to lower inner-parallel flavor of masked vector
multi-reductions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants