Skip to content

Commit 29c31cb

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Add support for transitive fusion.
Extend fusion on tensors to fuse producers greedily. Reviewed By: nicolasvasilache, hanchung Differential Revision: https://reviews.llvm.org/D110262
1 parent c92de29 commit 29c31cb

File tree

3 files changed

+199
-53
lines changed

3 files changed

+199
-53
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ class TileLoopNest {
212212
bool isEmpty();
213213

214214
/// Returns true if the tile loop nest invariants are satisfied:
215+
/// - The `rootOp` has been tiled at least once.
215216
/// - The number of tile loop operations and dimensions match.
216217
/// - The innermost tile loop is the parent of `tiledOp`.
217218
/// - The tile loops are directly nested.
@@ -233,8 +234,8 @@ class TileLoopNest {
233234
bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp);
234235

235236
LinalgOp rootOp;
236-
SmallVector<scf::ForOp> loopOps;
237-
SmallVector<int64_t> loopDims;
237+
SmallVector<scf::ForOp> tileLoopOps;
238+
DenseMap<Operation *, SmallVector<int64_t>> tiledRootAndFusedOpsLoops;
238239
};
239240

240241
/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp

Lines changed: 113 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,62 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
4242
AffineMap indexingMap = consumerOp.getTiedIndexingMap(consumerOperand);
4343

4444
// Search the slice dimensions tiled by a tile loop dimension.
45-
DenseSet<int64_t> tiledSliceDims;
45+
DenseSet<int64_t> tiledSliceDimIndices;
4646
for (auto en : enumerate(indexingMap.getResults())) {
4747
for (auto tiledLoopDim : tiledLoopDims) {
4848
if (en.value().isFunctionOfDim(tiledLoopDim))
49-
tiledSliceDims.insert(en.index());
49+
tiledSliceDimIndices.insert(en.index());
5050
}
5151
}
52-
return {tiledSliceDims.begin(), tiledSliceDims.end()};
52+
return {tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()};
53+
}
54+
55+
/// Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions
56+
/// of the producer result slice returns the tiled producer loop dimensions.
57+
/// Example:
58+
/// ```
59+
/// %res = linalg.fill(%cst, %input)
60+
/// scf.for %i
61+
/// scf.for %j
62+
/// %slice = tensor.extract_slice %res[%i, %j]
63+
/// ```
64+
/// getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1].
65+
static SmallVector<int64_t>
66+
getTiledProducerLoops(OpResult producerResult,
67+
ArrayRef<int64_t> tiledSliceDimIndices) {
68+
LinalgOp producerOp = producerResult.getOwner();
69+
70+
// Get the indexing map of the `producerOp` output operand that matches
71+
// ´producerResult´.
72+
AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
73+
producerOp.getOutputOperand(producerResult.getResultNumber()));
74+
75+
// Keep only the tiled result slice dimensions of `producerIndexingMap`.
76+
AffineMap tiledProducerIndexingSubMap =
77+
producerIndexingMap.getSubMap(SmallVector<unsigned>(
78+
tiledSliceDimIndices.begin(), tiledSliceDimIndices.end()));
79+
80+
// Compute the producer loop indices mapped to the tiled result slice
81+
// dimensions. As the output indexing map of structured operations are
82+
// projected permutations, `tiledProducerIndexingSubMap` has to be a
83+
// projected permutation as well. We can thus obtain the producer loop indices
84+
// by getting the positions of the result dimensions.
85+
// Example:
86+
// (d0, d1, d2) -> (d0, d2) has the result positions [0, 2].
87+
assert(tiledProducerIndexingSubMap.isProjectedPermutation() &&
88+
"expect slice and producer loop dimensions map one-to-one");
89+
SmallVector<int64_t> tiledProducerLoopIndices;
90+
transform(llvm::seq<unsigned>(0, tiledProducerIndexingSubMap.getNumResults()),
91+
std::back_inserter(tiledProducerLoopIndices), [&](unsigned idx) {
92+
return tiledProducerIndexingSubMap.getDimPosition(idx);
93+
});
94+
95+
return tiledProducerLoopIndices;
5396
}
5497

5598
/// Returns the producer fused in place of `sliceOp`. Tile the producer operands
56-
/// along the `tiledSliceDims` and clone the producer. Consider the case of
57-
/// fusion of an output tensor:
99+
/// along the `tiledSliceDimIndices` and clone the producer. Consider the case
100+
/// of fusion of an output tensor:
58101
/// ```
59102
/// %1 = producer ins(...) outs(%0)
60103
/// %2 = consumer ins(...) outs(%1)
@@ -84,7 +127,8 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
84127
/// producer is fused into a consumer and fold away unused iter_args.
85128
static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
86129
tensor::ExtractSliceOp sliceOp,
87-
ArrayRef<int64_t> tiledSliceDims,
130+
ArrayRef<int64_t> tiledSliceDimIndices,
131+
ArrayRef<int64_t> tiledProducerLoopIndices,
88132
OpOperand *iterArg) {
89133
// Clone the producer after `sliceOp` since the slice may be reused to pass in
90134
// the producer result.
@@ -102,23 +146,16 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
102146
[](Range range) { return range.size; });
103147
SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges(b, loc);
104148

