@@ -42,19 +42,62 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
42
42
AffineMap indexingMap = consumerOp.getTiedIndexingMap (consumerOperand);
43
43
44
44
// Search the slice dimensions tiled by a tile loop dimension.
45
- DenseSet<int64_t > tiledSliceDims ;
45
+ DenseSet<int64_t > tiledSliceDimIndices ;
46
46
for (auto en : enumerate(indexingMap.getResults ())) {
47
47
for (auto tiledLoopDim : tiledLoopDims) {
48
48
if (en.value ().isFunctionOfDim (tiledLoopDim))
49
- tiledSliceDims .insert (en.index ());
49
+ tiledSliceDimIndices .insert (en.index ());
50
50
}
51
51
}
52
- return {tiledSliceDims.begin (), tiledSliceDims.end ()};
52
+ return {tiledSliceDimIndices.begin (), tiledSliceDimIndices.end ()};
53
+ }
54
+
55
+ // / Given a vector of `tiledSliceDimIndices` that represent the tiled dimensions
56
+ // / of the producer result slice returns the tiled producer loop dimensions.
57
+ // / Example:
58
+ // / ```
59
+ // / %res = linalg.fill(%cst, %input)
60
+ // / scf.for %i
61
+ // / scf.for %j
62
+ // / %slice = tensor.extract_slice %res[%i, %j]
63
+ // / ```
64
+ // / getTiledProducerLoops(%res, [0, 1]) returns the loop indices [0, 1].
65
+ static SmallVector<int64_t >
66
+ getTiledProducerLoops (OpResult producerResult,
67
+ ArrayRef<int64_t > tiledSliceDimIndices) {
68
+ LinalgOp producerOp = producerResult.getOwner ();
69
+
70
+ // Get the indexing map of the `producerOp` output operand that matches
71
+ // ´producerResult´.
72
+ AffineMap producerIndexingMap = producerOp.getTiedIndexingMap (
73
+ producerOp.getOutputOperand (producerResult.getResultNumber ()));
74
+
75
+ // Keep only the tiled result slice dimensions of `producerIndexingMap`.
76
+ AffineMap tiledProducerIndexingSubMap =
77
+ producerIndexingMap.getSubMap (SmallVector<unsigned >(
78
+ tiledSliceDimIndices.begin (), tiledSliceDimIndices.end ()));
79
+
80
+ // Compute the producer loop indices mapped to the tiled result slice
81
+ // dimensions. As the output indexing map of structured operations are
82
+ // projected permutations, `tiledProducerIndexingSubMap` has to be a
83
+ // projected permutation as well. We can thus obtain the producer loop indices
84
+ // by getting the positions of the result dimensions.
85
+ // Example:
86
+ // (d0, d1, d2) -> (d0, d2) has the result positions [0, 2].
87
+ assert (tiledProducerIndexingSubMap.isProjectedPermutation () &&
88
+ " expect slice and producer loop dimensions map one-to-one" );
89
+ SmallVector<int64_t > tiledProducerLoopIndices;
90
+ transform (llvm::seq<unsigned >(0 , tiledProducerIndexingSubMap.getNumResults ()),
91
+ std::back_inserter (tiledProducerLoopIndices), [&](unsigned idx) {
92
+ return tiledProducerIndexingSubMap.getDimPosition (idx);
93
+ });
94
+
95
+ return tiledProducerLoopIndices;
53
96
}
54
97
55
98
// / Returns the producer fused in place of `sliceOp`. Tile the producer operands
56
- // / along the `tiledSliceDims ` and clone the producer. Consider the case of
57
- // / fusion of an output tensor:
99
+ // / along the `tiledSliceDimIndices ` and clone the producer. Consider the case
100
+ // / of fusion of an output tensor:
58
101
// / ```
59
102
// / %1 = producer ins(...) outs(%0)
60
103
// / %2 = consumer ins(...) outs(%1)
@@ -84,7 +127,8 @@ static SmallVector<int64_t> getTiledSliceDims(OpOperand *consumerOperand,
84
127
// / producer is fused into a consumer and fold away unused iter_args.
85
128
static LinalgOp getTiledProducer (OpBuilder &b, OpResult producerResult,
86
129
tensor::ExtractSliceOp sliceOp,
87
- ArrayRef<int64_t > tiledSliceDims,
130
+ ArrayRef<int64_t > tiledSliceDimIndices,
131
+ ArrayRef<int64_t > tiledProducerLoopIndices,
88
132
OpOperand *iterArg) {
89
133
// Clone the producer after `sliceOp` since the slice may be reused to pass in
90
134
// the producer result.
@@ -102,23 +146,16 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
102
146
[](Range range) { return range.size ; });
103
147
SmallVector<Range> sliceOpRanges = sliceOp.getOrCreateRanges (b, loc);
104
148
105
- // Get the producer result indexing map.
106
- AffineMap producerIndexingMap = producerOp.getTiedIndexingMap (
107
- producerOp.getOutputOperand (producerResult.getResultNumber ()));
108
-
109
149
// Tile the producer operands given the `sliceOp` ranges. Iterate the
110
- // `tiledSliceDims` and store the tile offset and size for the tiled slice
111
- // dimension. Assumes the mapping from slice dimensions to producer loops is a
112
- // permutation.
150
+ // `tiledSliceDimIndices` and store the tile offset and size for the tiled
151
+ // slice dimension.
113
152
auto zero = b.create <arith::ConstantIndexOp>(loc, 0 );
114
153
SmallVector<Value> tileIvs (producerOp.getNumLoops (), nullptr );
115
154
SmallVector<Value> tileSizes (producerOp.getNumLoops (), zero);
116
155
SmallVector<Value> allIvs (producerOp.getNumLoops (), nullptr );
117
- for (int64_t tiledSliceDim : tiledSliceDims) {
118
- AffineExpr result = producerIndexingMap.getResults ()[tiledSliceDim];
119
- assert (result.isa <AffineDimExpr>() &&
120
- " expect producer indexing map is a projected permutation" );
121
- int64_t tiledProducerLoop = result.cast <AffineDimExpr>().getPosition ();
156
+ for (auto it : zip (tiledSliceDimIndices, tiledProducerLoopIndices)) {
157
+ int64_t tiledSliceDim = std::get<0 >(it);
158
+ int64_t tiledProducerLoop = std::get<1 >(it);
122
159
tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset ;
123
160
tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size ;
124
161
allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop];
@@ -156,30 +193,34 @@ static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult,
156
193
// TileLoopNest specific helpers.
157
194
// ===----------------------------------------------------------------------===//
158
195
159
- bool TileLoopNest::isEmpty () { return loopOps .empty (); }
196
+ bool TileLoopNest::isEmpty () { return tileLoopOps .empty (); }
160
197
161
198
bool TileLoopNest::isValid () {
162
- // Check if the number of `tileLoopOps` and `tileLoopDims` match.
163
- if (loopOps.size () != loopDims.size ())
199
+ // Check if `rootOp` has been tiled at least once.
200
+ if (isEmpty () || tiledRootAndFusedOpsLoops.count (rootOp) == 0 )
201
+ return false ;
202
+
203
+ // Check if the number of loop operations and dimensions match.
204
+ if (tileLoopOps.size () != tiledRootAndFusedOpsLoops[rootOp].size ())
164
205
return false ;
165
206
166
207
// Check if the innermost tile loop is the parent of `tiledOp`.
167
- if (rootOp->getParentOp () != loopOps .back ())
208
+ if (rootOp->getParentOp () != tileLoopOps .back ())
168
209
return false ;
169
210
170
211
// Check if the tile loops are directly nested.
171
- return std::adjacent_find (loopOps .begin (), loopOps .end (),
212
+ return std::adjacent_find (tileLoopOps .begin (), tileLoopOps .end (),
172
213
[](Operation *op1, Operation *op2) {
173
214
return op1 != op2->getParentOp ();
174
- }) == loopOps .end ();
215
+ }) == tileLoopOps .end ();
175
216
}
176
217
177
218
SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs (BlockArgument bbArg) {
178
219
assert (bbArg && " expect the block argument to be non-zero" );
179
220
SmallVector<BlockArgument> bbArgs;
180
221
181
222
// Search all tile loop block arguments from inner to outer.
182
- for (auto tileLoop : reverse (loopOps )) {
223
+ for (auto tileLoop : reverse (tileLoopOps )) {
183
224
if (bbArg.getOwner ()->getParentOp () != tileLoop)
184
225
return {};
185
226
bbArgs.push_back (bbArg);
@@ -194,9 +235,9 @@ SmallVector<BlockArgument> TileLoopNest::getTiedBBArgs(BlockArgument bbArg) {
194
235
OpOperand *TileLoopNest::getTiedIterArg (BlockArgument bbArg) {
195
236
// Search all block arguments and return the matching iteration argument.
196
237
SmallVector<BlockArgument> bbArgs = getTiedBBArgs (bbArg);
197
- if (bbArgs.size () != loopOps .size ())
238
+ if (bbArgs.size () != tileLoopOps .size ())
198
239
return nullptr ;
199
- return &loopOps .front ().getOpOperandForRegionIterArg (bbArgs.front ());
240
+ return &tileLoopOps .front ().getOpOperandForRegionIterArg (bbArgs.front ());
200
241
}
201
242
202
243
bool TileLoopNest::hasOtherUses (BlockArgument bbArg,
@@ -255,38 +296,46 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
255
296
if (!isEmpty ())
256
297
rootOp->replaceAllUsesWith (tiledRootOp->tensorResults );
257
298
299
+ // Transfer the stored `rootOp` loop dimensions if it has been tiled before.
300
+ if (tiledRootAndFusedOpsLoops.count (rootOp) != 0 ) {
301
+ tiledRootAndFusedOpsLoops[tiledRootOp->op ] =
302
+ tiledRootAndFusedOpsLoops[rootOp];
303
+ }
304
+
258
305
// Update the root operation and append the loops and tile loop dimensions.
259
306
rootOp = tiledRootOp->op ;
260
- loopOps .append (tiledRootOp->loops .begin (), tiledRootOp->loops .end ());
307
+ tileLoopOps .append (tiledRootOp->loops .begin (), tiledRootOp->loops .end ());
261
308
for (auto en : enumerate(tileSizes)) {
262
309
// Copy only the tiled loop dimensions with non-zero tile size.
263
310
if (en.value () == 0 )
264
311
continue ;
265
- loopDims .push_back (tileInterchange[en.index ()]);
312
+ tiledRootAndFusedOpsLoops[rootOp] .push_back (tileInterchange[en.index ()]);
266
313
}
267
314
assert (isValid () && " expect tile loop nest to be valid after tiling" );
268
-
269
315
return success ();
270
316
}
271
317
272
318
FailureOr<LinalgOp> TileLoopNest::fuseProducer (OpBuilder &b,
273
- OpOperand *rootOpOperand ) {
274
- assert (rootOpOperand ->getOwner () == rootOp &&
275
- " expect the root op to be the owner of the operand to fuse " );
319
+ OpOperand *consumerOpOperand ) {
320
+ assert (tiledRootAndFusedOpsLoops. count (consumerOpOperand ->getOwner ()) != 0 &&
321
+ " expect the operand owner is the root operation or a fused producer " );
276
322
assert (this ->isValid () &&
277
323
" expect the tile loop nest to satisfy all invariants" );
278
324
279
325
// Check the tile loop nest is non-empty.
280
326
if (isEmpty ())
281
327
return failure ();
282
328
283
- // Check `rootOpOperand` is defined by an ExtractSliceOp.
284
- auto sliceOp = rootOpOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
329
+ // Check `consumerOpOperand` is defined by an ExtractSliceOp.
330
+ auto sliceOp =
331
+ consumerOpOperand->get ().getDefiningOp <tensor::ExtractSliceOp>();
285
332
if (!sliceOp)
286
333
return failure ();
287
334
288
- // Check `sliceOp` is tiled by the tile loop nest.
289
- if (sliceOp->getParentOp () != rootOp->getParentOp ())
335
+ // Check `sliceOp` and `consumerOp` are in the same block.
336
+ LinalgOp consumerOp = consumerOpOperand->getOwner ();
337
+ if (sliceOp->getBlock () != rootOp->getBlock () ||
338
+ consumerOp->getBlock () != rootOp->getBlock ())
290
339
return failure ();
291
340
292
341
// Check if the producer is a LinalgOp possibly passed by iteration argument.
@@ -302,19 +351,24 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
302
351
if (!producerResult || !isa<LinalgOp>(producerResult.getOwner ()))
303
352
return failure ();
304
353
305
- // Compute the tiled producer slice dimensions given the tiled root operation
306
- // loop dimensions `loopDims`.
307
- SmallVector<int64_t > tiledSliceDims =
308
- getTiledSliceDims (rootOpOperand, loopDims);
309
- if (tiledSliceDims.empty ())
354
+ // Compute the tiled producer slice dimensions given the tiled consumer loops.
355
+ SmallVector<int64_t > tiledSliceDimIndices = getTiledSliceDims (
356
+ consumerOpOperand, tiledRootAndFusedOpsLoops[consumerOp]);
357
+ if (tiledSliceDimIndices.empty ())
310
358
return failure ();
311
359
360
+ // Compute the tiled producer loop indices.
361
+ SmallVector<int64_t > tiledProducerLoopIndices =
362
+ getTiledProducerLoops (producerResult, tiledSliceDimIndices);
363
+
312
364
// Tile the producer operands and clone the producer in place of `sliceOp`.
313
365
LinalgOp clonedOp =
314
- getTiledProducer (b, producerResult, sliceOp, tiledSliceDims, iterArg);
366
+ getTiledProducer (b, producerResult, sliceOp, tiledSliceDimIndices,
367
+ tiledProducerLoopIndices, iterArg);
368
+ tiledRootAndFusedOpsLoops[clonedOp] = tiledProducerLoopIndices;
315
369
316
370
// Cast the `clonedOp` result to gap type mismatches before canonicalization.
317
- Type consumerOperandType = rootOpOperand ->get ().getType ();
371
+ Type consumerOperandType = consumerOpOperand ->get ().getType ();
318
372
Value newResult = clonedOp->getResult (producerResult.getResultNumber ());
319
373
if (newResult.getType () != consumerOperandType) {
320
374
OpBuilder::InsertionGuard guard (b);
@@ -330,7 +384,7 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
330
384
331
385
ValueRange TileLoopNest::getRootOpReplacementResults () {
332
386
assert (!isEmpty () && " expect tile loop nest to be non-empty" );
333
- return loopOps .front ()->getOpResults ();
387
+ return tileLoopOps .front ()->getOpResults ();
334
388
}
335
389
336
390
// ===----------------------------------------------------------------------===//
@@ -359,25 +413,33 @@ mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp,
359
413
});
360
414
int64_t split = std::distance (iterTypes.begin (), it);
361
415
416
+ // Helper to fuse the producers greedily using a queue of fusion candidates.
417
+ auto fuseProducersGreedily = [&](ArrayRef<OpOperand *> operands) {
418
+ SmallVector<OpOperand *> candidates (operands.begin (), operands.end ());
419
+ while (!candidates.empty ()) {
420
+ FailureOr<LinalgOp> fusedProducer =
421
+ tileLoopNest.fuseProducer (b, candidates.pop_back_val ());
422
+ if (failed (fusedProducer))
423
+ continue ;
424
+ candidates.append (fusedProducer->getInputAndOutputOperands ());
425
+ }
426
+ };
427
+
362
428
// Tile the outer parallel loops and fuse the output operands.
363
429
SmallVector<int64_t > outerTileSizes;
364
430
outerTileSizes.append (tileSizes.begin (), tileSizes.begin () + split);
365
431
outerTileSizes.append (tileSizes.size () - split, 0 );
366
432
if (failed (tileLoopNest.tileRootOp (b, outerTileSizes, tileInterchange)))
367
433
return failure ();
368
- for (OpOperand *opOperand : tileLoopNest.getRootOp ().getOutputOperands ())
369
- (void )tileLoopNest.fuseProducer (b, opOperand);
434
+ fuseProducersGreedily (tileLoopNest.getRootOp ().getOutputOperands ());
370
435
371
436
// Tile the remaining loops and fuse the input operands.
372
437
SmallVector<int64_t > innerTileSizes;
373
438
innerTileSizes.append (split, 0 );
374
439
innerTileSizes.append (tileSizes.begin () + split, tileSizes.end ());
375
440
if (failed (tileLoopNest.tileRootOp (b, innerTileSizes, tileInterchange)))
376
441
return failure ();
377
- SmallVector<OpOperand *> inputOperands =
378
- tileLoopNest.getRootOp ().getInputOperands ();
379
- for (OpOperand *opOperand : tileLoopNest.getRootOp ().getInputOperands ())
380
- (void )tileLoopNest.fuseProducer (b, opOperand);
442
+ fuseProducersGreedily (tileLoopNest.getRootOp ().getInputOperands ());
381
443
382
444
return tileLoopNest;
383
445
}
0 commit comments