@@ -311,80 +311,50 @@ struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp
311
311
// Collect list of operations that can be tiled and fused.
312
312
llvm::SmallDenseSet<Operation *> tiledAndFusedOps =
313
313
collectTiledAndFusedOps (rootOp);
314
- auto isIgnoredUser = [&]( Operation *user, scf::ForOp outerMostTiledLoop) {
315
- return tiledAndFusedOps. count (user) || isa<tensor::DimOp>( user) ||
316
- outerMostTiledLoop-> isAncestor (user);
314
+ llvm::SmallDenseMap< Operation *, bool > yielded;
315
+ auto isIgnoredUser = [&](Operation * user) {
316
+ return tiledAndFusedOps. count (user) || isa<tensor::DimOp> (user);
317
317
};
318
-
319
- // The rest of this method is similar to
320
- // scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp, except that also
321
- // yields replacements for values of the fused producer.
322
-
323
- // 1. Tile the consumer.
324
- SmallVector<OpResult> yieldedValuesToOrigValues;
325
- FailureOr<scf::SCFTilingResult> tilingResult =
326
- scf::tileUsingSCFForOp (rewriter, rootOp, options);
327
- if (failed (tilingResult)) {
328
- return rewriter.notifyMatchFailure (rootOp,
329
- " failed to tile base operation" );
318
+ for (Operation *op : tiledAndFusedOps) {
319
+ yielded[op] = llvm::any_of (op->getUsers (), [&](Operation *user) {
320
+ return !isIgnoredUser (user);
321
+ });
330
322
}
331
- yieldedValuesToOrigValues.append (rootOp->result_begin (),
332
- rootOp->result_end ());
333
-
334
- // 2. Tiling each operation results in generation of slices. The source of
335
- // these slices could be producers that can be fused into the tiled loops by
336
- // computing the slices of these producers in-place. This results in more
337
- // slices created for operands of the "fused producer". This open up more
338
- // opportunities for fusion. Use a worklist to fuse greedily.
339
- auto addCandidateSlices =
340
- [](Operation *fusedOp, std::deque<tensor::ExtractSliceOp> &candidates) {
341
- for (Value operand : fusedOp->getOperands ())
342
- if (auto sliceOp = operand.getDefiningOp <tensor::ExtractSliceOp>())
343
- candidates.push_back (sliceOp);
344
- };
345
323
346
- std::deque<tensor::ExtractSliceOp> candidates;
347
- addCandidateSlices (tilingResult->tiledOps .back (), candidates);
348
- OpBuilder::InsertionGuard g (rewriter);
349
- auto forLoops = llvm::to_vector (llvm::map_range (
350
- tilingResult->loops , [](auto op) { return cast<scf::ForOp>(op); }));
351
- while (!candidates.empty ()) {
352
- // Traverse the slices in BFS fashion.
353
- tensor::ExtractSliceOp candidateSliceOp = candidates.front ();
354
- candidates.pop_front ();
355
-
356
- // Materialize the slice of the producer in place.
357
- std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
358
- tileAndFuseProducerOfSlice (rewriter, candidateSliceOp, forLoops);
359
- if (!fusedProducer)
360
- continue ;
361
-
362
- // Check if the fused producer has other uses that require the value
363
- // to be yielded from within the tiled loop.
364
- OpResult untiledProducer = fusedProducer->origProducer ;
365
- if (llvm::any_of (untiledProducer.getUsers (), [&](Operation *user) {
366
- return !isIgnoredUser (user, forLoops.front ());
367
- })) {
368
- yieldReplacementForFusedProducer (rewriter, candidateSliceOp,
369
- fusedProducer.value (), forLoops);
370
- yieldedValuesToOrigValues.push_back (untiledProducer);
371
- }
324
+ scf::SCFTileAndFuseOptions tileAndFuseOptions;
325
+ tileAndFuseOptions.setTilingOptions (options);
326
+ scf::SCFTileAndFuseOptions::ControlFnTy controlFn =
327
+ [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer,
328
+ bool isDestinationOperand) {
329
+ Operation *owner = originalProducer.getOwner ();
330
+ return std::make_tuple (true ,
331
+ yielded.contains (owner) && yielded[owner]);
332
+ };
333
+ tileAndFuseOptions.setFusionControlFn (controlFn);
372
334
373
- // Add more fusion candidates to the worklist.
374
- if (auto fusedProducerOp =
375
- fusedProducer->tiledAndFusedProducer .getDefiningOp ())
376
- addCandidateSlices (fusedProducerOp, candidates);
335
+ FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
336
+ scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp (
337
+ rewriter, rootOp, tileAndFuseOptions);
338
+ if (failed (tileAndFuseResult)) {
339
+ return rewriter.notifyMatchFailure (
340
+ rootOp, " failed to tile and fuse with op as root" );
377
341
}
378
342
379
- scf::ForOp outermostLoop = forLoops. front ();
380
- for ( auto [index, origVal] : llvm::enumerate (yieldedValuesToOrigValues)) {
381
- Value replacement = outermostLoop. getResult (index) ;
343
+ for ( auto it : tileAndFuseResult-> replacements ) {
344
+ Value origVal = it. first ;
345
+ Value replacement = it. second ;
382
346
rewriter.replaceUsesWithIf (origVal, replacement, [&](OpOperand &use) {
383
- return !isIgnoredUser (use.getOwner (), outermostLoop);
347
+ Operation *user = use.getOwner ();
348
+ return !isIgnoredUser (user) &&
349
+ !tileAndFuseResult->loops .front ()->isAncestor (user);
384
350
});
385
351
}
352
+
386
353
rewriter.eraseOp (rootOp);
387
- filter.replaceTransformationFilter (rewriter, tilingResult->tiledOps .back ());
354
+ for (auto tiledAndFusedOp : tileAndFuseResult->tiledAndFusedOps )
355
+ if (tiledAndFusedOp->hasAttr (kTransformMarker ))
356
+ filter.replaceTransformationFilter (rewriter, tiledAndFusedOp);
357
+
388
358
return success ();
389
359
}
390
360
0 commit comments