105-
// Get the producer result indexing map.
106-
AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(
107-
producerOp.getOutputOperand(producerResult.getResultNumber()));
108-
109149
// Tile the producer operands given the `sliceOp` ranges. Iterate the
110-
// `tiledSliceDims` and store the tile offset and size for the tiled slice
111-
// dimension. Assumes the mapping from slice dimensions to producer loops is a
112-
// permutation.
150+
// `tiledSliceDimIndices` and store the tile offset and size for the tiled
151+
// slice dimension.
113152
auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
114153
SmallVector<Value> tileIvs(producerOp.getNumLoops(), nullptr);
115154
SmallVector<Value> tileSizes(producerOp.getNumLoops(), zero);
116155
SmallVector<Value> allIvs(producerOp.getNumLoops(), nullptr);
117-
for (int64_t tiledSliceDim : tiledSliceDims) {
118-
AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim];
119-
assert(result.isa<AffineDimExpr>() &&
120-
"expect producer indexing map is a projected permutation");
121-
int64_t tiledProducerLoop = result.cast<AffineDimExpr>().getPosition();
156+
for (auto it : zip(tiledSliceDimIndices, tiledProducerLoopIndices)) {
157+
int64_t tiledSliceDim = std::get<0>(it);
158+
int64_t tiledProducerLoop = std::get<1>(it);
122159
tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset;
123160
tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size;
124161
allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
@@ -156,30 +193,34 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
156193
// TileLoopNest specific helpers.
157194
//===----------------------------------------------------------------------===//
158195

159-
bool TileLoopNest::isEmpty() { return loopOps.empty(); }
196+
bool TileLoopNest::isEmpty() { return tileLoopOps.empty(); }
160197

161198
bool TileLoopNest::isValid() {
162-
// Check if the number of `tileLoopOps` and `tileLoopDims` match.
163-
if (loopOps.size() != loopDims.size())
199+
// Check if `rootOp` has been tiled at least once.
200+
if (isEmpty() || tiledRootAndFusedOpsLoops.count(rootOp) == 0)
201+
return false;
202+
203+
// Check if the number of loop operations and dimensions match.
204+
if (tileLoopOps.size() != tiledRootAndFusedOpsLoops[rootOp].size())
164205
return false;
165206

166207
// Check if the innermost tile loop is the parent of `tiledOp`.
167-
if (rootOp->getParentOp() != loopOps.back())
208+
if (rootOp->getParentOp() != tileLoopOps.back())
168209
return false;
169210

170211
// Check if the tile loops are directly nested.
171-
return std::adjacent_find(loopOps.begin(), loopOps.end(),
212+
return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(),
172213
[](Operation *op1, Operation *op2) {
173214
return op1 != op2->getParentOp();
174-
}) == loopOps.end();
215+
}) == tileLoopOps.end();
175216
}
176217

177218
SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
178219
assert(bbArg && "expect the block argument to be non-zero");
179220
SmallVector<BlockArgument> bbArgs;
180221

181222
// Search all tile loop block arguments from inner to outer.
182-
for (auto tileLoop : reverse(loopOps)) {
223+
for (auto tileLoop : reverse(tileLoopOps)) {
183224
if (bbArg.getOwner()->getParentOp() != tileLoop)
184225
return {};
185226
bbArgs.push_back(bbArg);
@@ -194,9 +235,9 @@ SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
194235
OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) {
195236
// Search all block arguments and return the matching iteration argument.
196237
SmallVector<BlockArgument> bbArgs = getTiedBBArgs(bbArg);
197-
if (bbArgs.size() != loopOps.size())
238+
if (bbArgs.size() != tileLoopOps.size())
198239
return nullptr;
199-
return &loopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
240+
return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front());
200241
}
201242

202243
bool TileLoopNest::hasOtherUses(BlockArgument bbArg,
@@ -255,38 +296,46 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
255296
if (!isEmpty())
256297
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);
257298

299+
// Transfer the stored `rootOp` loop dimensions if it has been tiled before.
300+
if (tiledRootAndFusedOpsLoops.count(rootOp) != 0) {
301+
tiledRootAndFusedOpsLoops[tiledRootOp->op] =
302+
tiledRootAndFusedOpsLoops[rootOp];
303+
}
304+
258305
// Update the root operation and append the loops and tile loop dimensions.
259306
rootOp = tiledRootOp->op;
260-
loopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
307+
tileLoopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
261308
for (auto en : enumerate(tileSizes)) {
262309
// Copy only the tiled loop dimensions with non-zero tile size.
263310
if (en.value() == 0)
264311
continue;
265-
loopDims.push_back(tileInterchange[en.index()]);
312+
tiledRootAndFusedOpsLoops[rootOp].push_back(tileInterchange[en.index()]);
266313
}
267314
assert(isValid() && "expect tile loop nest to be valid after tiling");
268-
269315
return success();
270316
}
271317

