Skip to content

[MLIR][Linalg] Fix insert_slice fusion with rank reduction #130961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 23, 2025

Conversation

RoboTux
Copy link
Contributor

@RoboTux RoboTux commented Mar 12, 2025

Insert_slice fusion with a linalg producer does not account for
possible rank-reduction in the insert_slice return type. When that
happens, a tensor.cast gets generated due to the type mismatch which is
invalid for tensor with different rank. This later trips other pass.

Insert_slice fusion with a linalg producer does not account for
possible rank-reduction in the insert_slice return type. When that
happens, a tosa.cast gets generated due to the type mismatch which is
invalid for tensor with different rank. This later trips other pass.
@llvmbot
Copy link
Member

llvmbot commented Mar 12, 2025

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Thomas Preud'homme (RoboTux)

Changes

Insert_slice fusion with a linalg producer does not account for
possible rank-reduction in the insert_slice return type. When that
happens, a tosa.cast gets generated due to the type mismatch which is
invalid for tensor with different rank. This later trips other pass.


Full diff: https://github.com/llvm/llvm-project/pull/130961.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp (+34-2)
  • (modified) mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir (+63)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
index 223d728b0b27d..81b204df5a0aa 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Dominance.h"
@@ -26,6 +27,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 
@@ -235,6 +237,31 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) {
   return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand);
 }
 
+/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor
+/// `from`.
+tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from,
+                                   const llvm::SmallBitVector &dropDims) {
+  auto fromType = cast<ShapedType>(from.getType());
+  assert(fromType.getRank() == dropDims.size());
+  SmallVector<ReassociationIndices, 2> reassocIdxsVec;
+  ReassociationIndices reassocIdxs;
+
+  bool foundKeptDim = false;
+  for (int dim = 0; dim < fromType.getRank(); dim++) {
+    if (!dropDims.test(dim)) {
+      if (foundKeptDim) {
+        reassocIdxsVec.push_back(reassocIdxs);
+        reassocIdxs.clear();
+      }
+      foundKeptDim = true;
+    }
+    reassocIdxs.push_back(dim);
+  }
+  if (!reassocIdxs.empty())
+    reassocIdxsVec.push_back(reassocIdxs);
+  return b.create<tensor::CollapseShapeOp>(loc, from, reassocIdxsVec);
+}
+
 FailureOr<FusionInfo>
 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
                                    OpOperand &consumerOpOperand) {
@@ -255,6 +282,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
                << "\nNot fusable, not an extract_slice op: " << inputTensor);
     return failure();
   }
+  llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
 
   // If producer is already in the same block as consumer, we are done.
   if (consumerOpOperand.get().getParentBlock() ==
@@ -272,12 +300,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
            consumerOpOperand);
 
   // Replace use.
+  Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
+  Type consumerType = consumerOpOperand.get().getType();
+  // Rank-reduction occured as part of the extract_slice.
+  if (cast<ShapedType>(consumerType).getRank() !=
+      cast<ShapedType>(def.getType()).getRank())
+    def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims);
   // Canonicalizations are not guaranteed to have happened before constructing
   // `fusedProducer`. In the tensor case this can result in temporary type
   // mismatches. Insert a `tensor.cast` op to propagate the transformation
   // invariant that types are compatible.
-  Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
-  Type consumerType = consumerOpOperand.get().getType();
   if (consumerType != def.getType())
     def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
   consumerOpOperand.set(def);
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
index 0f27a92c119cf..b4fbdfacde899 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir
@@ -318,3 +318,66 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
   }
   return %for0 : tensor<64x128xf32>
 }
