26
26
#include " mlir/Interfaces/TilingInterface.h"
27
27
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
28
28
#include " mlir/Transforms/RegionUtils.h"
29
-
30
29
#include < llvm/Support/Debug.h>
31
-
32
30
#include < memory>
33
31
34
32
#include " TilingUsingInterfaceX.h"
@@ -306,37 +304,38 @@ struct CandidateSliceFilterPipeLine
306
304
}
307
305
};
308
306
307
+ static FailureOr<int64_t >
308
+ computeTileSizeProductOfCandidate (OffsetSizeAndStrideOpInterface candidate) {
309
+ SmallVector<OpFoldResult> tileSizes = candidate.getMixedSizes ();
310
+ int64_t totalSize = 1 ;
311
+ for (auto &tile : tileSizes) {
312
+ FailureOr<int64_t > cstSize = ValueBoundsConstraintSet::computeConstantBound (
313
+ presburger::BoundType::UB, tile,
314
+ /* stopCondition=*/ nullptr , /* closedUB=*/ true );
315
+ if (failed (cstSize)) {
316
+ return failure ();
317
+ }
318
+ totalSize *= *cstSize;
319
+ };
320
+ return totalSize;
321
+ }
322
+
309
323
static int TilingSizeComparer (RewriterBase &rewriter,
310
324
OffsetSizeAndStrideOpInterface candidateA,
311
325
OffsetSizeAndStrideOpInterface candidateB,
312
326
CandidateDefOrUse defOrUse) {
313
- auto computeTotalSize =
314
- [](OffsetSizeAndStrideOpInterface candidate) -> FailureOr<int64_t > {
315
- SmallVector<OpFoldResult> tileSizes = candidate.getMixedSizes ();
316
- int64_t totalSize = 1 ;
317
- for (auto &tile : tileSizes) {
318
- FailureOr<int64_t > cstSize =
319
- ValueBoundsConstraintSet::computeConstantBound (
320
- presburger::BoundType::UB, tile,
321
- /* stopCondition=*/ nullptr , /* closedUB=*/ true );
322
- if (failed (cstSize)) {
323
- return failure ();
324
- }
325
- totalSize *= *cstSize;
326
- };
327
- return totalSize;
328
- };
329
-
330
- FailureOr<int64_t > totalSizeA = computeTotalSize (candidateA),
331
- totalSizeB = computeTotalSize (candidateB);
332
- if (failed (totalSizeA) || failed (totalSizeB)) {
327
+ FailureOr<int64_t > sizeProductA =
328
+ computeTileSizeProductOfCandidate (candidateA),
329
+ sizeProductB =
330
+ computeTileSizeProductOfCandidate (candidateB);
331
+ if (failed (sizeProductA) || failed (sizeProductB)) {
333
332
return 0 ;
334
333
}
335
334
// deal with equality
336
- if (*totalSizeA == *totalSizeB ) {
335
+ if (*sizeProductA == *sizeProductB ) {
337
336
return 0 ;
338
337
} else {
339
- return *totalSizeA < *totalSizeB ? -1 : 1 ;
338
+ return *sizeProductA < *sizeProductB ? -1 : 1 ;
340
339
}
341
340
}
342
341
@@ -525,22 +524,8 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
525
524
// / Support iterative producer and consumer fusion in BFS fashion.
526
525
LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp (
527
526
RewriterBase &rewriter, Operation *tiledOp,
528
- TargetSystemSpecInterface targetSpec) {
529
- // Flexible options to control which candidate slice would be selected from
530
- // the view of both validity and performance.
531
- CandidateSliceOptions options;
532
- // User-defined filter to control whether to fuse or not. For instance, the
533
- // maximum amount of fused ops is limited to 20(only used for example).
534
- int64_t numTiledOps = 0 ;
535
- CandidateSliceFilter customizedFilter =
536
- [&numTiledOps](RewriterBase &rewriter,
537
- OffsetSizeAndStrideOpInterface candidate,
538
- CandidateDefOrUse defOrUse) -> LogicalResult {
539
- return success (numTiledOps < 20 );
540
- };
541
- // If more than one filters need given, please use filter list instead.
542
- options.addFilter (customizedFilter);
543
-
527
+ const CandidateSliceOptions &options) {
528
+ int numTiledOps = 0 ;
544
529
std::deque<Operation *> tiledOpList = {tiledOp};
545
530
while (!tiledOpList.empty ()) {
546
531
tiledOp = tiledOpList.front ();
@@ -564,7 +549,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
564
549
}
565
550
}
566
551
}
567
- return success (numTiledOps);
552
+ return success (numTiledOps > 1 );
568
553
}
569
554
570
555
// / What is single tiled op in loop?
@@ -647,10 +632,68 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op) {
647
632
}
648
633
}
649
634
650
- void iterativeTilingAndFusion (RewriterBase &rewriter, func::FuncOp &f) {
651
- // Get target descriptor
652
- TargetSystemSpecInterface targetSpec =
653
- mlir::impl::getTargetSystemSpec (f->getParentOfType <ModuleOp>());
635
+ struct IterativeFusionOptions {
636
+ bool useCostModel = false ;
637
+ };
638
+
639
+ struct SystemDesc {
640
+ // get runtime OMP_NUM_THREADS
641
+ uint32_t getNumThreads () {
642
+ std::optional<Attribute> numThreads = layout.getDevicePropertyValue (
643
+ Builder (ctx).getStringAttr (" CPU" /* device ID*/ ),
644
+ Builder (ctx).getStringAttr (" num_threads" ));
645
+ if (numThreads && isa<IntegerAttr>(*numThreads)) {
646
+ return dyn_cast<IntegerAttr>(*numThreads).getInt ();
647
+ }
648
+ return 1 ;
649
+ }
650
+ // get cache size by cacheLevel
651
+ size_t getCacheSize (uint8_t cacheLevel) {
652
+ if (cacheLevel == 1 ) {
653
+ std::optional<Attribute> cacheSize = layout.getDevicePropertyValue (
654
+ Builder (ctx).getStringAttr (" CPU" /* device ID*/ ),
655
+ Builder (ctx).getStringAttr (" L1_cache_size_in_bytes" ));
656
+ if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
657
+ return dyn_cast<IntegerAttr>(*cacheSize).getInt ();
658
+ }
659
+ } else if (cacheLevel == 2 ) {
660
+ std::optional<Attribute> cacheSize = layout.getDevicePropertyValue (
661
+ Builder (ctx).getStringAttr (" CPU" /* device ID*/ ),
662
+ Builder (ctx).getStringAttr (" L2_cache_size_in_bytes" ));
663
+ if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
664
+ return dyn_cast<IntegerAttr>(*cacheSize).getInt ();
665
+ }
666
+ } else if (cacheLevel == 3 ) {
667
+ std::optional<Attribute> cacheSize = layout.getDevicePropertyValue (
668
+ Builder (ctx).getStringAttr (" CPU" /* device ID*/ ),
669
+ Builder (ctx).getStringAttr (" L3_cache_size_in_bytes" ));
670
+ if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
671
+ return dyn_cast<IntegerAttr>(*cacheSize).getInt ();
672
+ }
673
+ }
674
+ return 0 ;
675
+ }
676
+
677
+ // get the maximum vector length in bits
678
+ size_t getMaxVectorLength () {
679
+ std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue (
680
+ Builder (ctx).getStringAttr (" CPU" /* device ID*/ ),
681
+ Builder (ctx).getStringAttr (" max_vector_width" ));
682
+ if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) {
683
+ return dyn_cast<IntegerAttr>(*maxVectorLength).getInt ();
684
+ }
685
+ return 512 ;
686
+ }
687
+
688
+ SystemDesc (ModuleOp m) : layout(m), ctx(m->getContext ()) {}
689
+
690
+ private:
691
+ DataLayout layout;
692
+ MLIRContext *ctx;
693
+ };
694
+
695
+ void iterativeTilingAndFusion (RewriterBase &rewriter, func::FuncOp &f,
696
+ const IterativeFusionOptions &fuseOptions) {
654
697
// Collect untiled and tiled ops respectively
655
698
llvm::SetVector<Operation *> singleTiledOpInLoop, unTiledOps;
656
699
@@ -687,6 +730,30 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f) {
687
730
});
688
731
return !singleTiledOpInLoop.empty ();
689
732
};
733
+
734
+ SystemDesc sysDesc (f->getParentOfType <ModuleOp>());
735
+ // Flexible options to control which candidate slice would be selected from
736
+ // the view of both validity and performance.
737
+ CandidateSliceOptions sliceOptions;
738
+ // Since most filters regarding to validity have already been built-in
739
+ // enabled. Users could focus on performance related filters, a.k.a. cost
740
+ // model.
741
+ if (fuseOptions.useCostModel ) {
742
+ // Customized filter by cost model.
743
+ CandidateSliceFilter costModelFilter =
744
+ [&sysDesc](RewriterBase &rewriter,
745
+ OffsetSizeAndStrideOpInterface candidate,
746
+ CandidateDefOrUse defOrUse) -> LogicalResult {
747
+ // Get cache size
748
+ size_t l2CacheSize = sysDesc.getCacheSize (2 );
749
+ FailureOr<int64_t > tileSizeProduct =
750
+ computeTileSizeProductOfCandidate (candidate);
751
+ return success (succeeded (tileSizeProduct) &&
752
+ (*tileSizeProduct <= (int64_t )l2CacheSize));
753
+ };
754
+ sliceOptions.addFilter (costModelFilter);
755
+ }
756
+
690
757
// Iterative tiling and fusion until exhaustion.
691
758
while (collectUnTiledOps ()) {
692
759
// If existing tiled op before tiling.
@@ -696,10 +763,10 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f) {
696
763
// Record if any fusion happens
697
764
bool changed = false ;
698
765
// Iteratively fuse in forward and backward fashion.
699
- llvm::for_each (singleTiledOpInLoop, [&rewriter, &targetSpec ,
766
+ llvm::for_each (singleTiledOpInLoop, [&rewriter, &sliceOptions ,
700
767
&changed](Operation *tiledOp) {
701
768
changed |= succeeded (iterativelyFuseProducerAndConsumerOfTiledOp (
702
- rewriter, tiledOp, targetSpec ));
769
+ rewriter, tiledOp, sliceOptions ));
703
770
});
704
771
if (!changed) {
705
772
// If no new fusion happens, terminate iteration.
@@ -709,26 +776,21 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f) {
709
776
(void )mlir::runRegionDCE (rewriter, {f.getRegion ()});
710
777
}
711
778
} else {
712
- // Auto tiling with default tile size if no tiled op found.
713
- auto defaultTileContractionOp = [&rewriter](Operation *op) -> bool {
714
- return defaultTilingOfType<mlir::linalg::ContractionOpInterface>(
715
- rewriter, op);
716
- };
717
- auto defaultTileReductionOp = [&rewriter](Operation *op) -> bool {
718
- return defaultTilingOfType<mlir::linalg::ReduceOp>(rewriter, op);
719
- };
720
- auto defaultTileLinalgOp = [&rewriter](Operation *op) -> bool {
721
- return defaultTilingOfType<mlir::linalg::LinalgOp>(rewriter, op);
722
- };
723
- // Follow tiling priority based on OpTy:
724
- // `Contraction`->`Reduction`->`Elementwise`
725
- SmallVector<std::function<bool (Operation *)>> priorityTilingFns = {
726
- defaultTileContractionOp, defaultTileReductionOp,
727
- defaultTileLinalgOp};
728
- if (llvm::all_of (priorityTilingFns,
729
- [&unTiledOps](function_ref<bool (Operation *)> fn) {
730
- return !llvm::any_of (unTiledOps, fn);
731
- })) {
779
+ // Auto tiling with default tile size if no tiled op found. Follow tiling
780
+ // priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
781
+ SmallVector<std::function<bool (RewriterBase &, Operation *)>>
782
+ priorityTilingPipeLine = {
783
+ defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
784
+ defaultTilingOfType<mlir::linalg::ReduceOp>,
785
+ defaultTilingOfType<mlir::linalg::LinalgOp>};
786
+ if (llvm::all_of (
787
+ priorityTilingPipeLine,
788
+ [&rewriter, &unTiledOps](
789
+ function_ref<bool (RewriterBase &, Operation *)> tilingFn) {
790
+ return !llvm::any_of (unTiledOps,
791
+ std::bind (tilingFn, std::ref (rewriter),
792
+ std::placeholders::_1));
793
+ })) {
732
794
// If no op can be tiled
733
795
break ;
734
796
}
@@ -738,6 +800,7 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f) {
738
800
739
801
struct FineGrainedFusion
740
802
: public impl::FineGrainedFusionBase<FineGrainedFusion> {
803
+ using FineGrainedFusionBase::FineGrainedFusionBase;
741
804
742
805
public:
743
806
void runOnOperation () final {
@@ -748,7 +811,8 @@ struct FineGrainedFusion
748
811
// Get rewriter
749
812
IRRewriter rewriter (&ctx);
750
813
// Run iterative fusion
751
- iterativeTilingAndFusion (rewriter, func);
814
+ iterativeTilingAndFusion (rewriter, func,
815
+ IterativeFusionOptions{useCostModel});
752
816
}
753
817
754
818
{
0 commit comments