272318
FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
273-
OpOperand *rootOpOperand) {
274-
assert(rootOpOperand->getOwner() == rootOp &&
275-
"expect the root op to be the owner of the operand to fuse");
319+
OpOperand *consumerOpOperand) {
320+
assert(tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) != 0 &&
321+
"expect the operand owner is the root operation or a fused producer");
276322
assert(this->isValid() &&
277323
"expect the tile loop nest to satisfy all invariants");
278324

279325
// Check the tile loop nest is non-empty.
280326
if (isEmpty())
281327
return failure();
282328

283-
// Check `rootOpOperand` is defined by an ExtractSliceOp.
284-
auto sliceOp = rootOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
329+
// Check `consumerOpOperand` is defined by an ExtractSliceOp.
330+
auto sliceOp =
331+
consumerOpOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
285332
if (!sliceOp)
286333
return failure();
287334

288-
// Check `sliceOp` is tiled by the tile loop nest.
289-
if (sliceOp->getParentOp() != rootOp->getParentOp())
335+
// Check `sliceOp` and `consumerOp` are in the same block.
336+
LinalgOp consumerOp = consumerOpOperand->getOwner();
337+
if (sliceOp->getBlock() != rootOp->getBlock() ||
338+
consumerOp->getBlock() != rootOp->getBlock())
290339
return failure();
291340

292341
// Check if the producer is a LinalgOp possibly passed by iteration argument.
@@ -302,19 +351,24 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
302351
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner()))
303352
return failure();
304353

305-
// Compute the tiled producer slice dimensions given the tiled root operation
306-
// loop dimensions `loopDims`.
307-
SmallVector<int64_t> tiledSliceDims =
308-
getTiledSliceDims(rootOpOperand, loopDims);
309-
if (tiledSliceDims.empty())
354+
// Compute the tiled producer slice dimensions given the tiled consumer loops.
355+
SmallVector<int64_t> tiledSliceDimIndices = getTiledSliceDims(
356+
consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]);
357+
if (tiledSliceDimIndices.empty())
310358
return failure();
311359

360+
// Compute the tiled producer loop indices.
361+
SmallVector<int64_t> tiledProducerLoopIndices =
362+
getTiledProducerLoops(producerResult, tiledSliceDimIndices);
363+
312364
// Tile the producer operands and clone the producer in place of `sliceOp`.
313365
LinalgOp clonedOp =
314-
getTiledProducer(b, producerResult, sliceOp, tiledSliceDims, iterArg);
366+
getTiledProducer(b, producerResult, sliceOp, tiledSliceDimIndices,
367+
tiledProducerLoopIndices, iterArg);
368+
tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices;
315369

316370
// Cast the `clonedOp` result to gap type mismatches before canonicalization.
317-
Type consumerOperandType = rootOpOperand->get().getType();
371+
Type consumerOperandType = consumerOpOperand->get().getType();
318372
Value newResult = clonedOp->getResult(producerResult.getResultNumber());
319373
if (newResult.getType() != consumerOperandType) {
320374
OpBuilder::InsertionGuard guard(b);
@@ -330,7 +384,7 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
330384

331385
ValueRange TileLoopNest::getRootOpReplacementResults() {
332386
assert(!isEmpty() && "expect tile loop nest to be non-empty");
333-
return loopOps.front()->getOpResults();
387+
return tileLoopOps.front()->getOpResults();
334388
}
335389

336390
//===----------------------------------------------------------------------===//
@@ -359,25 +413,33 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
359413
});
360414
int64_t split = std::distance(iterTypes.begin(), it);
361415

416+
// Helper to fuse the producers greedily using a queue of fusion candidates.
417+
auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
418+
SmallVector<OpOperand *> candidates(operands.begin(), operands.end());
419+
while (!candidates.empty()) {
420+
FailureOr<LinalgOp> fusedProducer =
421+
tileLoopNest.fuseProducer(b, candidates.pop_back_val());
422+
if (failed(fusedProducer))
423+
continue;
424+
candidates.append(fusedProducer->getInputAndOutputOperands());
425+
}
426+
};
427+
362428
// Tile the outer parallel loops and fuse the output operands.
363429
SmallVector<int64_t> outerTileSizes;
364430
outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split);
365431
outerTileSizes.append(tileSizes.size() - split, 0);
366432
if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange)))
367433
return failure();
368-
for (OpOperand *opOperand : tileLoopNest.getRootOp().getOutputOperands())
369-
(void)tileLoopNest.fuseProducer(b, opOperand);
434+
fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands());
370435

371436
// Tile the remaining loops and fuse the input operands.
372437
SmallVector<int64_t> innerTileSizes;
373438
innerTileSizes.append(split, 0);
374439
innerTileSizes.append(tileSizes.begin() + split, tileSizes.end());
375440
if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange)))
376441
return failure();
377-
SmallVector<OpOperand *> inputOperands =
378-
tileLoopNest.getRootOp().getInputOperands();
379-
for (OpOperand *opOperand : tileLoopNest.getRootOp().getInputOperands())
380-
(void)tileLoopNest.fuseProducer(b, opOperand);
442+
fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands());
381443

382444
return tileLoopNest;
383445
}

0 commit comments

Comments
 (0)