+
+// -----
+
+func.func @rank_reduced_extract_slice(%arg0: tensor<6x6x1x1x1x1xf32>, %arg1: tensor<6x6x1x1xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> {
+  %c0 = arith.constant 0 : index
+  %c2 = arith.constant 2 : index
+  %c6 = arith.constant 6 : index
+  %cst = arith.constant 0.0 : f32
+  %init1 = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+  %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+  %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d6, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<6x6x1x1x1x1xf32>, tensor<6x6x1x1xf32>) outs(%fill1 : tensor<6x6x1x1x1x1xf32>) {
+  ^bb0(%in: f32, %in_1: f32, %out: f32):
+    %10 = arith.mulf %in, %in_1 : f32
+    %11 = arith.addf %out, %10 : f32
+    linalg.yield %11 : f32
+  } -> tensor<6x6x1x1x1x1xf32>
+  %init2 = tensor.empty() : tensor<4x6xf32>
+  %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) {
+    %2 = tensor.extract_slice %0[0, %arg4, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32>
+    %init3 = tensor.empty() : tensor<4x2xf32>
+    %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32>
+    %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) {
+    ^bb0(%in: f32, %in_1: f32, %out: f32):
+      %20 = arith.mulf %in, %in_1 : f32
+      %21 = arith.addf %out, %20 : f32
+      linalg.yield %21 : f32
+    } -> tensor<4x2xf32>
+    %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1]  : tensor<4x2xf32> into tensor<4x6xf32>
+    scf.yield %4 : tensor<4x6xf32>
+  }
+  return %1 : tensor<4x6xf32>
+}
+
+//       CHECK: func @rank_reduced_extract_slice(
+//  CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<6x6x1x1x1x1xf32>
+//  CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<6x6x1x1xf32>
+//  CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32>
+
+//   CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+//   CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+//       CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32>
+//       CHECK: %[[FILL_PROD:.*]] = linalg.fill ins({{%.*}} : f32)
+//  CHECK-SAME:     outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32>
+//       CHECK: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32>
+//       CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32>
+//       CHECK: %[[FILL_CONS:.*]] = linalg.fill ins({{%.*}} : f32)
+//  CHECK-SAME:     outs(%[[EMPTY_CONS]] : tensor<4x2xf32>) -> tensor<4x2xf32>
+//       CHECK: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]])
+//   CHECK-DAG:   %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]  : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+//   CHECK-DAG:   %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]], 0, 0] [6, 2, 1, 1] [1, 1, 1, 1]  : tensor<6x6x1x1xf32> to tensor<6x2x1x1xf32>
+//   CHECK-DAG:   %[[FILL_PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32>
+
+//       CHECK:    %[[MMUL_PROD:.*]] = linalg.generic
+//  CHECK-SAME:        ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : tensor<6x2x1x1x1x1xf32>, tensor<6x2x1x1xf32>)
+//  CHECK-SAME:        outs(%[[FILL_PROD_SLICE]] : tensor<6x2x1x1x1x1xf32>)
+//       CHECK:    %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32>
+//       CHECK:    %[[MMUL_CONS:.*]] = linalg.generic
+//  CHECK-SAME:        ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
+//  CHECK-SAME:        outs(%[[FILL_CONS]] : tensor<4x2xf32>)
+//       CHECK:   %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
+//       CHECK:   scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
+//       CHECK: return %[[FOR]] : tensor<4x6xf32>

@javedabsar1 javedabsar1 self-requested a review March 14, 2025 11:08
@RoboTux
Copy link
Contributor Author

RoboTux commented Mar 17, 2025

Ping?

1 similar comment
@RoboTux
Copy link
Contributor Author

RoboTux commented Mar 24, 2025

Ping?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, I've made some minor suggestions. I am surprised we didn't hit this before.

Thanks for the contribution Thomas and apologies for the delay reviewing 🙏🏻

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, Thomas!

I've left a few smallish comments, but before addressing those, it would be good to make sure that we do indeed need tensor.collapse_shape (as opposed to a rank reducing tensor.extract_slice). I am sure that you thought it through and there's a good reason to insert collapse_shape, but it's not obvious to me just yet 😅

@@ -272,12 +308,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
consumerOpOperand);

// Replace use.
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
Type consumerType = consumerOpOperand.get().getType();
// Rank-reduction occured as part of the extract_slice.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just realised ... Why not re-use rank-reducing tensor.extract_slice instead? It's a bit counter-intuitive to make that a rank-reducing tensor.extract_slice is replaced with a pair of tensor.extract_slice + tensor.collapse_shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extract_slice is done on the producer's arguments while the expand_shape is done on its result. I did look into doing the rank-reducing on the argument since the result type is derived from the DPS init, but the pass applies to both named and generic linalg ops which makes this tricky. In the case of a generic, one of the input might be using the unit dimension in a non-dim affine expression which requires changing the affine maps to update, but since the pass works on the LinalgOp interface there is no way to change the affine map (since named ops have an implicit map). Of course one could do an if but the code becomes quite complex.

