-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Linalg] Fix insert_slice fusion with rank reduction #130961
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Linalg] Fix insert_slice fusion with rank reduction #130961
Conversation
Insert_slice fusion with a linalg producer does not account for possible rank-reduction in the insert_slice return type. When that happens, a tosa.cast gets generated due to the type mismatch which is invalid for tensor with different rank. This later trips other pass.
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Thomas Preud'homme (RoboTux) ChangesInsert_slice fusion with a linalg producer does not account for Full diff: https://github.com/llvm/llvm-project/pull/130961.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 223d728b0b27d..81b204df5a0aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -18,6 +18,7 @@
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dominance.h"
@@ -26,6 +27,7 @@
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallBitVector.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
@@ -235,6 +237,31 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
}
+/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
+/// `from`.
+tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
+ const llvm::SmallBitVector &dropDims) {
+ auto fromType = cast<ShapedType>(from.getType());
+ assert(fromType.getRank() == dropDims.size());
+ SmallVector<ReassociationIndices, 2> reassocIdxsVec;
+ ReassociationIndices reassocIdxs;
+
+ bool foundKeptDim = false;
+ for (int dim = 0; dim < fromType.getRank(); dim++) {
+ if (!dropDims.test(dim)) {
+ if (foundKeptDim) {
+ reassocIdxsVec.push_back(reassocIdxs);
+ reassocIdxs.clear();
+ }
+ foundKeptDim = true;
+ }
+ reassocIdxs.push_back(dim);
+ }
+ if (!reassocIdxs.empty())
+ reassocIdxsVec.push_back(reassocIdxs);
+ return b.create<tensor::CollapseShapeOp>(loc, from, reassocIdxsVec);
+}
+
FailureOr<FusionInfo>
mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpOperand &consumerOpOperand) {
@@ -255,6 +282,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
return failure();
}
+ llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
// If producer is already in the same block as consumer, we are done.
if (consumerOpOperand.get().getParentBlock() ==
@@ -272,12 +300,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
consumerOpOperand);
// Replace use.
+ Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
+ Type consumerType = consumerOpOperand.get().getType();
+ // Rank-reduction occured as part of the extract_slice.
+ if (cast<ShapedType>(consumerType).getRank() !=
+ cast<ShapedType>(def.getType()).getRank())
+ def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims);
// Canonicalizations are not guaranteed to have happened before constructing
// `fusedProducer`. In the tensor case this can result in temporary type
// mismatches. Insert a `tensor.cast` op to propagate the transformation
// invariant that types are compatible.
- Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
- Type consumerType = consumerOpOperand.get().getType();
if (consumerType != def.getType())
def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
consumerOpOperand.set(def);
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 0f27a92c119cf..b4fbdfacde899 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -318,3 +318,66 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
}
return %for0 : tensor<64x128xf32>
}
+
+// -----
+
+func.func @rank_reduced_extract_slice(%arg0: tensor<6x6x1x1x1x1xf32>, %arg1: tensor<6x6x1x1xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c6 = arith.constant 6 : index
+ %cst = arith.constant 0.0 : f32
+ %init1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+ %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+ %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d6, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<6x6x1x1x1x1xf32>, tensor<6x6x1x1xf32>) outs(%fill1 : tensor<6x6x1x1x1x1xf32>) {
+ ^bb0(%in: f32, %in_1: f32, %out: f32):
+ %10 = arith.mulf %in, %in_1 : f32
+ %11 = arith.addf %out, %10 : f32
+ linalg.yield %11 : f32
+ } -> tensor<6x6x1x1x1x1xf32>
+ %init2 = tensor.empty() : tensor<4x6xf32>
+ %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
+ %2 = tensor.extract_slice %0[0, %arg4, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+ %init3 = tensor.empty() : tensor<4x2xf32>
+ %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
+ %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
+ ^bb0(%in: f32, %in_1: f32, %out: f32):
+ %20 = arith.mulf %in, %in_1 : f32
+ %21 = arith.addf %out, %20 : f32
+ linalg.yield %21 : f32
+ } -> tensor<4x2xf32>
+ %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+ scf.yield %4 : tensor<4x6xf32>
+ }
+ return %1 : tensor<4x6xf32>
+}
+
+// CHECK: func @rank_reduced_extract_slice(
+// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<6x6x1x1x1x1xf32>
+// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<6x6x1x1xf32>
+// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
+
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[FILL_PROD:.*]] = linalg.fill ins({{%.*}} : f32)
+// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+// CHECK: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
+// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
+// CHECK: %[[FILL_CONS:.*]] = linalg.fill ins({{%.*}} : f32)
+// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<4x2xf32>) -> tensor<4x2xf32>
+// CHECK: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
+// CHECK-DAG: %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]], 0, 0] [6, 2, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x2x1x1xf32>
+// CHECK-DAG: %[[FILL_PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+
+// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : tensor<6x2x1x1x1x1xf32>, tensor<6x2x1x1xf32>)
+// CHECK-SAME: outs(%[[FILL_PROD_SLICE]] : tensor<6x2x1x1x1x1xf32>)
+// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
+// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
+// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>)
+// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
+// CHECK: return %[[FOR]] : tensor<4x6xf32>
|
Ping? |
1 similar comment
Ping? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall, I've made some minor suggestions. I am surprised we didn't hit this before.
Thanks for the contribution Thomas and apologies for the delay reviewing 🙏🏻
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates, Thomas!
I've left a few smallish comments, but before addressing those, it would be good to make sure that we do indeed need tensor.collapse_shape
(as opposed to a rank reducing tensor.extract_slice
). I am sure that you thought it through and there's a good reason to insert collapse_shape
, but it's not obvious to me just yet 😅
@@ -272,12 +308,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, | |||
consumerOpOperand); | |||
|
|||
// Replace use. | |||
Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); | |||
Type consumerType = consumerOpOperand.get().getType(); | |||
// Rank-reduction occured as part of the extract_slice. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have just realised ... Why not re-use rank-reducing tensor.extract_slice
instead? It's a bit counter-intuitive to make that a rank-reducing tensor.extract_slice
is replaced with a pair of tensor.extract_slice
+ tensor.collapse_shape
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The extract_slice is done on the producer's arguments while the expand_shape is done on its result. I did look into doing the rank-reducing on the argument since the result type is derived from the DPS init, but the pass applies to both named and generic linalg ops which makes this tricky. In the case of a generic, one of the input might be using the unit dimension in a non-dim affine expression which requires changing the affine maps to update, but since the pass works on the LinalgOp interface there is no way to change the affine map (since named ops have an implicit map). Of course one could do an if but the code becomes quite complex.
My feeling this is the work for another pass. With this one, the producing linalg is moved inside the loop, alongside the consuming generic. Then it's a matter of folding expand shape / extract slice which sounds like a simple linalg fusion that probably already exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy for me to close this?
Thanks for the review @banach-space , I hope the new revision addresses all your comment in a satisfactory way. |
That's a very good question and I'll need to get back to you before I can answer fully. To be clear, the fusion does not insert a new extract_slice, this is done by another transform. I don't even think it was there in the earlier example and I don't know why there's a new linalg as well so I will look into understanding that. Normally there is an extract slice that might happen to collapse some unit dimension which then causes a type mismatch. This patch is to ensure we do the proper collapse_shape after the extract_slice is fused to reproduce that unit dimension collapsing and make the type match. By the way the if is because the pass requires the producing linalg and consuming linalg (the one that consume the extract_slice) to be in different block, the bug I was facing was a case of loop but I used a if here to make the code simpler. |
My bad, the pass does indeed create an extract slice (or a subview if working on memref). The code that does that is in Linalg Utils and doesn't deal with dropped unit dimension. I'll rework the patch along those lines instead. |
Sorry for the long pause. I've restarted working on it and should have a patch soon. |
…sert_slice_fusion
- rename collapseTo to better reflect its usage - assert it only collapse unit dimensions - rename ReassociationIndices-using variables to reassocGroup and reassocMaps, the same terminology used in tensor.collapse_shape documentation - use more representative test with comments to better explain what the patch does
This is answered in my other comment: #130961 (comment) but I'll rephrase it here in case that comment wasn't clear enough. First of all, the fusion is worth doing in case of rank-reducing extract_slice even if an expand_shape is involved because the producer is moved inside the loop and operate directly on rank-reduced buffers. The only case it wouldn't be worth it is if the extract_slice was only used to do rank-reduction but that would be strange IR in my opinion, a collapse_shape should have been used instead in that case. Second the collapse_shape is used because it would add a lot of complexity to fold the rank-reduction in the extract_slice operations of the producer's operands. This is because the producing linalg could be something like:
If the dimension #1 was used in a loop of step 2 and dimension #0 was rank-reduced, one would at the very least need to removed the dropped dimension from the indexing_maps as follows:
However the pass does not operate on GenericOp but on LinalgOp which does not allow modifying indexing maps since it can be an interface for a named linalg op with an implicit indexing maps. So one would need to distinguish between GenericOp and other LinalgOp. I also feel that this is a further optimization to fold the rank-reducing earlier, better suited into a dedicated pass if there isn't already one. Does that make sense to you? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me this makes sense, thanks for the discussion and apologies for the delay!
I've left a few comments inline, but I also have two higher-level suggestion:
- Move
dropGivenUnitDims
to mlir/lib/Dialect/Tensor/Utils.cpp -> this hook is a fairly generic helper. - The test that's added here is quite complex - could you try simplifying it? In principle, the only thing that matters is "rank reduction", everything else (e.g. other ops) is creating noise. I've made some suggestions inline.
Let me also ping @qedawkins and @javedabsar , who are experts in this area :)
@@ -271,12 +309,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, | |||
consumerOpOperand); | |||
|
|||
// Replace use. | |||
Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); | |||
Type consumerType = consumerOpOperand.get().getType(); | |||
// Rank-reduction occured as part of the extract_slice. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Rank-reduction occured as part of the extract_slice.
- occured -> occurred
- shouldn't we verify that it's indeed coming from
tensor.extract_slice
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That function only fuses ExtractSliceOp, see a little above:
`Value inputTensor = consumerOpOperand.get();
// Must be an extract_slice op to guarantee there are loops we can fuse into.
auto sliceOp = inputTensor.getDefiningOptensor::ExtractSliceOp();
if (!sliceOp) {
LLVM_DEBUG(llvm::dbgs()
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
return failure();
}
`
dropGivenUnitDims(): - move assert out of loop - rework algorithm to make grouping more explicit and avoid complex nested ifs - fix occured typo Test: remove all tensor.empty and linalg.fill
✅ With the latest revision this PR passed the C/C++ code formatter. |
Sorry I forgot to move the function into Utils.cpp, will do. |
Done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Thomas, final suggestions from me.
Utils: - drop comments on implementation - rename from into src Fusion: - restrict live range of droppedDims - clarify comment for rank-reduction check Test: - Use more descriptive SSA and FileCheck variables - Emphasize the rank-reducing extract_slice in the input IR as the key aspect of the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for all the updates, LGTM!
With no comments from other reviewers, I think that this is safe to land. You may want to wait till tomorrow - just in case somebody changes their mind and decides to review :) (given that this has been in review for a while, I suspect that you are not in a rush)
I really enjoyed the discussion, Thomas! Send more PRs :)
Insert_slice fusion with a linalg producer does not account for possible rank-reduction in the insert_slice return type. When that happens, a tensor.cast gets generated due to the type mismatch which is invalid for tensor with different rank. This later trips other pass.
Insert_slice fusion with a linalg producer does not account for
possible rank-reduction in the insert_slice return type. When that
happens, a tensor.cast gets generated due to the type mismatch which is
invalid for tensor with different rank. This later trips other pass.