Skip to content

Commit b548633

Browse files
committed
Add naive cost model support
1 parent 010b77a commit b548633

File tree

3 files changed

+178
-73
lines changed

3 files changed

+178
-73
lines changed

include/gc/Transforms/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def FineGrainedFusion : Pass<"fine-grained-fusion",
7777
}];
7878
let dependentDialects = ["func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect",
7979
"tensor::TensorDialect"];
80+
81+
let options = [
82+
Option<"fusionLevel", "fusion-level", "int64_t",
83+
/*default=*/"1",
84+
"Control the granularity of fusion.">,
85+
Option<"useCostModel", "use-cost-model", "bool",
86+
/*default=*/"false",
87+
"Decide if enable cost model to control iterative fusion.">,
88+
];
8089
}
8190

8291
#endif // GC_DIALECT_GC_PASSES

lib/gc/Transforms/FineGrainedFusion.cpp

Lines changed: 132 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
#include "mlir/Interfaces/TilingInterface.h"
2727
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2828
#include "mlir/Transforms/RegionUtils.h"
29-
3029
#include <llvm/Support/Debug.h>
31-
3230
#include <memory>
3331

3432
#include "TilingUsingInterfaceX.h"
@@ -306,37 +304,38 @@ struct CandidateSliceFilterPipeLine
306304
}
307305
};
308306

307+
static FailureOr<int64_t>
308+
computeTileSizeProductOfCandidate(OffsetSizeAndStrideOpInterface candidate) {
309+
SmallVector<OpFoldResult> tileSizes = candidate.getMixedSizes();
310+
int64_t totalSize = 1;
311+
for (auto &tile : tileSizes) {
312+
FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
313+
presburger::BoundType::UB, tile,
314+
/*stopCondition=*/nullptr, /*closedUB=*/true);
315+
if (failed(cstSize)) {
316+
return failure();
317+
}
318+
totalSize *= *cstSize;
319+
};
320+
return totalSize;
321+
}
322+
309323
static int TilingSizeComparer(RewriterBase &rewriter,
310324
OffsetSizeAndStrideOpInterface candidateA,
311325
OffsetSizeAndStrideOpInterface candidateB,
312326
CandidateDefOrUse defOrUse) {
313-
auto computeTotalSize =
314-
[](OffsetSizeAndStrideOpInterface candidate) -> FailureOr<int64_t> {
315-
SmallVector<OpFoldResult> tileSizes = candidate.getMixedSizes();
316-
int64_t totalSize = 1;
317-
for (auto &tile : tileSizes) {
318-
FailureOr<int64_t> cstSize =
319-
ValueBoundsConstraintSet::computeConstantBound(
320-
presburger::BoundType::UB, tile,
321-
/*stopCondition=*/nullptr, /*closedUB=*/true);
322-
if (failed(cstSize)) {
323-
return failure();
324-
}
325-
totalSize *= *cstSize;
326-
};
327-
return totalSize;
328-
};
329-
330-
FailureOr<int64_t> totalSizeA = computeTotalSize(candidateA),
331-
totalSizeB = computeTotalSize(candidateB);
332-
if (failed(totalSizeA) || failed(totalSizeB)) {
327+
FailureOr<int64_t> sizeProductA =
328+
computeTileSizeProductOfCandidate(candidateA),
329+
sizeProductB =
330+
computeTileSizeProductOfCandidate(candidateB);
331+
if (failed(sizeProductA) || failed(sizeProductB)) {
333332
return 0;
334333
}
335334
// deal with equality
336-
if (*totalSizeA == *totalSizeB) {
335+
if (*sizeProductA == *sizeProductB) {
337336
return 0;
338337
} else {
339-
return *totalSizeA < *totalSizeB ? -1 : 1;
338+
return *sizeProductA < *sizeProductB ? -1 : 1;
340339
}
341340
}
342341

@@ -525,22 +524,8 @@ tileAndFuseConsumerOfOpResult(RewriterBase &rewriter, OpResult result,
525524
/// Support iterative producer and consumer fusion in BFS fashion.
526525
LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
527526
RewriterBase &rewriter, Operation *tiledOp,
528-
TargetSystemSpecInterface targetSpec) {
529-
// Flexible options to control which candidate slice would be selected from
530-
// the view of both validity and performance.
531-
CandidateSliceOptions options;
532-
// User-defined filter to control whether to fuse or not. For instance, the
533-
// maximum amount of fused ops is limited to 20(only used for example).
534-
int64_t numTiledOps = 0;
535-
CandidateSliceFilter customizedFilter =
536-
[&numTiledOps](RewriterBase &rewriter,
537-
OffsetSizeAndStrideOpInterface candidate,
538-
CandidateDefOrUse defOrUse) -> LogicalResult {
539-
return success(numTiledOps < 20);
540-
};
541-
// If more than one filters need given, please use filter list instead.
542-
options.addFilter(customizedFilter);
543-
527+
const CandidateSliceOptions &options) {
528+
int numTiledOps = 0;
544529
std::deque<Operation *> tiledOpList = {tiledOp};
545530
while (!tiledOpList.empty()) {
546531
tiledOp = tiledOpList.front();
@@ -564,7 +549,7 @@ LogicalResult iterativelyFuseProducerAndConsumerOfTiledOp(
564549
}
565550
}
566551
}
567-
return success(numTiledOps);
552+
return success(numTiledOps > 1);
568553
}
569554

570555
/// What is single tiled op in loop?
@@ -647,10 +632,68 @@ static bool defaultTilingOfType(RewriterBase &rewriter, Operation *op) {
647632
}
648633
}
649634