My feeling this is the work for another pass. With this one, the producing linalg is moved inside the loop, alongside the consuming generic. Then it's a matter of folding expand shape / extract slice which sounds like a simple linalg fusion that probably already exists.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy for me to close this?

@RoboTux
Copy link
Contributor Author

RoboTux commented Mar 26, 2025

Thanks for the review @banach-space , I hope the new revision addresses all your comment in a satisfactory way.

@RoboTux
Copy link
Contributor Author

RoboTux commented Mar 26, 2025

Thanks for the updates, Thomas!

I've left a few smallish comments, but before addressing those, it would be good to make sure that we do indeed need tensor.collapse_shape (as opposed to a rank reducing tensor.extract_slice). I am sure that you thought it through and there's a good reason to insert collapse_shape, but it's not obvious to me just yet 😅

That's a very good question and I'll need to get back to you before I can answer fully. To be clear, the fusion does not insert a new extract_slice, this is done by another transform. I don't even think it was there in the earlier example and I don't know why there's a new linalg as well so I will look into understanding that. Normally there is an extract slice that might happen to collapse some unit dimension which then causes a type mismatch. This patch is to ensure we do the proper collapse_shape after the extract_slice is fused to reproduce that unit dimension collapsing and make the type match. By the way the if is because the pass requires the producing linalg and consuming linalg (the one that consume the extract_slice) to be in different block, the bug I was facing was a case of loop but I used a if here to make the code simpler.

@RoboTux
Copy link
Contributor Author

RoboTux commented Mar 26, 2025

My bad, the pass does indeed create an extract slice (or a subview if working on memref). The code that does that is in Linalg Utils and doesn't deal with dropped unit dimension. I'll rework the patch along those lines instead.

@RoboTux RoboTux marked this pull request as draft March 26, 2025 15:48
@RoboTux
Copy link
Contributor Author

RoboTux commented May 4, 2025

Sorry for the long pause. I've restarted working on it and should have a patch soon.

RoboTux added 2 commits May 7, 2025 09:31
- rename collapseTo to better reflect its usage
- assert it only collapse unit dimensions
- rename ReassociationIndices-using variables to reassocGroup and
  reassocMaps, the same terminology used in tensor.collapse_shape
  documentation
- use more representative test with comments to better explain what the
  patch does
@RoboTux
Copy link
Contributor Author

RoboTux commented May 7, 2025

Thanks for the updates, Thomas!
I've left a few smallish comments, but before addressing those, it would be good to make sure that we do indeed need tensor.collapse_shape (as opposed to a rank reducing tensor.extract_slice). I am sure that you thought it through and there's a good reason to insert collapse_shape, but it's not obvious to me just yet 😅

That's a very good question and I'll need to get back to you before I can answer fully. To be clear, the fusion does not insert a new extract_slice, this is done by another transform. I don't even think it was there in the earlier example and I don't know why there's a new linalg as well so I will look into understanding that. Normally there is an extract slice that might happen to collapse some unit dimension which then causes a type mismatch. This patch is to ensure we do the proper collapse_shape after the extract_slice is fused to reproduce that unit dimension collapsing and make the type match. By the way the if is because the pass requires the producing linalg and consuming linalg (the one that consume the extract_slice) to be in different block, the bug I was facing was a case of loop but I used a if here to make the code simpler.

This is answered in my other comment: #130961 (comment) but I'll rephrase it here in case that comment wasn't clear enough.

First of all, the fusion is worth doing in case of rank-reducing extract_slice even if an expand_shape is involved because the producer is moved inside the loop and operate directly on rank-reduced buffers. The only case it wouldn't be worth it is if the extract_slice was only used to do rank-reduction but that would be strange IR in my opinion, a collapse_shape should have been used instead in that case.

Second the collapse_shape is used because it would add a lot of complexity to fold the rank-reduction in the extract_slice operations of the producer's operands. This is because the producing linalg could be something like:

#map0 = affine_map<(d0, d1) -> (d0, d1)
#map1 = affine_map<(d0, d1) -> (d0, d0 + d1)
linalg.generic
    {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]}
    ins(%0, %1 : tensor<1x4xf32>, tensor <1x4xf32>) outs(%2 : tensor<1x4>) {
^bb0(%in : f32, %in1 : f32):
    %result = arith.addf %in, %in1 : f32
    linalg.yield %result : f32
}
// followed by loop with consumer

