-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Add support for masks in castAwayContractionLeadingOneDim #81906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesPartial fix for #78787 Full diff: https://github.com/llvm/llvm-project/pull/81906.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 74382b027c2f48..f7f2b934056185 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -332,9 +332,12 @@ struct CastAwayTransferWriteLeadingOneDim
LogicalResult
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
RewriterBase &rewriter) {
- // TODO(#78787): Not supported masked op yet.
- if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
- return failure();
+ // Specifically for masked Ops for which we need to update the insertion
+ // point
+ PatternRewriter::InsertionGuard guard(rewriter);
+
+ auto isMasked =
+ cast<MaskableOpInterface>(contractOp.getOperation()).isMasked();
VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
if (oldAccType == nullptr)
return failure();
@@ -346,6 +349,12 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
// greedily to drop more.
int64_t dropDim = 1;
+ if (isMasked) {
+ // Update the insertion point to avoid adding more ops to the vector.mask
+ // region corresponding to `mask`
+ rewriter.setInsertionPointAfter(contractOp->getParentOp());
+ }
+
auto oldIndexingMaps = contractOp.getIndexingMapsArray();
SmallVector<AffineMap> newIndexingMaps;
@@ -368,6 +377,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
SmallVector<Value> operands = {contractOp.getLhs(), contractOp.getRhs(),
contractOp.getAcc()};
SmallVector<Value> newOperands;
+ auto loc = contractOp.getLoc();
for (const auto &it : llvm::enumerate(oldIndexingMaps)) {
// Check if the dim to be dropped exists as a leading dim in the operand
@@ -405,7 +415,7 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
operands[it.index()] = rewriter.create<vector::TransposeOp>(
- contractOp.getLoc(), operands[it.index()], perm);
+ loc, operands[it.index()], perm);
}
}
// We have taken care to have the dim to be dropped be
@@ -429,17 +439,30 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
// Extract if its a valid extraction, otherwise use the operand
// without extraction.
newOperands.push_back(
- validExtract ? rewriter.create<vector::ExtractOp>(contractOp.getLoc(),
- operands[it.index()],
- splatZero(dropDim))
+ validExtract ? rewriter.create<vector::ExtractOp>(
+ loc, operands[it.index()], splatZero(dropDim))
: operands[it.index()]);
}
- auto newContractOp = rewriter.create<vector::ContractionOp>(
- contractOp.getLoc(), newOperands[0], newOperands[1], newOperands[2],
+ Operation *newContractOp = rewriter.create<vector::ContractionOp>(
+ loc, newOperands[0], newOperands[1], newOperands[2],
rewriter.getAffineMapArrayAttr(newIndexingMaps),
rewriter.getArrayAttr(newIteratorTypes), contractOp.getKind());
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- contractOp, contractOp->getResultTypes()[0], newContractOp);
+
+ if (isMasked) {
+ auto mask = contractOp.getMaskingOp();
+ auto newMask = rewriter.create<vector::ExtractOp>(loc, mask.getMask(),
+ splatZero(dropDim));
+
+ newContractOp =
+ mlir::vector::maskOperation(rewriter, newContractOp, newMask);
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ mask, contractOp->getResultTypes()[0], newContractOp->getResults()[0]);
+ } else {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ contractOp, contractOp->getResultTypes()[0],
+ newContractOp->getResults()[0]);
+ }
+
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
index af6e636245b04e..4ba51c5953d13c 100644
--- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir
@@ -30,6 +30,80 @@ func.func @cast_away_contraction_leading_one_dims(%arg0: vector<1x16x8xf32>, %ar
}
// -----
+// CHECK: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK: #[[$MAP_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_const_mask
+// CHECK: %[[MASK:.*]] = vector.constant_mask [15, 15, 8] : vector<16x16x8xi1>
+// CHECK: %[[R0:.*]] = vector.extract %{{.*}}[0] : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK: %[[R1:.*]] = vector.extract %{{.*}}[0] : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK: %[[R2:.*]] = vector.extract %{{.*}}[0] : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK: %[[CONTRACT:.*]] = vector.mask %[[MASK]] {
+// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]], #[[$MAP_2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK: return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+ affine_map<(l, i, j, k) -> (l, i, k)>,
+ affine_map<(l, i, j, k) -> (l, k, j)>,
+ affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+ indexing_maps = #contraction_accesses0,
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_const_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
+ %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
+ %0 = vector.mask %mask {
+ vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+ } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+ return %0 : vector<1x16x16xf32>
+}
+
+// -----
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2, d1)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func.func @cast_away_contraction_leading_one_dim_under_mask
+// CHECK: %[[R0:.*]] = vector.extract %{{.*}} : vector<16x8xf32> from vector<1x16x8xf32>
+// CHECK: %[[R1:.*]] = vector.extract %{{.*}} : vector<8x16xf32> from vector<1x8x16xf32>
+// CHECK: %[[R2:.*]] = vector.extract %{{.*}} : vector<16x16xf32> from vector<1x16x16xf32>
+// CHECK: %[[M:.*]] = vector.extract %{{.*}} : vector<16x16x8xi1> from vector<1x16x16x8xi1>
+// CHECK: %[[CONTRACT:.*]] = vector.mask %[[M]] {
+// CHECK-SAME: vector.contract {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+// CHECK-SAME: %[[R0]], %[[R1]], %[[R2]] : vector<16x8xf32>, vector<8x16xf32> into vector<16x16xf32>
+// CHECK-SAME: } : vector<16x16x8xi1> -> vector<16x16xf32>
+// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[CONTRACT]] : vector<16x16xf32> to vector<1x16x16xf32>
+// CHECK-NEXT: return %[[RES]] : vector<1x16x16xf32>
+
+#contraction_accesses0 = [
+ affine_map<(l, i, j, k) -> (l, i, k)>,
+ affine_map<(l, i, j, k) -> (l, k, j)>,
+ affine_map<(l, i, j, k) -> (l, i, j)>
+]
+#contraction_trait0 = {
+ indexing_maps = #contraction_accesses0,
+ iterator_types = ["parallel", "parallel", "parallel", "reduction"]
+}
+
+func.func @cast_away_contraction_leading_one_dim_under_mask(
+ %arg0: vector<1x16x8xf32>,
+ %arg1: vector<1x8x16xf32>,
+ %arg2: vector<1x16x16xf32>,
+ %mask: vector<1x16x16x8xi1>) -> vector<1x16x16xf32> {
+ %0 = vector.mask %mask {
+ vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
+ } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
+ return %0: vector<1x16x16xf32>
+}
+
+// -----
+
// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1) -> (d1, d0)>
// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0)>
@@ -164,36 +238,6 @@ func.func @cast_away_contraction_leading_one_dims_nonleadingunitdim_rank4_acctra
return %0: vector<1x1x2x16xf32>
}
-// -----
-
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
-// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
-
-// CHECK-LABEL: not_insert_cast_for_contraction_under_mask
-// CHECK: %[[MASK:.+]] = vector.constant_mask
-// CHECK: %[[CASTED_MASK:.+]] = vector.broadcast %[[MASK]]
-// CHECK: %[[RET:.+]] = vector.mask %[[CASTED_MASK]] {
-// CHECK-SAME: vector.contract {{.*}} : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32> }
-// CHECK: return %[[RET]] : vector<1x16x16xf32>
-
-#contraction_accesses0 = [
- affine_map<(l, i, j, k) -> (l, i, k)>,
- affine_map<(l, i, j, k) -> (l, k, j)>,
- affine_map<(l, i, j, k) -> (l, i, j)>
-]
-#contraction_trait0 = {
- indexing_maps = #contraction_accesses0,
- iterator_types = ["parallel", "parallel", "parallel", "reduction"]
-}
-
-func.func @not_insert_cast_for_contraction_under_mask(%arg0: vector<1x16x8xf32>, %arg1: vector<1x8x16xf32>, %arg2: vector<1x16x16xf32>) -> vector<1x16x16xf32> {
- %mask = vector.constant_mask [1, 15, 15, 8] : vector<1x16x16x8xi1>
- %0 = vector.mask %mask {
- vector.contract #contraction_trait0 %arg0, %arg1, %arg2 : vector<1x16x8xf32>, vector<1x8x16xf32> into vector<1x16x16xf32>
- } : vector<1x16x16x8xi1> -> vector<1x16x16xf32>
- return %0 : vector<1x16x16xf32>
-}
// -----
// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims
|
@@ -346,6 +349,12 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, | |||
// greedily to drop more. | |||
int64_t dropDim = 1; | |||
|
|||
if (isMasked) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some snippets under the comment Vector mask setup
or similar that doing something like this. I think we should perhaps create a utility for those and define some kind of canonical form to do this at pattern rewrite level. For conversions, the goal when we introduced the mask op was to try to reduce this kind of conditional code and provide some infra to make it as transparent as possible, hence the mask conversion pattern class that we use for LLVM conversion. Let's brainstorm a bit about the different options and see what can be improved in this regard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dcaballe I was implementing a general purpose implementation of this, something like tanmaysachan@242f901
Does this look okay to you? Can refactor for common code between the masked vs non-masked rewrite classes to be a function, and the pattern rewrites handle the entry points.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, returning to this after two weeks of being OOO.
@tanmaysachan , from a quick scan, you have adopted VectorMaskOpConversionBase which matches vector.mask
. IIUC, that's going to be insufficient in this case - we need something that would work both for:
vector.mask {vector.contract}
, andvector.contract
.
VectorMaskOpConversionBase
would only work for the first case. Perhaps @dcaballe had something else in mind, but IMHO we need another "base" class to accomodate for that. This a blocker for me, so I'm proposing one here (*):
I think we should perhaps create a utility for those and define some kind of canonical form to do this at pattern rewrite level.
Does #83827 make sense? Happy to try something else :) Also, regardless of the long-term approach that we take here, would you be OK with me landing this to unblock linalg.mmt4d
investigation?
(*) Apologies @tanmaysachan if you are also actively working on this, but this is quite urgent for me. Mindful of delays due to time difference, I decided to go ahead and draft something quickly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries, I was unsure about what the go ahead way for this was.
b85869a
to
8d1a121
Compare
Rebased on top of #83827 |
8d1a121
to
1e9c7a1
Compare
Stopped while trying to get the added example to compile using masks. * Depends on llvm#81906, which solves one issue. * Not sure what happens after that. * Compare against non-masked version and "reverse engineer" a solution/pipeline.
8d650b3
to
f6f3665
Compare
@dcaballe Done :) I have rewritten this to build on top of #83827, wdyt? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
f6f3665
to
01d87b5
Compare
Updates `castAwayContractionLeadingOneDim` to inherit from `MaskableOpRewritePattern` so that this pattern can support masking. Builds on top of llvm#83827
01d87b5
to
734fbd3
Compare
…Dim (llvm#81906) Updates `castAwayContractionLeadingOneDim` to inherit from `MaskableOpRewritePattern` so that this pattern can support masking. Builds on top of llvm#83827
Updates
castAwayContractionLeadingOneDim
to inherit fromMaskableOpRewritePattern
so that this pattern can support masking.Builds on top of #83827