650-
void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f) {
651-
// Get target descriptor
652-
TargetSystemSpecInterface targetSpec =
653-
mlir::impl::getTargetSystemSpec(f->getParentOfType<ModuleOp>());
635+
struct IterativeFusionOptions {
636+
bool useCostModel = false;
637+
};
638+
639+
struct SystemDesc {
640+
// get runtime OMP_NUM_THREADS
641+
uint32_t getNumThreads() {
642+
std::optional<Attribute> numThreads = layout.getDevicePropertyValue(
643+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
644+
Builder(ctx).getStringAttr("num_threads"));
645+
if (numThreads && isa<IntegerAttr>(*numThreads)) {
646+
return dyn_cast<IntegerAttr>(*numThreads).getInt();
647+
}
648+
return 1;
649+
}
650+
// get cache size by cacheLevel
651+
size_t getCacheSize(uint8_t cacheLevel) {
652+
if (cacheLevel == 1) {
653+
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
654+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
655+
Builder(ctx).getStringAttr("L1_cache_size_in_bytes"));
656+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
657+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
658+
}
659+
} else if (cacheLevel == 2) {
660+
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
661+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
662+
Builder(ctx).getStringAttr("L2_cache_size_in_bytes"));
663+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
664+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
665+
}
666+
} else if (cacheLevel == 3) {
667+
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
668+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
669+
Builder(ctx).getStringAttr("L3_cache_size_in_bytes"));
670+
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
671+
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
672+
}
673+
}
674+
return 0;
675+
}
676+
677+
// get the maximum vector length in bits
678+
size_t getMaxVectorLength() {
679+
std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue(
680+
Builder(ctx).getStringAttr("CPU" /* device ID*/),
681+
Builder(ctx).getStringAttr("max_vector_width"));
682+
if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) {
683+
return dyn_cast<IntegerAttr>(*maxVectorLength).getInt();
684+
}
685+
return 512;
686+
}
687+
688+
SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {}
689+
690+
private:
691+
DataLayout layout;
692+
MLIRContext *ctx;
693+
};
694+
695+
void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f,
696+
const IterativeFusionOptions &fuseOptions) {
654697
// Collect untiled and tiled ops respectively
655698
llvm::SetVector<Operation *> singleTiledOpInLoop, unTiledOps;
656699

@@ -687,6 +730,30 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f) {
687730
});
688731
return !singleTiledOpInLoop.empty();
689732
};
733+
734+
SystemDesc sysDesc(f->getParentOfType<ModuleOp>());
735+
// Flexible options to control which candidate slice would be selected from
736+
// the view of both validity and performance.
737+
CandidateSliceOptions sliceOptions;
738+
// Since most filters regarding to validity have already been built-in
739+
// enabled. Users could focus on performance related filters, a.k.a. cost
740+
// model.
741+
if (fuseOptions.useCostModel) {
742+
// Customized filter by cost model.
743+
CandidateSliceFilter costModelFilter =
744+
[&sysDesc](RewriterBase &rewriter,
745+
OffsetSizeAndStrideOpInterface candidate,
746+
CandidateDefOrUse defOrUse) -> LogicalResult {
747+
// Get cache size
748+
size_t l2CacheSize = sysDesc.getCacheSize(2);
749+
FailureOr<int64_t> tileSizeProduct =
750+
computeTileSizeProductOfCandidate(candidate);
751+
return success(succeeded(tileSizeProduct) &&
752+
(*tileSizeProduct <= (int64_t)l2CacheSize));
753+
};
754+
sliceOptions.addFilter(costModelFilter);
755+
}
756+
690757
// Iterative tiling and fusion until exhaustion.
691758
while (collectUnTiledOps()) {
692759
// If existing tiled op before tiling.
@@ -696,10 +763,10 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f) {
696763
// Record if any fusion happens
697764
bool changed = false;
698765
// Iteratively fuse in forward and backward fashion.
699-
llvm::for_each(singleTiledOpInLoop, [&rewriter, &targetSpec,
766+
llvm::for_each(singleTiledOpInLoop, [&rewriter, &sliceOptions,
700767
&changed](Operation *tiledOp) {
701768
changed |= succeeded(iterativelyFuseProducerAndConsumerOfTiledOp(
702-
rewriter, tiledOp, targetSpec));
769+
rewriter, tiledOp, sliceOptions));
703770
});
704771
if (!changed) {
705772
// If no new fusion happens, terminate iteration.
@@ -709,26 +776,21 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f) {
709776
(void)mlir::runRegionDCE(rewriter, {f.getRegion()});
710777
}
711778
} else {
712-
// Auto tiling with default tile size if no tiled op found.
713-
auto defaultTileContractionOp = [&rewriter](Operation *op) -> bool {
714-
return defaultTilingOfType<mlir::linalg::ContractionOpInterface>(
715-
rewriter, op);
716-
};
717-
auto defaultTileReductionOp = [&rewriter](Operation *op) -> bool {
718-
return defaultTilingOfType<mlir::linalg::ReduceOp>(rewriter, op);
719-
};
720-
auto defaultTileLinalgOp = [&rewriter](Operation *op) -> bool {
721-
return defaultTilingOfType<mlir::linalg::LinalgOp>(rewriter, op);
722-
};
723-
// Follow tiling priority based on OpTy:
724-
// `Contraction`->`Reduction`->`Elementwise`
725-
SmallVector<std::function<bool(Operation *)>> priorityTilingFns = {
726-
defaultTileContractionOp, defaultTileReductionOp,
727-
defaultTileLinalgOp};
728-
if (llvm::all_of(priorityTilingFns,
729-
[&unTiledOps](function_ref<bool(Operation *)> fn) {
730-
return !llvm::any_of(unTiledOps, fn);
731-
})) {
779+
// Auto tiling with default tile size if no tiled op found. Follow tiling
780+
// priority based on OpTy: `Contraction`->`Reduction`->`Elementwise`.
781+
SmallVector<std::function<bool(RewriterBase &, Operation *)>>
782+
priorityTilingPipeLine = {
783+
defaultTilingOfType<mlir::linalg::ContractionOpInterface>,
784+
defaultTilingOfType<mlir::linalg::ReduceOp>,
785+
defaultTilingOfType<mlir::linalg::LinalgOp>};
786+
if (llvm::all_of(
787+
priorityTilingPipeLine,
788+
[&rewriter, &unTiledOps](
789+
function_ref<bool(RewriterBase &, Operation *)> tilingFn) {
790+
return !llvm::any_of(unTiledOps,
791+
std::bind(tilingFn, std::ref(rewriter),
792+
std::placeholders::_1));
793+
})) {
732794
// If no op can be tiled
733795
break;
734796
}
@@ -738,6 +800,7 @@ void iterativeTilingAndFusion(RewriterBase &rewriter, func::FuncOp &f) {
738800

739801
struct FineGrainedFusion
740802
: public impl::FineGrainedFusionBase<FineGrainedFusion> {
803+
using FineGrainedFusionBase::FineGrainedFusionBase;
741804

742805
public:
743806
void runOnOperation() final {
@@ -748,7 +811,8 @@ struct FineGrainedFusion
748811
// Get rewriter
749812
IRRewriter rewriter(&ctx);
750813
// Run iterative fusion
751-
iterativeTilingAndFusion(rewriter, func);
814+
iterativeTilingAndFusion(rewriter, func,
815+
IterativeFusionOptions{useCostModel});
752816
}
753817

754818
{

test/gc/Transform/fine-grained-fusion.mlir

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1-
// RUN: gc-opt --split-input-file -fine-grained-fusion %s
1+
// RUN: gc-opt --split-input-file -fine-grained-fusion %s --cse
22

3-
module {
3+
module attributes {
4+
dlti.target_system_spec = #dlti.target_system_spec<
5+
"CPU": #dlti.target_device_spec<
6+
#dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>,
7+
#dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>,
8+
#dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>,
9+
#dlti.dl_entry<"num_threads", 56 : i32>,
10+
#dlti.dl_entry<"max_vector_width", 512 : i32>>
11+
>} {
412
func.func @fuse_mlp(%arg0: tensor<128x512xbf16>, %arg1: tensor<32x8x16x32xbf16>, %arg2: tensor<256xbf16>) -> tensor<128x256xbf16> {
513
%c32 = arith.constant 32 : index
614
%c512 = arith.constant 512 : index
@@ -65,7 +73,15 @@ module {
6573
// -----
6674

6775
#map = affine_map<(d0) -> (d0 * 128)>
68-
module {
76+
module attributes {
77+
dlti.target_system_spec = #dlti.target_system_spec<
78+
"CPU": #dlti.target_device_spec<
79+
#dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>,
80+
#dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>,
81+
#dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>,
82+
#dlti.dl_entry<"num_threads", 56 : i32>,
83+
#dlti.dl_entry<"max_vector_width", 512 : i32>>
84+
>} {
6985
func.func @fuse_multiple_consumer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>, %arg3: tensor<256x256xf32>) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
7086
%c0 = arith.constant 0 : index
7187
%c64 = arith.constant 64 : index
@@ -103,7 +119,15 @@ module {
103119
// -----
104120

105121
#map = affine_map<(d0) -> (d0 * 128)>
106-
module {
122+
module attributes {
123+
dlti.target_system_spec = #dlti.target_system_spec<
124+
"CPU": #dlti.target_device_spec<
125+
#dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>,
126+
#dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>,
127+
#dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>,
128+
#dlti.dl_entry<"num_threads", 56 : i32>,
129+
#dlti.dl_entry<"max_vector_width", 512 : i32>>
130+
>} {
107131
func.func @fuse_reduce(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256xf32> {
108132
%c0 = arith.constant 0 : index
109133
%c64 = arith.constant 64 : index
@@ -142,7 +166,15 @@ module {
142166

143167
// -----
144168

145-
module {
169+
module attributes {
170+
dlti.target_system_spec = #dlti.target_system_spec<
171+
"CPU": #dlti.target_device_spec<
172+
#dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>,
173+
#dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>,
174+
#dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>,
175+
#dlti.dl_entry<"num_threads", 56 : i32>,
176+
#dlti.dl_entry<"max_vector_width", 512 : i32>>
177+
>} {
146178
func.func @fuse_with_default_tiling(%arg0: tensor<128x256x256xf32>, %arg1: tensor<128x256x256xf32>) -> tensor<128x256xf32> {
147179
%dest0 = tensor.empty() : tensor<128x256x256xf32>
148180
%0 = linalg.add ins(%arg0, %arg1 : tensor<128x256x256xf32>, tensor<128x256x256xf32>) outs(%dest0 : tensor<128x256x256xf32>) -> tensor<128x256x256xf32>

0 commit comments

Comments
 (0)