-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Vector] Add support for inner-parallel masked multi-reductions #126722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Manupa Karunaratne (manupak) ChangesThis commit adds suppot to lower inner-parallel flavor of masked vector multi-reductions. Full diff: https://github.com/llvm/llvm-project/pull/126722.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 0cafc9cd3551746..63897c8ed36e5fd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -310,9 +310,16 @@ struct TwoDimMultiReductionToElementWise
PatternRewriter &rewriter) const override {
auto maskableOp =
cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
- if (maskableOp.isMasked())
- // TODO: Support masking.
- return failure();
+
+ Operation *rootOp;
+ Value mask = nullptr;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = multiReductionOp;
+ }
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Rank-2 ["parallel", "reduce"] or bail.
@@ -334,11 +341,15 @@ struct TwoDimMultiReductionToElementWise
for (int64_t i = 0; i < srcShape[0]; i++) {
auto operand = rewriter.create<vector::ExtractOp>(
loc, multiReductionOp.getSource(), i);
+ Value extractMask = nullptr;
+ if (mask) {
+ extractMask = rewriter.create<vector::ExtractOp>(loc, mask, i);
+ }
result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(),
- operand, result);
+ operand, result, nullptr, extractMask);
}
- rewriter.replaceOp(multiReductionOp, result);
+ rewriter.replaceOp(rootOp, result);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
index 68621ffaac3d20d..e7b8697f554a1f6 100644
--- a/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-pass-lowering.mlir
@@ -41,3 +41,23 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc
// ALL-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32>
// INNER-REDUCTION: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32>
// INNER-PARALLEL: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32>
+
+// -----
+
+func.func @vector_masked_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>, %mask: vector<2x4xi1>) -> vector<2xf32> {
+ %0 = vector.mask %mask { vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> } : vector<2x4xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// ALL-LABEL: func @vector_masked_multi_reduction
+// ALL-SAME: %[[INPUT:.+]]: vector<2x4xf32>, %[[ACC:.+]]: vector<2xf32>, %[[MASK:.+]]: vector<2x4xi1>
+// INNER-REDUCTION: %[[INNERVEC:.+]] = vector.extract %[[INPUT]][0] : vector<4xf32> from vector<2x4xf32>
+// INNER-REDUCTION: %[[INNERACC:.+]] = vector.extract %[[ACC]][0] : f32 from vector<2xf32>
+// INNER-REDUCTION: %[[INNERMASK:.+]] = vector.extract %[[MASK]][0] : vector<4xi1> from vector<2x4xi1>
+// INNER-REDUCTION: vector.mask %[[INNERMASK]] { vector.reduction <mul>, %[[INNERVEC]], %[[INNERACC]] : vector<4xf32> into f32 } : vector<4xi1> -> f32
+// INNER-PARALLEL: %[[TPMASK:.+]] = vector.transpose %[[MASK]], [1, 0] : vector<2x4xi1> to vector<4x2xi1>
+// INNER-PARALLEL: %[[TPINPUT:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32>
+// INNER-PARALLEL: %[[INNERVEC:.+]] = vector.extract %[[TPINPUT]][0] : vector<2xf32> from vector<4x2xf32>
+// INNER-PARALLEL: %[[INNERMASK:.+]] = vector.extract %[[TPMASK]][0] : vector<2xi1> from vector<4x2xi1>
+// INNER-PARALLEL: %[[REDUCED:.+]] = arith.mulf %[[INNERVEC]], %[[ACC]] : vector<2xf32>
+// INNER-PARALLEL: arith.select %[[INNERMASK]], %[[REDUCED]], %[[ACC]] : vector<2xi1>, vector<2xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks!
|
||
// ----- | ||
|
||
func.func @vector_masked_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>, %mask: vector<2x4xi1>) -> vector<2xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you rename this to vector_multi_reduction_masked
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Operation *rootOp; | ||
Value mask = nullptr; | ||
if (maskableOp.isMasked()) { | ||
rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this lower to when we start creating operations and just add rewriter.setInsertionPoint(rootOp) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also add an insertion guard to restore the prev insertion point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved it a bit lower but I d like to obtain the mask inside the same conditional.. hence it cant be at where we start creating operations -- consistent with all other masked implementation in this file.
thanks for catching insertion guard. fixed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Operation *rootOp; | ||
Value mask = nullptr; | ||
if (maskableOp.isMasked()) { | ||
rewriter.setInsertionPoint(maskableOp.getMaskingOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also add an insertion guard to restore the prev insertion point.
result = makeArithReduction(rewriter, loc, multiReductionOp.getKind(), | ||
operand, result); | ||
operand, result, nullptr, extractMask); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> /* nameOfTheFuncParam=nullptr */ for readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
ee5193c
to
6c94d5e
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
6c94d5e
to
bea5b92
Compare
This commit adds suppot to lower inner-parallel flavor of masked vector multi-reductions.
If there are not anymore comments, I d appreciate someone hitting the button :). |
…llvm#126722) This commit adds support to lower inner-parallel flavor of masked vector multi-reductions.
…llvm#126722) This commit adds support to lower inner-parallel flavor of masked vector multi-reductions.
This commit adds support to lower inner-parallel flavor of masked vector multi-reductions.