If the dimension #1 was used in a loop of step 2 and dimension #0 was rank-reduced, one would at the very least need to removed the dropped dimension from the indexing_maps as follows:

#map0 = affine_map<(d1) -> (d1)
#map1 = affine_map<(d1) -> (0 + d1)

// following is inside the loop where the consumer is
%2 = tensor.empty() : tensor<2xf32>
%3 = tensor.extract_slice %0 into %2[1,%iv][1,2][1,1] : tensor<2xf32>
%4 = tensor.empty() : tensor<2xf32>
%5 = tensor.extract_slice %1 into %4[1,%iv][1,2][1,1] : tensor<2xf32>
linalg.generic
    {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]}
    ins(%3, %5 : tensor<2xf32>, tensor <2xf32>) outs(%2 : tensor<2xf32>) {
^bb0(%in : f32, %in1 : f32):
    %result = arith.addf %in, %in1 : f32
    linalg.yield %result : f32
}

However the pass does not operate on GenericOp but on LinalgOp which does not allow modifying indexing maps since it can be an interface for a named linalg op with an implicit indexing maps. So one would need to distinguish between GenericOp and other LinalgOp. I also feel that this is a further optimization to fold the rank-reducing earlier, better suited into a dedicated pass if there isn't already one.

Does that make sense to you?

@RoboTux RoboTux marked this pull request as ready for review May 7, 2025 11:51
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me this makes sense, thanks for the discussion and apologies for the delay!

I've left a few comments inline, but I also have two higher-level suggestion:

  • Move dropGivenUnitDims to mlir/lib/Dialect/Tensor/Utils.cpp -> this hook is a fairly generic helper.
  • The test that's added here is quite complex - could you try simplifying it? In principle, the only thing that matters is "rank reduction", everything else (e.g. other ops) is creating noise. I've made some suggestions inline.

Let me also ping @qedawkins and @javedabsar , who are experts in this area :)

@@ -271,12 +309,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
consumerOpOperand);

// Replace use.
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
Type consumerType = consumerOpOperand.get().getType();
// Rank-reduction occured as part of the extract_slice.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Rank-reduction occured as part of the extract_slice.

  • occured -> occurred
  • shouldn't we verify that it's indeed coming from tensor.extract_slice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That function only fuses ExtractSliceOp, see a little above:
`Value inputTensor = consumerOpOperand.get();

// Must be an extract_slice op to guarantee there are loops we can fuse into.
auto sliceOp = inputTensor.getDefiningOptensor::ExtractSliceOp();
if (!sliceOp) {
LLVM_DEBUG(llvm::dbgs()
<< "\nNot fusable, not an extract_slice op: " << inputTensor);
return failure();
}
`

dropGivenUnitDims():
- move assert out of loop
- rework algorithm to make grouping more explicit and avoid complex
  nested ifs
- fix occured typo

Test: remove all tensor.empty and linalg.fill
Copy link

github-actions bot commented May 20, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@RoboTux
Copy link
Contributor Author

RoboTux commented May 20, 2025

Sorry I forgot to move the function into Utils.cpp, will do.

@RoboTux
Copy link
Contributor Author

RoboTux commented May 20, 2025

* Move `dropGivenUnitDims` to mlir/lib/Dialect/Tensor/Utils.cpp -> this hook is a fairly generic helper.

Done

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Thomas, final suggestions from me.

Utils:
- drop comments on implementation
- rename from into src

Fusion:
- restrict live range of droppedDims
- clarify comment for rank-reduction check

Test:
- Use more descriptive SSA and FileCheck variables
- Emphasize the rank-reducing extract_slice in the input IR as the key
  aspect of the test.
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for all the updates, LGTM!

With no comments from other reviewers, I think that this is safe to land. You may want to wait till tomorrow - just in case somebody changes their mind and decides to review :) (given that this has been in review for a while, I suspect that you are not in a rush)

I really enjoyed the discussion, Thomas! Send more PRs :)

@RoboTux RoboTux merged commit 2e12bad into llvm:main May 23, 2025
11 checks passed
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
Insert_slice fusion with a linalg producer does not account for
possible rank-reduction in the insert_slice return type. When that
happens, a tensor.cast gets generated due to the type mismatch which is
invalid for tensor with different rank. This later trips other pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants