Skip to content

Commit 8b5236d

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Simplify slice dim computation for fusion on tensors (NFC).
Compute the tiled producer slice dimensions directly starting from the consumer not using the producer at all. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110147
1 parent 9072f1b commit 8b5236d

File tree

2 files changed

+55
-53
lines changed

2 files changed

+55
-53
lines changed

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 19 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -30,61 +30,26 @@ using namespace linalg;
3030
// StructuredOp specific helpers.
3131
//===----------------------------------------------------------------------===//
3232

33-
/// Relate the producer to the consumer loop iterations that access the same
34-
/// producer result element:
35-
/// consumerToProducerLoops =
36-
/// inverse(producerIndexingMap).compose(consumerIndexingMap).
37-
/// Return `consumerToProducerLoops` or none if the inversion fails.
38-
static Optional<AffineMap>
39-
getConsumerToProducerLoopsMap(AffineMap producerIndexingMap,
40-
AffineMap consumerIndexingMap) {
41-
assert(consumerIndexingMap.getNumResults() ==
42-
producerIndexingMap.getNumResults() &&
43-
"expect the number of indexing map results to match");
44-
// Ensure the producer indexing map is a projected permutation.
45-
if (!producerIndexingMap.isProjectedPermutation())
46-
return None;
47-
AffineMap inverseIndexingMap =
48-
inverseAndBroadcastProjectedPermuation(producerIndexingMap);
49-
return inverseIndexingMap.compose(consumerIndexingMap);
50-
}
51-
52-
/// Returns the producer result slice dimensions tiled by the tile loop nest or
53-
/// an empty vector if `getConsumerToProducerLoopsMap` returns none.
54-
// TODO: replace by Fourier-Motzkin and/or compute starting from consumer.
55-
SmallVector<int64_t> getTiledSliceDims(OpResult producerResult,
56-
OpOperand *consumerOperand,
33+
/// Returns the tiled slice dimensions given the tiled consumer loop dimensions.
34+
/// The slice defines a hyper rectangular iteration space and fusing the
35+
/// producer is always possible. However, depending on the consumer indexing
36+
/// map, not all slice elements may be consumed and the tiles may overlap. In
37+
/// these cases, fusion introduces redundant computation.
38+
SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
5739
ArrayRef<int64_t> tiledLoopDims) {
40+
// Get the consumer operand indexing map.
5841
LinalgOp consumerOp = consumerOperand->getOwner();
59-
LinalgOp producerOp = producerResult.getOwner();
60-
OpOperand *opOperand =
61-
producerOp.getOutputOperand(producerResult.getResultNumber());
62-
63-
// Compute the `consumerToProducerLoopsMap` and exit if the computation fails.
64-
AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(opOperand);
65-
Optional<AffineMap> consumerToProducerLoopsMap =
66-
getConsumerToProducerLoopsMap(
67-
producerIndexingMap, consumerOp.getTiedIndexingMap(consumerOperand));
68-
if (!consumerToProducerLoopsMap.hasValue())
69-
return {};
70-
71-
// Compute the set of tiled producer loops.
72-
DenseSet<int64_t> tiledProducerLoops;
73-
for (auto en : enumerate(consumerToProducerLoopsMap->getResults())) {
74-
for (int64_t dim : tiledLoopDims) {
75-
if (en.value().isFunctionOfDim(dim))
76-
tiledProducerLoops.insert(en.index());
42+
AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
43+
44+
// Search the slice dimensions tiled by a tile loop dimension.
45+
DenseSet<int64_t> tiledSliceDims;
46+
for (auto en : enumerate(indexingMap.getResults())) {
47+
for (auto tiledLoopDim : tiledLoopDims) {
48+
if (en.value().isFunctionOfDim(tiledLoopDim))
49+
tiledSliceDims.insert(en.index());
7750
}
7851
}
79-
80-
// Compute the slice dimensions for the tiled producer loops.
81-
SmallVector<int64_t> tiledSliceDims;
82-
for (auto en : enumerate(producerIndexingMap.getResults())) {
83-
auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
84-
if (dimExpr && tiledProducerLoops.count(dimExpr.getPosition()) != 0)
85-
tiledSliceDims.push_back(en.index());
86-
}
87-
return tiledSliceDims;
52+
return {tiledSliceDims.begin(), tiledSliceDims.end()};
8853
}
8954

9055
/// Returns the producer fused in place of `sliceOp`. Tile the producer operands
@@ -332,9 +297,10 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
332297
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
333298
return failure();
334299

335-
// Compute the slice dimensions tiled by `tileLoopNest`.
300+
// Compute the tiled producer slice dimensions given the tiled root operation
301+
// loop dimensions `loopDims`.
336302
SmallVector<int64_t> tiledSliceDims =
337-
getTiledSliceDims(producerResult, rootOpOperand, loopDims);
303+
getTiledSliceDims(rootOpOperand, loopDims);
338304
if (tiledSliceDims.empty())
339305
return failure();
340306

mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,39 @@ builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>,
230230
return %1 : tensor<24x25xi32>
231231
}
232232

233+
// -----
234+
235+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
236+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 18)>
237+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 18)>
238+
#map0 = affine_map<(d0, d1) -> (d0, d0 + d1)>
239+
#map1 = affine_map<(d0, d1) -> (d0, d1)>
240+
241+
// CHECK: fuse_non_rectangular
242+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x18xf32>
243+
func @fuse_non_rectangular(%arg0: tensor<10x18xf32>,
244+
%arg1: tensor<10x8xf32>) -> tensor<10x8xf32> {
245+
%cst = constant 0.000000e+00 : f32
246+
%0 = linalg.fill(%cst, %arg0) : f32, tensor<10x18xf32> -> tensor<10x18xf32>
247+
248+
// CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] =
249+
// CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] =
250+
251+
// Compute producer on a hyper rectangular bounding box. Along the second dimenson,
252+
// the offset is set to the sum of the induction variables and the upper bound
253+
// to either eight (sum of the tile sizes) or eighteen (sum of the domain sizes)
254+
// minus the induction variables.
255+
// CHECK: %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]]
256+
// CHECK: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]]
257+
// CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]]
258+
// CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]]
259+
// CHECK-SAME: %[[IV1]], %[[SUM]]
260+
// CHECK-SAME: , %[[UB1]]
261+
// CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]])
262+
%1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x18xf32>) outs(%arg1 : tensor<10x8xf32>) {
263+
^bb0(%arg2: f32, %arg3: f32): // no predecessors
264+
%2 = addf %arg2, %arg3 : f32
265+
linalg.yield %2 : f32
266+
} -> tensor<10x8xf32>
267+
return %1 : tensor<10x8xf32>
268+
}

0 commit comments

Comments
 (0)