Skip to content

Commit 813bbe0

Browse files
authored
[mlir][linalg] Allow fusing reshapes with non-parallel operands (#130148)
Removes the condition that checks that operand is not indexed by reduction iterators which allows for more fine-grained control via the reshape fusion control function. For example, users could allow fusing reshapes expand the M/N dims of a matmul but not the K dims (or preserve the current behavior by not fusing at all). --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 2619c2e commit 813bbe0

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,6 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
566566
// - All the indexing maps for operands and results are projected
567567
// permutations.
568568
// - The fused tensor is not a scalar.
569-
// - All the loops for the reshaped operand are parallel loops.
570569
SmallVector<utils::IteratorType> iteratorTypes =
571570
linalgOp.getIteratorTypesArray();
572571
AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
@@ -577,11 +576,7 @@ static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
577576
.getValue()
578577
.isProjectedPermutation();
579578
}) &&
580-
operandMap.getNumResults() > 0 &&
581-
llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) {
582-
return isParallelIterator(
583-
iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]);
584-
});
579+
operandMap.getNumResults() > 0;
585580
}
586581

587582
namespace {

mlir/test/Dialect/Linalg/reshape_fusion.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,31 @@ func.func @generic_op_reshape_consumer_fusion_projected(%arg0 : tensor<?x?xf32>,
482482

483483
// -----
484484

485+
func.func @fuse_collapse_reduction(%arg0: tensor<10x10x20xf32>) -> tensor<100xf32> {
486+
%c0 = arith.constant 0 : index
487+
%c_0 = arith.constant 0.0 : f32
488+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<10x10x20xf32> into tensor<100x20xf32>
489+
%2 = tensor.empty() : tensor<100xf32>
490+
%3 = linalg.fill ins(%c_0 : f32) outs(%2 : tensor<100xf32>) -> tensor<100xf32>
491+
%4 = linalg.generic {
492+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>],
493+
iterator_types = ["parallel", "reduction"]}
494+
ins(%0 : tensor<100x20xf32>) outs(%3 : tensor<100xf32>) {
495+
^bb0(%arg1 : f32, %arg2: f32):
496+
%4 = arith.addf %arg1, %arg2 : f32
497+
linalg.yield %4 : f32
498+
} -> tensor<100xf32>
499+
return %4 : tensor<100xf32>
500+
}
501+
502+
// CHECK: func @fuse_collapse_reduction
503+
// CHECK-SAME: %[[ARG0:.+]]: tensor<10x10x20xf32>
504+
// CHECK: %[[GENERIC:.+]] = linalg.generic
505+
// CHECK-SAME: ins(%[[ARG0]] : tensor<10x10x20xf32>)
506+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[GENERIC]]
507+
// CHECK: return %[[COLLAPSE]]
508+
// -----
509+
485510
func.func @no_fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
486511
%c0 = arith.constant 0 : index
487512
%0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<?x?xf32> into tensor<?xf32>

0 commit comments

Comments
 (0)