Skip to content

Commit 29a3e3d

Browse files
committed
[OpenMPOpt] Expand SPMDization with guarding for target parallel regions
This patch expands SPMDization (converting generic execution mode to SPMD for target regions) by guarding code regions that should be executed only by the main thread. Specifically, it generates guarded regions, which only the main thread executes, and the synchronization with worker threads using simple barriers. For correctness, the patch aborts SPMDization for target regions if the same code executes in a parallel region, thus must be not be guarded. This check is implemented using the ParallelLevels AA. Reviewed By: jhuber6 Differential Revision: https://reviews.llvm.org/D106892
1 parent c234051 commit 29a3e3d

File tree

7 files changed

+1366
-575
lines changed

7 files changed

+1366
-575
lines changed

llvm/lib/Transforms/IPO/OpenMPOpt.cpp

Lines changed: 214 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ struct KernelInfoState : AbstractState {
523523
/// State to track if we are in SPMD-mode, assumed or know, and why we decided
524524
/// we cannot be. If it is assumed, then RequiresFullRuntime should also be
525525
/// false.
526-
BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker;
526+
BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
527527

528528
/// The __kmpc_target_init call in this kernel, if any. If we find more than
529529
/// one we abort as the kernel is malformed.
@@ -2821,6 +2821,12 @@ struct AAKernelInfoFunction : AAKernelInfo {
28212821
AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
28222822
: AAKernelInfo(IRP, A) {}
28232823

2824+
SmallPtrSet<Instruction *, 4> GuardedInstructions;
2825+
2826+
SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2827+
return GuardedInstructions;
2828+
}
2829+
28242830
/// See AbstractAttribute::initialize(...).
28252831
void initialize(Attributor &A) override {
28262832
// This is a high-level transform that might change the constant arguments
@@ -3021,6 +3027,188 @@ struct AAKernelInfoFunction : AAKernelInfo {
30213027
return false;
30223028
}
30233029

3030+
auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3031+
Instruction *RegionEndI) {
3032+
LoopInfo *LI = nullptr;
3033+
DominatorTree *DT = nullptr;
3034+
MemorySSAUpdater *MSU = nullptr;
3035+
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3036+
3037+
BasicBlock *ParentBB = RegionStartI->getParent();
3038+
Function *Fn = ParentBB->getParent();
3039+
Module &M = *Fn->getParent();
3040+
3041+
// Create all the blocks and logic.
3042+
// ParentBB:
3043+
// goto RegionCheckTidBB
3044+
// RegionCheckTidBB:
3045+
// Tid = __kmpc_hardware_thread_id()
3046+
// if (Tid != 0)
3047+
// goto RegionBarrierBB
3048+
// RegionStartBB:
3049+
// <execute instructions guarded>
3050+
// goto RegionEndBB
3051+
// RegionEndBB:
3052+
// <store escaping values to shared mem>
3053+
// goto RegionBarrierBB
3054+
// RegionBarrierBB:
3055+
// __kmpc_simple_barrier_spmd()
3056+
// // second barrier is omitted if lacking escaping values.
3057+
// <load escaping values from shared mem>
3058+
// __kmpc_simple_barrier_spmd()
3059+
// goto RegionExitBB
3060+
// RegionExitBB:
3061+
// <execute rest of instructions>
3062+
3063+
BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3064+
DT, LI, MSU, "region.guarded.end");
3065+
BasicBlock *RegionBarrierBB =
3066+
SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3067+
MSU, "region.barrier");
3068+
BasicBlock *RegionExitBB =
3069+
SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3070+
DT, LI, MSU, "region.exit");
3071+
BasicBlock *RegionStartBB =
3072+
SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3073+
3074+
assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3075+
"Expected a different CFG");
3076+
3077+
BasicBlock *RegionCheckTidBB = SplitBlock(
3078+
ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3079+
3080+
// Register basic blocks with the Attributor.
3081+
A.registerManifestAddedBasicBlock(*RegionEndBB);
3082+
A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3083+
A.registerManifestAddedBasicBlock(*RegionExitBB);
3084+
A.registerManifestAddedBasicBlock(*RegionStartBB);
3085+
A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3086+
3087+
bool HasBroadcastValues = false;
3088+
// Find escaping outputs from the guarded region to outside users and
3089+
// broadcast their values to them.
3090+
for (Instruction &I : *RegionStartBB) {
3091+
SmallPtrSet<Instruction *, 4> OutsideUsers;
3092+
for (User *Usr : I.users()) {
3093+
Instruction &UsrI = *cast<Instruction>(Usr);
3094+
if (UsrI.getParent() != RegionStartBB)
3095+
OutsideUsers.insert(&UsrI);
3096+
}
3097+
3098+
if (OutsideUsers.empty())
3099+
continue;
3100+
3101+
HasBroadcastValues = true;
3102+
3103+
// Emit a global variable in shared memory to store the broadcasted
3104+
// value.
3105+
auto *SharedMem = new GlobalVariable(
3106+
M, I.getType(), /* IsConstant */ false,
3107+
GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3108+
I.getName() + ".guarded.output.alloc", nullptr,
3109+
GlobalValue::NotThreadLocal,
3110+
static_cast<unsigned>(AddressSpace::Shared));
3111+
3112+
// Emit a store instruction to update the value.
3113+
new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3114+
3115+
LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3116+
I.getName() + ".guarded.output.load",
3117+
RegionBarrierBB->getTerminator());
3118+
3119+
// Emit a load instruction and replace uses of the output value.
3120+
for (Instruction *UsrI : OutsideUsers) {
3121+
assert(UsrI->getParent() == RegionExitBB &&
3122+
"Expected escaping users in exit region");
3123+
UsrI->replaceUsesOfWith(&I, LoadI);
3124+
}
3125+
}
3126+
3127+
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3128+
3129+
// Go to tid check BB in ParentBB.
3130+
const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3131+
ParentBB->getTerminator()->eraseFromParent();
3132+
OpenMPIRBuilder::LocationDescription Loc(
3133+
InsertPointTy(ParentBB, ParentBB->end()), DL);
3134+
OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3135+
auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
3136+
Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
3137+
BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3138+
3139+
// Add check for Tid in RegionCheckTidBB
3140+
RegionCheckTidBB->getTerminator()->eraseFromParent();
3141+
OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3142+
InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3143+
OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3144+
FunctionCallee HardwareTidFn =
3145+
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3146+
M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3147+
Value *Tid =
3148+
OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3149+
Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3150+
OMPInfoCache.OMPBuilder.Builder
3151+
.CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3152+
->setDebugLoc(DL);
3153+
3154+
// First barrier for synchronization, ensures main thread has updated
3155+
// values.
3156+
FunctionCallee BarrierFn =
3157+
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3158+
M, OMPRTL___kmpc_barrier_simple_spmd);
3159+
OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3160+
RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3161+
OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
3162+
->setDebugLoc(DL);
3163+
3164+
// Second barrier ensures workers have read broadcast values.
3165+
if (HasBroadcastValues)
3166+
CallInst::Create(BarrierFn, {Ident, Tid}, "",
3167+
RegionBarrierBB->getTerminator())
3168+
->setDebugLoc(DL);
3169+
};
3170+
3171+
SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3172+
3173+
for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3174+
BasicBlock *BB = GuardedI->getParent();
3175+
auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3176+
IRPosition::function(*GuardedI->getFunction()), nullptr,
3177+
DepClassTy::NONE);
3178+
assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3179+
auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3180+
// Continue if instruction is already guarded.
3181+
if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3182+
continue;
3183+
3184+
Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3185+
for (Instruction &I : *BB) {
3186+
// If instruction I needs to be guarded update the guarded region
3187+
// bounds.
3188+
if (SPMDCompatibilityTracker.contains(&I)) {
3189+
CalleeAAFunction.getGuardedInstructions().insert(&I);
3190+
if (GuardedRegionStart)
3191+
GuardedRegionEnd = &I;
3192+
else
3193+
GuardedRegionStart = GuardedRegionEnd = &I;
3194+
3195+
continue;
3196+
}
3197+
3198+
// Instruction I does not need guarding, store
3199+
// any region found and reset bounds.
3200+
if (GuardedRegionStart) {
3201+
GuardedRegions.push_back(
3202+
std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3203+
GuardedRegionStart = nullptr;
3204+
GuardedRegionEnd = nullptr;
3205+
}
3206+
}
3207+
}
3208+
3209+
for (auto &GR : GuardedRegions)
3210+
CreateGuardedRegion(GR.first, GR.second);
3211+
30243212
// Adjust the global exec mode flag that tells the runtime what mode this
30253213
// kernel is executed in.
30263214
Function *Kernel = getAnchorScope();
@@ -3356,8 +3544,21 @@ struct AAKernelInfoFunction : AAKernelInfo {
33563544
if (llvm::all_of(Objects,
33573545
[](const Value *Obj) { return isa<AllocaInst>(Obj); }))
33583546
return true;
3547+
// Check for AAHeapToStack moved objects which must not be guarded.
3548+
auto &HS = A.getAAFor<AAHeapToStack>(
3549+
*this, IRPosition::function(*I.getFunction()),
3550+
DepClassTy::REQUIRED);
3551+
if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3552+
auto *CB = dyn_cast<CallBase>(Obj);
3553+
if (!CB)
3554+
return false;
3555+
return HS.isAssumedHeapToStack(*CB);
3556+
})) {
3557+
return true;
3558+
}
33593559
}
3360-
// For now we give up on everything but stores.
3560+
3561+
// Insert instruction that needs guarding.
33613562
SPMDCompatibilityTracker.insert(&I);
33623563
return true;
33633564
};
@@ -3371,6 +3572,9 @@ struct AAKernelInfoFunction : AAKernelInfo {
33713572
if (!IsKernelEntry) {
33723573
updateReachingKernelEntries(A);
33733574
updateParallelLevels(A);
3575+
3576+
if (!ParallelLevels.isValidState())
3577+
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
33743578
}
33753579

33763580
// Callback to check a call instruction.
@@ -3521,8 +3725,10 @@ struct AAKernelInfoCallSite : AAKernelInfo {
35213725

35223726
// If SPMDCompatibilityTracker is not fixed, we need to give up on the
35233727
// idea we can run something unknown in SPMD-mode.
3524-
if (!SPMDCompatibilityTracker.isAtFixpoint())
3728+
if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3729+
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
35253730
SPMDCompatibilityTracker.insert(&CB);
3731+
}
35263732

35273733
// We have updated the state for this unknown call properly, there won't
35283734
// be any change so we indicate a fixpoint.
@@ -3565,6 +3771,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
35653771
case OMPScheduleType::DistributeChunked:
35663772
break;
35673773
default:
3774+
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
35683775
SPMDCompatibilityTracker.insert(&CB);
35693776
break;
35703777
};
@@ -3590,7 +3797,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
35903797
// We do not look into tasks right now, just give up.
35913798
SPMDCompatibilityTracker.insert(&CB);
35923799
ReachedUnknownParallelRegions.insert(&CB);
3593-
break;
3800+
indicatePessimisticFixpoint();
3801+
return;
35943802
case OMPRTL___kmpc_alloc_shared:
35953803
case OMPRTL___kmpc_free_shared:
35963804
// Return without setting a fixpoint, to be resolved in updateImpl.
@@ -3599,7 +3807,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
35993807
// Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
36003808
// generally.
36013809
SPMDCompatibilityTracker.insert(&CB);
3602-
break;
3810+
indicatePessimisticFixpoint();
3811+
return;
36033812
}
36043813
// All other OpenMP runtime calls will not reach parallel regions so they
36053814
// can be safely ignored for now. Since it is a known OpenMP runtime call we

0 commit comments

Comments
 (0)