Skip to content

Commit 06514c5

Browse files
IanWood1qedawkins
andauthored
[MLIR][Linalg] Fix linalg crash during elementwise op fusion (#117667)
`isOpOperandCanBeDroppedAfterFusedLinalgs` crashes when `indexingMaps` is empty. This can occur when `producer` only has DPS init operands and `consumer ` only has a single DPS input operand (all operands are ignored and nothing gets added to `indexingMaps`). This is because `concatAffineMaps` wasn't handling the maps being empty properly. Similar to `canOpOperandsBeDroppedImpl`, I added an early return when the maps are of size zero. Additionally, `concatAffineMaps`'s declaration comment says it returns an empty map when `maps` is empty but it has no way to get the `MLIRContext` needed to construct the empty affine map when the array is empty. So, I changed this to take the context. __NOTE: concatAffineMaps now takes an MLIRContext to be able to construct an empty map in the case where `maps` is empty.__ --------- Signed-off-by: Ian Wood <[email protected]> Co-authored-by: Quinn Dawkins <[email protected]>
1 parent 43b6b78 commit 06514c5

File tree

7 files changed

+53
-10
lines changed

7 files changed

+53
-10
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def LinalgStructuredInterface
741741
/*methodBody=*/"",
742742
/*defaultImplementation=*/[{
743743
auto maps = $_op.getIndexingMapsArray();
744-
return concatAffineMaps(maps);
744+
return concatAffineMaps(maps, $_op.getContext());
745745
}]
746746
>,
747747
InterfaceMethod<

mlir/include/mlir/IR/AffineMap.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map);
613613
/// ```mlir
614614
/// (i, j, k) -> (i, k, k, j, i, j)
615615
/// ```
616-
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
616+
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps, MLIRContext *context);
617617

618618
/// Returns the map that results from projecting out the dimensions specified in
619619
/// `projectedDimensions`. The projected dimensions are set to 0.

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
5454
// if the op has no loops.
5555
return linalgOp.getNumLoops() == 0;
5656
}
57-
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
57+
return inversePermutation(concatAffineMaps(
58+
indexingMaps, linalgOp.getContext())) != AffineMap();
5859
}
5960

6061
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
392392
// 1. Check if any of the iteration dimensions are unit-trip count. They will
393393
// end up being unit-trip count if they are used to index into a unit-dim
394394
// tensor/memref.
395-
AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
395+
AffineMap invertedMap =
396+
inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
396397
if (!invertedMap) {
397398
return rewriter.notifyMatchFailure(genericOp,
398399
"invalid indexing maps for operation");
@@ -486,7 +487,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
486487
// Abort if the indexing maps of the result operation are not invertible
487488
// (i.e. not legal) or if no dimension was reduced.
488489
if (newIndexingMaps == indexingMaps ||
489-
!inversePermutation(concatAffineMaps(newIndexingMaps)))
490+
!inversePermutation(
491+
concatAffineMaps(newIndexingMaps, rewriter.getContext())))
490492
return failure();
491493

492494
Location loc = genericOp.getLoc();

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,18 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
8888
indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
8989
}
9090
}
91+
if (indexingMaps.empty()) {
92+
// If there are no indexing maps, the operand can only be dropped
93+
// if neither op has loops.
94+
return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
95+
}
9196

9297
// The concatanation of the remained indexing maps must be invertible, so
9398
// the bounds of the op can be still computed after dropping the selected
9499
// operand. inversePermutation returns an empty AffineMap in case the
95100
// concatanated indexing maps are not invertible.
96-
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
101+
return inversePermutation(concatAffineMaps(
102+
indexingMaps, producer.getContext())) != AffineMap();
97103
}
98104

99105
/// Returns a set of indices of the producer's results which would
@@ -1995,7 +2001,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
19952001
genericOp.getMatchingIndexingMap(&outputOperand));
19962002

19972003
// Check if the operation shapes to loops map is computable.
1998-
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
2004+
if (!inversePermutation(
2005+
concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
19992006
return rewriter.notifyMatchFailure(
20002007
genericOp, "fused op loop bound computation failed");
20012008
}

mlir/lib/IR/AffineMap.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,10 @@ AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
833833
return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
834834
}
835835

836-
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
836+
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps,
837+
MLIRContext *context) {
838+
if (maps.empty())
839+
return AffineMap::get(context);
837840
unsigned numResults = 0, numDims = 0, numSymbols = 0;
838841
for (auto m : maps)
839842
numResults += m.getNumResults();
@@ -846,8 +849,7 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
846849
numSymbols += m.getNumSymbols();
847850
numDims = std::max(m.getNumDims(), numDims);
848851
}
849-
return AffineMap::get(numDims, numSymbols, results,
850-
maps.front().getContext());
852+
return AffineMap::get(numDims, numSymbols, results, context);
851853
}
852854

853855
/// Common implementation to project out dimensions or symbols from an affine

mlir/test/Dialect/Linalg/fusion-elementwise.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,34 @@ func.func @drop_unused_producer_result(%arg0 : tensor<?x?xf32>,
2828
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
2929
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
3030
// CHECK: return %[[FUSED_OP]]
31+
32+
// -----
33+
34+
#map = affine_map<(d0) -> (d0)>
35+
func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
36+
%cst_0 = arith.constant 0.000000e+00 : f32
37+
%cst_1 = arith.constant 1.000000e+00 : f32
38+
%0:2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} outs(%arg0, %arg1 : tensor<8xf32>, tensor<8xf32>) {
39+
^bb0(%out: f32, %out_2: f32):
40+
%1 = linalg.index 0 : index
41+
%2 = arith.index_cast %1 : index to i64
42+
%3 = arith.sitofp %2 : i64 to f32
43+
%4 = arith.divf %3, %cst_0 : f32
44+
linalg.yield %3, %4 : f32, f32
45+
} -> (tensor<8xf32>, tensor<8xf32>)
46+
linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%0#1 : tensor<8xf32>) {
47+
^bb0(%in: f32):
48+
%2 = arith.cmpf one, %in, %cst_1 : f32
49+
cf.assert %2, "Side effect op"
50+
linalg.yield
51+
}
52+
func.return %arg1 : tensor<8xf32>
53+
}
54+
55+
// CHECK-LABEL: func @handle_unused_operands
56+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
57+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
58+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
59+
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
60+
// CHECK-SAME: outs(%[[EMPTY]] :
61+
// CHECK-NOT: linalg.generic

0 commit comments

Comments
 (0)