Skip to content

Commit cb77f94

Browse files
author
ZhangYan
committed
Merge remote-tracking branch 'origin/main' into zhangyan/fix_perf
2 parents 5f51e49 + cad8a29 commit cb77f94

File tree

3 files changed

+427
-17
lines changed

3 files changed

+427
-17
lines changed

lib/gc/Transforms/MergeAllocTickBased.cpp

Lines changed: 208 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,65 @@ using namespace special_ticks;
3030
/// and default memory space.
3131
static bool isMemRefTypeOk(MemRefType type) { return type.hasStaticShape(); }
3232

33+
static inline int64_t getSizeInBytes(MemRefType &memType) {
34+
// treat bool (i1) as 1 byte. It may not be true for all targets, but we at
35+
// least have a large enough size for i1
36+
int64_t size = memType.getElementTypeBitWidth() / 8;
37+
size = (size > 0) ? size : 1;
38+
for (auto v : memType.getShape()) {
39+
size *= v;
40+
}
41+
return size;
42+
}
43+
44+
static bool needsHoistOutOfParallelLoop(Operation *op) {
45+
Operation *parent =
46+
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
47+
if (isa_and_nonnull<scf::ForallOp>(parent)) {
48+
// check if the current allocation is between the nested pfor, and use
49+
// inside the inner parallel loop
50+
SmallVector<Operation *, 4> parallelOpInCurBlock;
51+
Block *curBlock = op->getBlock();
52+
for (auto &curOp : curBlock->getOperations()) {
53+
if (isa<scf::ForallOp>(curOp)) {
54+
parallelOpInCurBlock.push_back(&curOp);
55+
}
56+
}
57+
58+
if (parallelOpInCurBlock.empty())
59+
return false;
60+
61+
for (auto *use : op->getUsers()) {
62+
for (auto *parallelOp : parallelOpInCurBlock) {
63+
if (parallelOp->isAncestor(use)) {
64+
return true;
65+
}
66+
}
67+
}
68+
}
69+
70+
return false;
71+
}
72+
73+
static bool isForallLoopBoundStatic(Operation *op) {
74+
auto forallOp = dyn_cast<scf::ForallOp>(op);
75+
if (!forallOp)
76+
return false;
77+
78+
auto lbs = forallOp.getMixedLowerBound();
79+
auto ubs = forallOp.getMixedUpperBound();
80+
auto steps = forallOp.getMixedStep();
81+
auto allConstantValue = [](SmallVector<OpFoldResult> vals) -> bool {
82+
return llvm::all_of(vals, [](OpFoldResult val) {
83+
std::optional<int64_t> const_val = getConstantIntValue(val);
84+
return const_val.has_value();
85+
});
86+
};
87+
88+
return allConstantValue(lbs) && allConstantValue(ubs) &&
89+
allConstantValue(steps);
90+
}
91+
3392
void Tick::update(int64_t tick) {
3493
if (tick == UNTRACEABLE_ACCESS) {
3594
firstAccess = UNTRACEABLE_ACCESS;
@@ -180,28 +239,60 @@ bool TickCollecter::isMergeableAlloc(TickCollecterStates *s, Operation *op,
180239
// trait, and is not scf.for
181240
Operation *TickCollecter::getAllocScope(TickCollecterStates *s,
182241
Operation *op) const {
183-
auto parent = op;
242+
Operation *parent = op;
243+
bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop(op);
244+
184245
for (;;) {
185246
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
186247
if (!parent) {
187248
return nullptr;
188249
}
189-
if (!isa<scf::ForOp>(parent)) {
190-
return parent;
191-
}
250+
251+
if (isa<scf::ForOp>(parent))
252+
continue;
253+
254+
if (isa<scf::ForallOp>(parent) &&
255+
(moveToUpperParellelLoop && isForallLoopBoundStatic(parent)))
256+
continue;
257+
258+
return parent;
192259
}
193260
}
194261

195262
FailureOr<size_t> TickCollecter::getAllocSize(TickCollecterStates *s,
196263
Operation *op) const {
197264
auto refType = cast<MemRefType>(op->getResultTypes().front());
198-
int64_t size = refType.getElementTypeBitWidth() / 8;
199-
// treat bool (i1) as 1 byte. It may not be true for all targets, but we at
200-
// least have a large enough size for i1
201-
size = (size != 0) ? size : 1;
202-
for (auto v : refType.getShape()) {
203-
size *= v;
265+
266+
// Get the total number of threads from the outermost to the current level of
267+
// the parallel loop that the allocation located in.
268+
int64_t numThreads = 1;
269+
if (needsHoistOutOfParallelLoop(op)) {
270+
Operation *parent =
271+
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
272+
while (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
273+
if (!isForallLoopBoundStatic(forallOp))
274+
break;
275+
276+
OpBuilder builder{forallOp->getContext()};
277+
std::optional<int64_t> numIterations;
278+
for (auto [lb, ub, step] : llvm::zip(forallOp.getLowerBound(builder),
279+
forallOp.getUpperBound(builder),
280+
forallOp.getStep(builder))) {
281+
numIterations = constantTripCount(lb, ub, step);
282+
if (numIterations.has_value()) {
283+
numThreads *= numIterations.value();
284+
} else {
285+
return op->emitError("Expecting static loop range!");
286+
}
287+
}
288+
289+
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
290+
}
204291
}
292+
assert(numThreads > 0);
293+
294+
int64_t size = getSizeInBytes(refType);
295+
size *= numThreads;
205296
if (size > 0) {
206297
return static_cast<size_t>(size);
207298
}
@@ -391,11 +482,113 @@ Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Block *scope,
391482
Value mergedAlloc,
392483
int64_t byteOffset) const {
393484
builder.setInsertionPoint(origAllocOp);
394-
auto byteShift =
395-
builder.create<arith::ConstantIndexOp>(origAllocOp->getLoc(), byteOffset);
396-
return builder.create<memref::ViewOp>(origAllocOp->getLoc(),
397-
origAllocOp->getResultTypes().front(),
398-
mergedAlloc, byteShift, ValueRange{});
485+
auto loc = origAllocOp->getLoc();
486+
auto byteShift = builder.create<arith::ConstantIndexOp>(loc, byteOffset);
487+
488+
bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop(origAllocOp);
489+
Operation *parent =
490+
origAllocOp->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
491+
if (!moveToUpperParellelLoop || !parent || !isa<scf::ForallOp>(parent))
492+
return builder.create<memref::ViewOp>(loc,
493+
origAllocOp->getResultTypes().front(),
494+
mergedAlloc, byteShift, ValueRange{});
495+
496+
// get the aggregated inductorVar
497+
Value inductVar;
498+
bool isOuterMostLoop = true;
499+
int64_t innerLoopUpperBound = 1;
500+
while (parent) {
501+
if (auto forallOp = dyn_cast<scf::ForallOp>(parent)) {
502+
if (isForallLoopBoundStatic(forallOp)) {
503+
SmallVector<Value> ubs = forallOp.getUpperBound(builder);
504+
SmallVector<Value> lbs = forallOp.getLowerBound(builder);
505+
SmallVector<Value> steps = forallOp.getStep(builder);
506+
SmallVector<Value> inductionVars = forallOp.getInductionVars();
507+
508+
auto getCurrentVar = [&loc, &builder](Value var, Value lb,
509+
Value step) -> Value {
510+
if (!isConstantIntValue(lb, 0))
511+
var = builder.create<arith::SubIOp>(loc, var, lb);
512+
513+
if (!isConstantIntValue(step, 1))
514+
var = builder.create<arith::DivSIOp>(loc, var, step);
515+
return var;
516+
};
517+
518+
auto getAggregatedVar =
519+
[&loc, &builder, &getCurrentVar](
520+
const SmallVector<Value> &_lbs, const SmallVector<Value> &_ubs,
521+
const SmallVector<Value> &_steps,
522+
const SmallVector<Value> &_inductVars) -> Value {
523+
Value var;
524+
if (_ubs.size() == 1) {
525+
var = getCurrentVar(_inductVars[0], _lbs[0], _steps[0]);
526+
return var;
527+
} else {
528+
bool isFirstLoop = true;
529+
for (auto [lb, ub, step, inductVar] :
530+
llvm::zip(_lbs, _ubs, _steps, _inductVars)) {
531+
if (isFirstLoop) {
532+
var = getCurrentVar(inductVar, lb, step);
533+
isFirstLoop = false;
534+
} else {
535+
Value cur_var = getCurrentVar(inductVar, lb, step);
536+
std::optional<int64_t> bound = constantTripCount(lb, ub, step);
537+
assert(bound.has_value());
538+
Value boundVal =
539+
builder.create<arith::ConstantIndexOp>(loc, bound.value());
540+
Value tmpVal =
541+
builder.create<arith::MulIOp>(loc, var, boundVal);
542+
var = builder.create<arith::AddIOp>(loc, tmpVal, cur_var);
543+
}
544+
}
545+
return var;
546+
}
547+
};
548+
549+
if (isOuterMostLoop) {
550+
inductVar = getAggregatedVar(lbs, ubs, steps, inductionVars);
551+
isOuterMostLoop = false;
552+
} else {
553+
Value currentVar = getAggregatedVar(lbs, ubs, steps, inductionVars);
554+
555+
Value innerLoopBoundVal =
556+
builder.create<arith::ConstantIndexOp>(loc, innerLoopUpperBound);
557+
Value intermediateVal =
558+
builder.create<arith::MulIOp>(loc, currentVar, innerLoopBoundVal);
559+
inductVar =
560+
builder.create<arith::AddIOp>(loc, inductVar, intermediateVal);
561+
}
562+
// get aggregated loop bound
563+
for (auto [lb, ub, step] : llvm::zip(lbs, ubs, steps)) {
564+
std::optional<int64_t> cur_bound = constantTripCount(lb, ub, step);
565+
assert(cur_bound.has_value());
566+
innerLoopUpperBound *= cur_bound.value();
567+
}
568+
}
569+
}
570+
571+
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
572+
}
573+
574+
if (!isOuterMostLoop) {
575+
// get original shape size
576+
auto memType = cast<MemRefType>(origAllocOp->getResultTypes().front());
577+
int64_t size = getSizeInBytes(memType);
578+
Value origSize = builder.create<arith::ConstantIndexOp>(loc, size);
579+
Value offsetPerThread =
580+
builder.create<arith::MulIOp>(loc, inductVar, origSize);
581+
Value byteShiftPerThread =
582+
builder.create<arith::AddIOp>(loc, byteShift, offsetPerThread);
583+
584+
return builder.create<memref::ViewOp>(
585+
loc, origAllocOp->getResultTypes().front(), mergedAlloc,
586+
byteShiftPerThread, ValueRange{});
587+
} else {
588+
return builder.create<memref::ViewOp>(loc,
589+
origAllocOp->getResultTypes().front(),
590+
mergedAlloc, byteShift, ValueRange{});
591+
}
399592
}
400593

401594
LogicalResult

test/mlir/test/gc/Transforms/buffer-merge-lifetime.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ func.func @alias_ref(%pred : i1) {
113113
// CHECK-DAG: func.func @escape_from_if() attributes {__mergealloc_scope = [[TOPSCOPE5:[0-9]+]]
114114
func.func @escape_from_if() {
115115
%ctrue = arith.constant 1 : i1
116-
// check that f lives at the whole range of the following scf.if
116+
// check that f lives at the whole range of the following scf.if
117117
// CHECK-DAG: %[[F:.*]] = memref.alloc() {__mergealloc_lifetime = array<i64: [[TOPSCOPE5]], 4, 13>}
118118
%f = memref.alloc() : memref<8x64xf32>
119119
// tick of the scf.if starts from 4 and ends at 14

0 commit comments

Comments
 (0)