Skip to content

Commit 4b59b7b

Browse files
authored
[mlir][Linalg] Fix fusing of indexed linalg consumer with different axes (#140892)
When fusing two `linalg.genericOp`, where the producer has index semantics, invalid `affine.apply` ops can be generated where the number of indices do not match the number of loops in the fused genericOp. This patch fixes the issue by directly using the number of loops from the generated fused op.
1 parent 2d49bc0 commit 4b59b7b

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,7 @@ static void generateFusedElementwiseOpRegion(
231231
// `consumerToProducerLoopsMap` to map the producer indices.
232232
if (producer.hasIndexSemantics()) {
233233
// Add an index operation for every fused loop dimension.
234-
unsigned numFusedOpLoops =
235-
std::max(producer.getNumLoops(), consumer.getNumLoops());
234+
unsigned numFusedOpLoops = fusedOp.getNumLoops();
236235
SmallVector<Value> fusedIndices;
237236
fusedIndices.reserve(numFusedOpLoops);
238237
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,43 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
860860

861861
// -----
862862

863+
func.func @fusion_different_axes_indexed(%arg0: tensor<2x2xi32>) -> tensor<2xi32> {
864+
%0 = tensor.empty() : tensor<2x2xi32>
865+
%1 = linalg.generic {
866+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
867+
iterator_types = ["parallel", "parallel"]}
868+
ins(%arg0 : tensor<2x2xi32>) outs(%0 : tensor<2x2xi32>) {
869+
^bb0(%in: i32, %out: i32):
870+
%2 = linalg.index 1 : index
871+
%3 = arith.index_cast %2 : index to i32
872+
linalg.yield %3 : i32
873+
} -> tensor<2x2xi32>
874+
%4 = tensor.empty() : tensor<2xi32>
875+
%5 = linalg.generic {
876+
indexing_maps = [affine_map<(d0) -> (d0, 1)>, affine_map<(d0) -> (d0)>],
877+
iterator_types = ["parallel"]}
878+
ins(%1 : tensor<2x2xi32>) outs(%4 : tensor<2xi32>) {
879+
^bb0(%in: i32, %out: i32):
880+
linalg.yield %in : i32
881+
} -> tensor<2xi32>
882+
return %5 : tensor<2xi32>
883+
}
884+
885+
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0) -> (d0)>
886+
// CHECK: func @fusion_different_axes_indexed(
887+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x2xi32>
888+
// CHECK-DAG: %[[CST:.+]] = arith.constant 1 : i32
889+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<2xi32>
890+
// CHECK: %[[RESULT:.+]] = linalg.generic
891+
// CHECK-SAME: indexing_maps = [#[[MAP]]]
892+
// CHECK-SAME: outs(%[[INIT]] :
893+
// CHECK-NEXT: ^bb0(
894+
// CHECK-SAME: %[[B0:.+]]: i32
895+
// CHECK: linalg.yield %[[CST]] : i32
896+
// CHECK: return %[[RESULT]]
897+
898+
// -----
899+
863900
// CHECK-LABEL: func @fold_fill_generic_basic
864901
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xf32>) -> tensor<?xf32> {
865902
// CHECK-NOT: linalg.fill

0 commit comments

Comments
 (0)