Skip to content

Commit 8816159

Browse files
committed
add merge forall pass
1 parent cf580cd commit 8816159

File tree

7 files changed

+238
-350
lines changed

7 files changed

+238
-350
lines changed

include/gc/Transforms/Passes.td

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,19 @@ def DeepTileContractionNamedOp
6060
];
6161
}
6262

63+
def SinkOpIntoInnerLoop : Pass<"sink-op-into-inner-loop"> {
64+
let summary = "Sink operations into inner loops";
65+
let description = [{The pass tries to sink operations into inner loops as deep as possible to maximize the chance for outer loop optimization.
66+
}];
67+
let dependentDialects = [];
68+
}
69+
70+
def MergeNestedForall : Pass<"merge-nested-forall"> {
71+
let summary = "Merge nested scf.forall operations";
72+
let description = [{The pass tries to merge nested forall operations.}];
73+
let dependentDialects = ["scf::SCFDialect"];
74+
}
75+
6376
def GCCPUPipeline : Pass<"gc-cpu-pipeline"> {
6477
let summary = "All-in-one pipeline for GC for CPU";
6578
let dependentDialects = [

lib/gc/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ add_mlir_library(GCPasses
1414
Pipeline.cpp
1515
DeepTileContractionNamedOp.cpp
1616
Tiling.cpp
17+
SinkOpIntoInnerLoop.cpp
18+
MergeNestedForall.cpp
1719

1820
ADDITIONAL_HEADER_DIRS
1921
${PROJECT_SOURCE_DIR}/include

lib/gc/Transforms/DeepTileContractionNamedOp.cpp

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -464,30 +464,36 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
464464
currentOp.getNumLoops(), getAsIndexOpFoldResult(b.getContext(), 0));
465465
SmallVector<unsigned> reductionDims;
466466
currentOp.getReductionDims(reductionDims);
467+
bool tileOnReduction = false;
467468
for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) {
469+
if (llvm::find(reductionDims, d) != reductionDims.end()) {
470+
tileOnReduction = true;
471+
}
468472
if (llvm::find(reductionDims, d) != reductionDims.end() &&
469-
!dyn_cast<PartialReductionOpInterface>(currentOp.getOperation()))
473+
!dyn_cast<PartialReductionOpInterface>(currentOp.getOperation())) {
470474
tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), 0);
471-
else
475+
tileOnReduction = false;
476+
} else
472477
tileSizes[d] = getAsIndexOpFoldResult(b.getContext(), tile);
473478
}
474479
SmallVector<Range> loopRanges =
475480
cast<TilingInterface>(currentOp.getOperation()).getIterationDomain(b);
476481
OpBuilder::InsertionGuard guard(b);
477482
b.setInsertionPoint(currentOp);
478-
if (auto partialInterface =
479-
dyn_cast<PartialReductionOpInterface>(currentOp.getOperation())) {
483+
if (tileOnReduction) {
484+
auto partialInterface =
485+
dyn_cast<PartialReductionOpInterface>(currentOp.getOperation());
480486
for (auto [idx, tile] : llvm::enumerate(tileSizes)) {
481-
if (isConstantIntValue(tile, 0)) {
487+
if (isConstantIntValue(tile, 0) &&
488+
llvm::find(reductionDims, d) != reductionDims.end()) {
482489
tileSizes[idx] = loopRanges[idx].size;
483490
}
484491
}
485-
486492
SmallVector<OpFoldResult> newParallelDims;
487493
for (auto i = 0UL; i < reductionDims.size(); i++) {
488494
newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i));
489495
}
490-
auto tilingResult = linalgX::tileAllUsingForall(
496+
auto tilingResult = linalgX::tileReductionUsingForall(
491497
b, cast<PartialReductionOpInterface>(currentOp.getOperation()), {},
492498
tileSizes, newParallelDims, std::nullopt);
493499
if (failed(tilingResult) &&
@@ -503,8 +509,8 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
503509
}
504510
}
505511
}
506-
} else if (auto tilingInterface =
507-
cast<TilingInterface>(currentOp.getOperation())) {
512+
} else {
513+
auto tilingInterface = cast<TilingInterface>(currentOp.getOperation());
508514
auto tilingResult = linalg::tileToForallOpUsingTileSizes(
509515
b, tilingInterface, tileSizes, std::nullopt);
510516
if (failed(tilingResult))
@@ -597,11 +603,15 @@ struct deepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
597603
? (cfg.NBlock - 1) / cfg.innerMostNBlock + 1
598604
: cfg.NBlock;
599605
// Outer
600-
option.nestedTileSizes.emplace_back(SmallVector<size_t>{
601-
MParallelBlockSize, NParallelBlockSize, KParallelBlockSize});
602-
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForallOp);
603-
option.loopDim.emplace_back(
604-
SmallVector<size_t>{MDimPos[0], NDimPos[0], KDimPos[0]});
606+
for (auto [tile, dim] :
607+
llvm::zip(SmallVector<size_t>{KParallelBlockSize, MParallelBlockSize,
608+
NParallelBlockSize},
609+
SmallVector<size_t>{KDimPos[0], MDimPos[0], NDimPos[0]})) {
610+
option.nestedTileSizes.emplace_back(SmallVector<size_t>{tile});
611+
option.loopType.emplace_back(
612+
OuterLoopGenerationOption::LoopType::ForallOp);
613+
option.loopDim.emplace_back(SmallVector<size_t>{dim});
614+
}
605615
// Middle
606616
for (auto [tile, dim] :
607617
llvm::zip(SmallVector<size_t>{MOuterBlockSize, NOuterBlockSize,
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
//===-- MergeNestedForall.cpp - DESC -------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/SCF/IR/SCF.h"
12+
#include "mlir/IR/Dominance.h"
13+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
14+
#include "mlir/Interfaces/LoopLikeInterface.h"
15+
#include "mlir/Interfaces/SideEffectInterfaces.h"
16+
#include "mlir/Transforms/ControlFlowSinkUtils.h"
17+
18+
namespace mlir {
19+
namespace gc {
20+
#define GEN_PASS_DEF_MERGENESTEDFORALL
21+
#include "gc/Transforms/Passes.h.inc"
22+
23+
namespace {
24+
25+
struct MergeNestedForallLoops : public OpRewritePattern<scf::ForallOp> {
26+
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
27+
28+
LogicalResult matchAndRewrite(scf::ForallOp op,
29+
PatternRewriter &rewriter) const override {
30+
Block &outerBody = *op.getBody();
31+
if (!llvm::hasSingleElement(outerBody.without_terminator()))
32+
return failure();
33+
34+
auto innerOp = dyn_cast<scf::ForallOp>(outerBody.front());
35+
if (!innerOp)
36+
return failure();
37+
38+
for (auto val : outerBody.getArguments())
39+
if (llvm::is_contained(innerOp.getDynamicLowerBound(), val) ||
40+
llvm::is_contained(innerOp.getDynamicUpperBound(), val) ||
41+
llvm::is_contained(innerOp.getDynamicStep(), val))
42+
return failure();
43+
44+
// Reductions are not supported yet.
45+
if (!op.getInits().empty() || !innerOp.getInits().empty())
46+
return failure();
47+
48+
auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
49+
ValueRange iterVals) {
50+
Block &innerBody = *innerOp.getBody();
51+
assert(iterVals.size() ==
52+
(outerBody.getNumArguments() + innerBody.getNumArguments()));
53+
IRMapping mapping;
54+
mapping.map(outerBody.getArguments(),
55+
iterVals.take_front(outerBody.getNumArguments()));
56+
mapping.map(innerBody.getArguments(),
57+
iterVals.take_back(innerBody.getNumArguments()));
58+
for (Operation &op : innerBody)
59+
builder.clone(op, mapping);
60+
};
61+
62+
auto concatValues = [](const auto &first, const auto &second) {
63+
SmallVector<OpFoldResult> ret;
64+
ret.reserve(first.size() + second.size());
65+
ret.assign(first.begin(), first.end());
66+
ret.append(second.begin(), second.end());
67+
return ret;
68+
};
69+
70+
auto newLowerBounds =
71+
concatValues(op.getMixedLowerBound(), innerOp.getMixedLowerBound());
72+
auto newUpperBounds =
73+
concatValues(op.getMixedUpperBound(), innerOp.getMixedUpperBound());
74+
auto newSteps = concatValues(op.getMixedStep(), innerOp.getMixedStep());
75+
rewriter.replaceOpWithNewOp<scf::ForallOp>(
76+
op, newLowerBounds, newUpperBounds, newSteps, ValueRange{},
77+
std::nullopt, bodyBuilder);
78+
return success();
79+
}
80+
};
81+
82+
struct MergeNestedForall
83+
: public impl::MergeNestedForallBase<MergeNestedForall> {
84+
public:
85+
void runOnOperation() final {
86+
auto &ctx = getContext();
87+
RewritePatternSet patterns(&ctx);
88+
89+
patterns.add<MergeNestedForallLoops>(patterns.getContext());
90+
91+
if (failed(applyPatternsAndFoldGreedily(getOperation(),
92+
std::move(patterns)))) {
93+
return signalPassFailure();
94+
}
95+
}
96+
};
97+
98+
} // namespace
99+
} // namespace gc
100+
} // namespace mlir

lib/gc/Transforms/Pipeline.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@
3232

3333
namespace mlir::gc {
3434

35+
void populateCleanUpPasses(mlir::PassManager &pm) {
36+
pm.addPass(createCanonicalizerPass());
37+
pm.addPass(createCSEPass());
38+
pm.addPass(createLoopInvariantCodeMotionPass());
39+
pm.addPass(createControlFlowSinkPass());
40+
pm.addPass(createCSEPass());
41+
pm.addPass(createSCCPPass());
42+
pm.addPass(createMem2Reg());
43+
pm.addPass(createTopologicalSortPass());
44+
}
45+
3546
// linalg + linalgX + tensor
3647
void populateFrontendPasses(mlir::OpPassManager &pm) {
3748
pm.addPass(createConvertOneDNNGraphToLinalg());
@@ -42,13 +53,14 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
4253
// todo: padding propagation pass
4354
// todo: layout propagation pass
4455
// todo: tensor constant propagation pass
45-
// todo: linalg.matmul lowering to (scf.loop + linalg.brgemm) pass
56+
pm.addNestedPass<func::FuncOp>(createDeepTileContractionNamedOp());
4657
// todo: fine-grain fusion pass
4758
// todo: lower linalg to arith/math on virtual vector pass
4859

4960
// REMOVE this pass after the above passes are added. Currently we add this
5061
// pass to make the pipeline work properly
5162
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
63+
populateCleanUpPasses(pm);
5264
}
5365

5466
// scf + arith + math + vector + tensor + linalg.brgemm
@@ -67,6 +79,7 @@ void populateVectorPasses(mlir::OpPassManager &pm) {
6779
// oneDNN graph spec
6880
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
6981
// todo: lower to physical vector pass, device dependent pass
82+
populateCleanUpPasses(pm);
7083
}
7184

7285
// scf + arith + math + vector + memref + linalg.brgemm
@@ -86,6 +99,7 @@ void populateBufferizationPasses(mlir::OpPassManager &pm) {
8699
pm.addNestedPass<func::FuncOp>(bufferization::createBufferLoopHoistingPass());
87100
pm.addNestedPass<func::FuncOp>(bufferization::createBufferDeallocationPass());
88101
pm.addPass(createBufferizationToMemRefPass());
102+
populateCleanUpPasses(pm);
89103
}
90104

91105
// scf + arith + math + vector + memref + func/microkernel
@@ -102,6 +116,12 @@ void populateMicroKernelPasses(mlir::OpPassManager &pm) {
102116
void populateCPURuntimePasses(mlir::OpPassManager &pm) {
103117
// todo: flatten nested parallel pass to support coarse-grain usion
104118
// remove this pass after we add FlattenNestedParallel
119+
pm.addPass(createSinkOpIntoInnerLoop());
120+
pm.addPass(createMergeNestedForall());
121+
populateCleanUpPasses(pm);
122+
pm.addPass(createForallToParallelLoopPass());
123+
pm.addPass(createParallelLoopFusionPass());
124+
pm.addPass(createLoopInvariantCodeMotionPass());
105125
pm.addPass(createConvertSCFToOpenMPPass());
106126
}
107127

@@ -141,7 +161,7 @@ void populateCPUPipeline(mlir::OpPassManager &pm) {
141161
pm.addNestedPass<func::FuncOp>(createConvertLinalgToParallelLoopsPass());
142162
populateMicroKernelPasses(pm);
143163
populateCPURuntimePasses(pm);
144-
// // back-end, llvm dialect
164+
// back-end, llvm dialect
145165
populateLLVMPasses(pm);
146166
}
147167

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//===-- SinkOpIntoInnerLoop.cpp - DESC -------------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Transforms/Passes.h"
10+
11+
#include "mlir/IR/Dominance.h"
12+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
13+
#include "mlir/Interfaces/LoopLikeInterface.h"
14+
#include "mlir/Interfaces/SideEffectInterfaces.h"
15+
#include "mlir/Transforms/ControlFlowSinkUtils.h"
16+
17+
namespace mlir {
18+
namespace gc {
19+
#define GEN_PASS_DEF_SINKOPINTOINNERLOOP
20+
#include "gc/Transforms/Passes.h.inc"
21+
22+
namespace {
23+
24+
struct SinkOpIntoInnerLoop
25+
: public impl::SinkOpIntoInnerLoopBase<SinkOpIntoInnerLoop> {
26+
public:
27+
void runOnOperation() final {
28+
auto &domInfo = getAnalysis<DominanceInfo>();
29+
getOperation()->walk([&](LoopLikeOpInterface loop) {
30+
SmallVector<Region *> regionsToSink;
31+
// Get the regions are that known to be executed at most once.
32+
for (auto &it : loop->getRegions()) {
33+
regionsToSink.push_back(&it);
34+
}
35+
// Sink side-effect free operations.
36+
controlFlowSink(
37+
regionsToSink, domInfo,
38+
[](Operation *op, Region *) { return isMemoryEffectFree(op); },
39+
[](Operation *op, Region *region) {
40+
// Move the operation to the beginning of the region's entry block.
41+
// This guarantees the preservation of SSA dominance of all of the
42+
// operation's uses are in the region.
43+
op->moveBefore(&region->front(), region->front().begin());
44+
});
45+
});
46+
}
47+
};
48+
49+
} // namespace
50+
} // namespace gc
51+
} // namespace mlir

0 commit comments

Comments
 (0)