-
Notifications
You must be signed in to change notification settings - Fork 17
[Transform] Hoist thread-local allocator within the nested parallel loops #283
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 35 commits
45a02a4
1615878
d1225c0
d107580
eaf2667
f868629
504c785
ef3d150
21fe5fa
5db4be4
84a17ac
87a7fb6
42d612b
115fd66
26adb18
f93f1d2
36354ea
82ab370
1d3b887
e285e99
5990627
eecc19f
b3541f0
57ddeab
e5a2f83
66309fc
5953b13
ddf69b0
8401ca0
43f27a7
75a88aa
a81249d
8f5325d
5c8fdf4
d2c7cdc
fdede5b
4708126
cc25380
514664e
ac8452f
99a2c34
1e80344
8daedf9
991e1fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,69 @@ using namespace special_ticks; | |
/// and default memory space. | ||
static bool isMemRefTypeOk(MemRefType type) { return type.hasStaticShape(); } | ||
|
||
static inline int64_t getSizeInBytes(MemRefType &memType) { | ||
// treat bool (i1) as 1 byte. It may not be true for all targets, but we at | ||
// least have a large enough size for i1 | ||
int64_t size = memType.getElementTypeBitWidth() / 8; | ||
size = (size > 0) ? size : 1; | ||
for (auto v : memType.getShape()) { | ||
size *= v; | ||
} | ||
return size; | ||
} | ||
|
||
static bool needsHoistOutOfParallelLoop(Operation *op) { | ||
Operation *parent = | ||
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); | ||
if (parent && isa<scf::ForallOp>(parent)) { | ||
ciyongch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// check if the current allocation is between the nested pfor, and use | ||
// inside the inner parallel loop | ||
SmallVector<Operation *, 4> parallelOpInCurBlock; | ||
Block *curBlock = op->getBlock(); | ||
for (auto &curOp : curBlock->getOperations()) { | ||
if (isa<scf::ForallOp>(curOp)) { | ||
parallelOpInCurBlock.push_back(&curOp); | ||
} | ||
} | ||
|
||
if (parallelOpInCurBlock.empty()) | ||
return false; | ||
|
||
for (auto *use : op->getUsers()) { | ||
for (auto *parallelOp : parallelOpInCurBlock) { | ||
if (parallelOp->isAncestor(use)) { | ||
return true; | ||
} | ||
} | ||
} | ||
} | ||
|
||
return false; | ||
} | ||
|
||
static bool isForallLoopBoundStatic(Operation *op) { | ||
bool isStatic = true; | ||
if (auto forallOp = dyn_cast<scf::ForallOp>(op)) { | ||
OpBuilder builder{forallOp->getContext()}; | ||
SmallVector<Value> upperBounds = forallOp.getUpperBound(builder); | ||
SmallVector<Value> lowerBounds = forallOp.getLowerBound(builder); | ||
ciyongch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
isStatic &= llvm::all_of(upperBounds, [](Value &ub) { | ||
return (ub.getDefiningOp() && | ||
isa<arith::ConstantIndexOp>(ub.getDefiningOp())); | ||
}); | ||
|
||
isStatic &= llvm::all_of(lowerBounds, [](Value &lb) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we also check the step? I am also not sure if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I can add the check for the step, and use early return when the expression return false. |
||
return (lb.getDefiningOp() && | ||
isa<arith::ConstantIndexOp>(lb.getDefiningOp())); | ||
}); | ||
|
||
return isStatic; | ||
} else { | ||
return false; | ||
} | ||
ciyongch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
void Tick::update(int64_t tick) { | ||
if (tick == UNTRACEABLE_ACCESS) { | ||
firstAccess = UNTRACEABLE_ACCESS; | ||
|
@@ -180,28 +243,56 @@ bool TickCollecter::isMergeableAlloc(TickCollecterStates *s, Operation *op, | |
// trait, and is not scf.for | ||
Operation *TickCollecter::getAllocScope(TickCollecterStates *s, | ||
Operation *op) const { | ||
auto parent = op; | ||
Operation *parent = op; | ||
bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop(op); | ||
|
||
for (;;) { | ||
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); | ||
if (!parent) { | ||
return nullptr; | ||
} | ||
if (!isa<scf::ForOp>(parent)) { | ||
return parent; | ||
} | ||
|
||
if (isa<scf::ForOp>(parent)) | ||
continue; | ||
|
||
if (isa<scf::ForallOp>(parent) && | ||
(moveToUpperParellelLoop && isForallLoopBoundStatic(parent))) | ||
continue; | ||
|
||
return parent; | ||
} | ||
} | ||
|
||
FailureOr<size_t> TickCollecter::getAllocSize(TickCollecterStates *s, | ||
Operation *op) const { | ||
auto refType = cast<MemRefType>(op->getResultTypes().front()); | ||
int64_t size = refType.getElementTypeBitWidth() / 8; | ||
// treat bool (i1) as 1 byte. It may not be true for all targets, but we at | ||
// least have a large enough size for i1 | ||
size = (size != 0) ? size : 1; | ||
for (auto v : refType.getShape()) { | ||
size *= v; | ||
|
||
// Get the total number of threads from the outermost to the current level of | ||
// the parallel loop that the allocation located in. | ||
int64_t numThreads = 1; | ||
zhczhong marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (needsHoistOutOfParallelLoop(op)) { | ||
Operation *parent = | ||
op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); | ||
while (auto forallOp = dyn_cast<scf::ForallOp>(parent)) { | ||
if (!isForallLoopBoundStatic(forallOp)) | ||
break; | ||
|
||
OpBuilder builder{forallOp->getContext()}; | ||
SmallVector<Value> ubs = forallOp.getUpperBound(builder); | ||
if (std::optional<int64_t> ubs0_int = getConstantIntValue(ubs[0])) { | ||
int64_t innerLoopUpperBound = ubs0_int.value(); | ||
numThreads *= innerLoopUpperBound; | ||
Yun-Fly marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} else { | ||
op->emitError("Expecting static loop range!"); | ||
} | ||
|
||
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); | ||
} | ||
} | ||
assert(numThreads > 0); | ||
|
||
int64_t size = getSizeInBytes(refType); | ||
size *= numThreads; | ||
if (size > 0) { | ||
return static_cast<size_t>(size); | ||
} | ||
|
@@ -391,11 +482,63 @@ Value MergeAllocDefaultMutator::buildView(OpBuilder &builder, Block *scope, | |
Value mergedAlloc, | ||
int64_t byteOffset) const { | ||
builder.setInsertionPoint(origAllocOp); | ||
auto byteShift = | ||
builder.create<arith::ConstantIndexOp>(origAllocOp->getLoc(), byteOffset); | ||
return builder.create<memref::ViewOp>(origAllocOp->getLoc(), | ||
origAllocOp->getResultTypes().front(), | ||
mergedAlloc, byteShift, ValueRange{}); | ||
auto loc = origAllocOp->getLoc(); | ||
auto byteShift = builder.create<arith::ConstantIndexOp>(loc, byteOffset); | ||
|
||
bool moveToUpperParellelLoop = needsHoistOutOfParallelLoop(origAllocOp); | ||
Operation *parent = | ||
origAllocOp->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); | ||
if (!moveToUpperParellelLoop || !parent || !isa<scf::ForallOp>(parent)) | ||
return builder.create<memref::ViewOp>(loc, | ||
origAllocOp->getResultTypes().front(), | ||
mergedAlloc, byteShift, ValueRange{}); | ||
|
||
// get the aggregated inductorVar | ||
Value inductVar; | ||
bool isOuterMostLoop = true; | ||
int64_t innerLoopUpperBound = 1; | ||
while (parent) { | ||
if (auto forallOp = dyn_cast<scf::ForallOp>(parent)) { | ||
if (isForallLoopBoundStatic(forallOp)) { | ||
SmallVector<Value> upperBounds = forallOp.getUpperBound(builder); | ||
ciyongch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (std::optional<int64_t> ubs0_int = | ||
getConstantIntValue(upperBounds[0])) { | ||
if (isOuterMostLoop) { | ||
inductVar = forallOp.getInductionVar(0); | ||
isOuterMostLoop = false; | ||
} else { | ||
Value innerLoopBoundVal = builder.create<arith::ConstantIndexOp>( | ||
loc, innerLoopUpperBound); | ||
Value intermediateVal = builder.create<arith::MulIOp>( | ||
loc, forallOp.getInductionVar(0), innerLoopBoundVal); | ||
inductVar = | ||
builder.create<arith::AddIOp>(loc, inductVar, intermediateVal); | ||
} | ||
innerLoopUpperBound = ubs0_int.value(); | ||
} | ||
} | ||
} | ||
|
||
parent = parent->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); | ||
} | ||
|
||
if (!isOuterMostLoop) { | ||
// get original shape size | ||
auto memType = cast<MemRefType>(origAllocOp->getResultTypes().front()); | ||
int64_t size = getSizeInBytes(memType); | ||
Value origSize = builder.create<arith::ConstantIndexOp>(loc, size); | ||
Value offsetPerThread = | ||
builder.create<arith::MulIOp>(loc, inductVar, origSize); | ||
Value byteShiftPerThread = | ||
builder.create<arith::AddIOp>(loc, byteShift, offsetPerThread); | ||
|
||
return builder.create<memref::ViewOp>( | ||
loc, origAllocOp->getResultTypes().front(), mergedAlloc, | ||
byteShiftPerThread, ValueRange{}); | ||
} else | ||
return builder.create<memref::ViewOp>(loc, | ||
ciyongch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
origAllocOp->getResultTypes().front(), | ||
mergedAlloc, byteShift, ValueRange{}); | ||
} | ||
|
||
LogicalResult | ||
|
Uh oh!
There was an error while loading. Please reload this page.