Skip to content

[MLIR][Linalg] Fix linalg crash during elementwise op fusion #117667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def LinalgStructuredInterface
/*methodBody=*/"",
/*defaultImplementation=*/[{
auto maps = $_op.getIndexingMapsArray();
return concatAffineMaps(maps);
return concatAffineMaps(maps, $_op.getContext());
}]
>,
InterfaceMethod<
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/AffineMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map);
/// ```mlir
/// (i, j, k) -> (i, k, k, j, i, j)
/// ```
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
AffineMap concatAffineMaps(ArrayRef<AffineMap> maps, MLIRContext *context);

/// Returns the map that results from projecting out the dimensions specified in
/// `projectedDimensions`. The projected dimensions are set to 0.
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
// if the op has no loops.
return linalgOp.getNumLoops() == 0;
}
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
return inversePermutation(concatAffineMaps(
indexingMaps, linalgOp.getContext())) != AffineMap();
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
// 1. Check if any of the iteration dimensions are unit-trip count. They will
// end up being unit-trip count if they are used to index into a unit-dim
// tensor/memref.
AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps));
AffineMap invertedMap =
inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
if (!invertedMap) {
return rewriter.notifyMatchFailure(genericOp,
"invalid indexing maps for operation");
Expand Down Expand Up @@ -486,7 +487,8 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
// Abort if the indexing maps of the result operation are not invertible
// (i.e. not legal) or if no dimension was reduced.
if (newIndexingMaps == indexingMaps ||
!inversePermutation(concatAffineMaps(newIndexingMaps)))
!inversePermutation(
concatAffineMaps(newIndexingMaps, rewriter.getContext())))
return failure();

Location loc = genericOp.getLoc();
Expand Down
11 changes: 9 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,18 @@ static bool isOpOperandCanBeDroppedAfterFusedLinalgs(
indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand));
}
}
if (indexingMaps.empty()) {
// If there are no indexing maps, the operand can only be dropped
// if neither op has loops.
return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0;
}

// The concatanation of the remained indexing maps must be invertible, so
// the bounds of the op can be still computed after dropping the selected
// operand. inversePermutation returns an empty AffineMap in case the
// concatanated indexing maps are not invertible.
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
return inversePermutation(concatAffineMaps(
indexingMaps, producer.getContext())) != AffineMap();
}

/// Returns a set of indices of the producer's results which would
Expand Down Expand Up @@ -1995,7 +2001,8 @@ class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> {
genericOp.getMatchingIndexingMap(&outputOperand));

// Check if the operation shapes to loops map is computable.
if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) {
if (!inversePermutation(
concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) {
return rewriter.notifyMatchFailure(
genericOp, "fused op loop bound computation failed");
}
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/IR/AffineMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,10 @@ AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
}

AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps,
MLIRContext *context) {
if (maps.empty())
return AffineMap::get(context);
unsigned numResults = 0, numDims = 0, numSymbols = 0;
for (auto m : maps)
numResults += m.getNumResults();
Expand All @@ -846,8 +849,7 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
numSymbols += m.getNumSymbols();
numDims = std::max(m.getNumDims(), numDims);
}
return AffineMap::get(numDims, numSymbols, results,
maps.front().getContext());
return AffineMap::get(numDims, numSymbols, results, context);
}

/// Common implementation to project out dimensions or symbols from an affine
Expand Down
31 changes: 31 additions & 0 deletions mlir/test/Dialect/Linalg/fusion-elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,34 @@ func.func @drop_unused_producer_result(%arg0 : tensor<?x?xf32>,
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: return %[[FUSED_OP]]

// -----

#map = affine_map<(d0) -> (d0)>
func.func @handle_unused_operands(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> {
%cst_0 = arith.constant 0.000000e+00 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
%0:2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} outs(%arg0, %arg1 : tensor<8xf32>, tensor<8xf32>) {
^bb0(%out: f32, %out_2: f32):
%1 = linalg.index 0 : index
%2 = arith.index_cast %1 : index to i64
%3 = arith.sitofp %2 : i64 to f32
%4 = arith.divf %3, %cst_0 : f32
linalg.yield %3, %4 : f32, f32
} -> (tensor<8xf32>, tensor<8xf32>)
linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} ins(%0#1 : tensor<8xf32>) {
^bb0(%in: f32):
%2 = arith.cmpf one, %in, %cst_1 : f32
cf.assert %2, "Side effect op"
linalg.yield
}
func.return %arg1 : tensor<8xf32>
}

// CHECK-LABEL: func @handle_unused_operands
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<8xf32>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: outs(%[[EMPTY]] :
// CHECK-NOT: linalg.generic