@@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
55
55
return filledVector;
56
56
}
57
57
58
+ // / Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
59
+ template <typename SrcOpTy>
60
+ static SmallVector<Operation *> getAsOperations (ArrayRef<SrcOpTy> ops) {
61
+ return llvm::to_vector (
62
+ llvm::map_range (ops, [](auto op) -> Operation * { return op; }));
63
+ }
64
+ template <typename SrcOpTy>
65
+ static SmallVector<Operation *>
66
+ getAsOperations (const SmallVector<SrcOpTy> &ops) {
67
+ return getAsOperations (ArrayRef<SrcOpTy>(ops));
68
+ }
69
+
70
+ // / Convert a list of `Operation *` to a list of `DstOpTy.
71
+ template <typename DstOpTy>
72
+ static SmallVector<DstOpTy> castToTypedOperations (ArrayRef<Operation *> ops) {
73
+ return llvm::to_vector (
74
+ llvm::map_range (ops, [](Operation *op) { return cast<DstOpTy>(op); }));
75
+ }
76
+ template <typename DstOpTy>
77
+ static SmallVector<DstOpTy>
78
+ castToTypedOperations (const SmallVector<Operation *> &ops) {
79
+ return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
80
+ }
81
+
58
82
// ===----------------------------------------------------------------------===//
59
83
// tileUsingSCFForOp implementation.
60
84
// ===----------------------------------------------------------------------===//
@@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
77
101
// / `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
78
102
static OpFoldResult getBoundedTileSize (OpBuilder &b, Location loc,
79
103
Range loopRange, Value iv,
80
- Value tileSize) {
81
- std::optional<int64_t > ts = getConstantIntValue (tileSize);
82
- if (ts && ts.value () == 1 )
83
- return getAsOpFoldResult (tileSize);
104
+ OpFoldResult tileSize) {
105
+ if (isConstantIntValue (tileSize, 1 ))
106
+ return tileSize;
84
107
85
108
if (tileDividesIterationDomain (
86
109
Range{loopRange.offset , loopRange.size , tileSize}))
@@ -295,8 +318,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
295
318
tileSizeVector.append (numLoops - tileSizeVector.size (), zero);
296
319
}
297
320
298
- scf::SCFTilingResult tilingResult;
299
321
SmallVector<OpFoldResult> offsets, sizes;
322
+ SmallVector<scf::ForOp> forLoops;
300
323
{
301
324
// If there is an interchange specified, permute the iteration domain and
302
325
// the tile sizes.
@@ -319,8 +342,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
319
342
// 3. Materialize an empty loop nest that iterates over the tiles. These
320
343
// loops for now do not return any values even if the original operation has
321
344
// results.
322
- tilingResult. loops = generateTileLoopNest (
323
- rewriter, op. getLoc (), iterationDomain, tileSizeVector, offsets, sizes);
345
+ forLoops = generateTileLoopNest (rewriter, op. getLoc (), iterationDomain,
346
+ tileSizeVector, offsets, sizes);
324
347
325
348
if (!interchangeVector.empty ()) {
326
349
auto inversePermutation = invertPermutationVector (interchangeVector);
@@ -330,30 +353,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
330
353
}
331
354
332
355
LLVM_DEBUG ({
333
- if (!tilingResult. loops .empty ()) {
356
+ if (!forLoops .empty ()) {
334
357
llvm::dbgs () << " LoopNest shell :\n " ;
335
- tilingResult. loops .front ().dump ();
358
+ forLoops .front ().dump ();
336
359
llvm::dbgs () << " \n " ;
337
360
}
338
361
});
339
362
340
363
// 4. Generate the tiled implementation within the inner most loop.
341
- if (!tilingResult.loops .empty ())
342
- rewriter.setInsertionPoint (
343
- tilingResult.loops .back ().getBody ()->getTerminator ());
364
+ if (!forLoops.empty ())
365
+ rewriter.setInsertionPoint (forLoops.back ().getBody ()->getTerminator ());
344
366
FailureOr<TilingResult> tiledImplementation =
345
367
op.getTiledImplementation (rewriter, offsets, sizes);
346
- tilingResult. tiledOps . append (tiledImplementation-> tiledOps );
368
+
347
369
if (op->getNumResults () == 0 ) {
348
- // nothing more to do.
349
- return tilingResult ;
370
+ return scf::SCFTilingResult{
371
+ tiledImplementation-> tiledOps , getAsOperations (forLoops), {}} ;
350
372
}
351
373
352
374
// If loops are empty, the tiled op is used as the replacement for the untiled
353
375
// op.
354
- if (tilingResult.loops .empty ()) {
355
- tilingResult.replacements = tiledImplementation->tiledValues ;
356
- return tilingResult;
376
+ if (forLoops.empty ()) {
377
+ return scf::SCFTilingResult{tiledImplementation->tiledOps ,
378
+ getAsOperations (forLoops),
379
+ tiledImplementation->tiledValues };
357
380
}
358
381
359
382
// 5. Yield all the results of the tiled operation. The surrounding loop
@@ -377,18 +400,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
377
400
destinationTensors)))
378
401
return rewriter.notifyMatchFailure (op, " failed to get destinations" );
379
402
380
- tilingResult. replacements = yieldTiledValues (
403
+ SmallVector<Value> replacements = yieldTiledValues (
381
404
rewriter, destinationTensors, tiledImplementation.value (),
382
- resultOffsetsList, resultSizesList, tilingResult.loops );
383
-
405
+ resultOffsetsList, resultSizesList, forLoops);
384
406
LLVM_DEBUG ({
385
- if (!tilingResult. loops .empty ()) {
407
+ if (!forLoops .empty ()) {
386
408
llvm::dbgs () << " After tiled implementation :\n " ;
387
- tilingResult. loops .front ().dump ();
409
+ forLoops .front ().dump ();
388
410
llvm::dbgs () << " \n " ;
389
411
}
390
412
});
391
- return tilingResult;
413
+ return scf::SCFTilingResult{tiledImplementation->tiledOps ,
414
+ getAsOperations (forLoops), replacements};
392
415
}
393
416
394
417
FailureOr<scf::SCFReductionTilingResult>
@@ -466,6 +489,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
466
489
results.mergeOp = mergeOp;
467
490
return results;
468
491
}
492
+
469
493
// ===----------------------------------------------------------------------===//
470
494
// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
471
495
// ===----------------------------------------------------------------------===//
@@ -636,28 +660,31 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
636
660
}
637
661
638
662
// 1. First tile the consumer.
639
- scf::SCFTileAndFuseResult tileAndFuseResult;
663
+ SmallVector<scf::ForOp> forLoops;
664
+ SetVector<Operation *> fusedProducers, tiledAndFusedOps;
665
+ DenseMap<Value, Value> replacements;
640
666
llvm::SmallDenseMap<Value, int64_t > yieldedValueToResultNumber;
641
667
{
642
668
FailureOr<scf::SCFTilingResult> tilingResult =
643
669
tileUsingSCFForOp (rewriter, consumer, options.tilingOptions );
644
670
if (failed (tilingResult))
645
671
return rewriter.notifyMatchFailure (consumer, " failed to tile consumer" );
646
672
for (auto *tiledOp : tilingResult->tiledOps )
647
- tileAndFuseResult.tiledAndFusedOps .insert (tiledOp);
648
- tileAndFuseResult.loops = std::move (tilingResult->loops );
649
- for (const auto &result : llvm::enumerate (
650
- llvm::zip (consumer->getResults (), tilingResult->replacements ))) {
651
- tileAndFuseResult.replacements [std::get<0 >(result.value ())] =
652
- std::get<1 >(result.value ());
673
+ tiledAndFusedOps.insert (tiledOp);
674
+ forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops );
675
+ for (auto [index, origValue, replacement] :
676
+ llvm::enumerate (consumer->getResults (), tilingResult->replacements )) {
677
+ replacements[origValue] = replacement;
653
678
yieldedValueToResultNumber[tilingResult->tiledOps .back ()->getResult (
654
- result. index ()) ] = result. index () ;
679
+ index) ] = index;
655
680
}
656
681
}
657
682
658
683
// If there are no loops generated, fusion is immaterial.
659
- if (tileAndFuseResult.loops .empty ())
660
- return tileAndFuseResult;
684
+ if (forLoops.empty ()) {
685
+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
686
+ getAsOperations (forLoops), replacements};
687
+ }
661
688
662
689
// 2. Typically, the operands of the tiled operation are slices of the
663
690
// operands of the untiled operation. These are expressed in IR using
@@ -674,7 +701,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
674
701
};
675
702
676
703
std::deque<tensor::ExtractSliceOp> candidates;
677
- addCandidateSlices (tileAndFuseResult. tiledAndFusedOps .back (), candidates);
704
+ addCandidateSlices (tiledAndFusedOps.back (), candidates);
678
705
OpBuilder::InsertionGuard g (rewriter);
679
706
while (!candidates.empty ()) {
680
707
// Traverse the slices in BFS fashion.
@@ -684,19 +711,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
684
711
// The operands of the fused producer might themselved be slices of
685
712
// values produced by operations that implement the `TilingInterface`.
686
713
// Add these operations to the worklist.
687
- std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
688
- tileAndFuseProducerOfSlice (rewriter, candidateSliceOp,
689
- tileAndFuseResult.loops );
690
- if (!fusedProducer)
714
+ std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
715
+ tileAndFuseProducerOfSlice (rewriter, candidateSliceOp, forLoops);
716
+ if (!fusedResult)
691
717
continue ;
692
718
693
719
if (Operation *tiledAndFusedOp =
694
- fusedProducer->tiledAndFusedProducer .getDefiningOp ()) {
695
- tileAndFuseResult.tiledAndFusedOps .insert (tiledAndFusedOp);
720
+ fusedResult->tiledAndFusedProducer .getDefiningOp ()) {
721
+ fusedProducers.insert (fusedResult->origProducer .getDefiningOp ());
722
+ tiledAndFusedOps.insert (tiledAndFusedOp);
696
723
addCandidateSlices (tiledAndFusedOp, candidates);
697
724
}
698
725
}
699
- return tileAndFuseResult;
726
+ return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
727
+ getAsOperations (forLoops), replacements};
700
728
}
701
729
702
730
// ===----------------------------------------------------------------------===//
0 commit comments