Skip to content

Commit 3921034

Browse files
author
git apple-llvm automerger
committed
Merge commit '29a3e3dd7bed' from llvm.org/main into next
2 parents d04cec4 + 29a3e3d commit 3921034

File tree

7 files changed

+1366
-575
lines changed

7 files changed

+1366
-575
lines changed

llvm/lib/Transforms/IPO/OpenMPOpt.cpp

Lines changed: 214 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ struct KernelInfoState : AbstractState {
523523
/// State to track if we are in SPMD-mode, assumed or know, and why we decided
524524
/// we cannot be. If it is assumed, then RequiresFullRuntime should also be
525525
/// false.
526-
BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker;
526+
BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
527527

528528
/// The __kmpc_target_init call in this kernel, if any. If we find more than
529529
/// one we abort as the kernel is malformed.
@@ -2821,6 +2821,12 @@ struct AAKernelInfoFunction : AAKernelInfo {
28212821
AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
28222822
: AAKernelInfo(IRP, A) {}
28232823

2824+
SmallPtrSet<Instruction *, 4> GuardedInstructions;
2825+
2826+
SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2827+
return GuardedInstructions;
2828+
}
2829+
28242830
/// See AbstractAttribute::initialize(...).
28252831
void initialize(Attributor &A) override {
28262832
// This is a high-level transform that might change the constant arguments
@@ -3021,6 +3027,188 @@ struct AAKernelInfoFunction : AAKernelInfo {
30213027
return false;
30223028
}
30233029

3030+
auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3031+
Instruction *RegionEndI) {
3032+
LoopInfo *LI = nullptr;
3033+
DominatorTree *DT = nullptr;
3034+
MemorySSAUpdater *MSU = nullptr;
3035+
using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3036+
3037+
BasicBlock *ParentBB = RegionStartI->getParent();
3038+
Function *Fn = ParentBB->getParent();
3039+
Module &M = *Fn->getParent();
3040+
3041+
// Create all the blocks and logic.
3042+
// ParentBB:
3043+
// goto RegionCheckTidBB
3044+
// RegionCheckTidBB:
3045+
// Tid = __kmpc_hardware_thread_id()
3046+
// if (Tid != 0)
3047+
// goto RegionBarrierBB
3048+
// RegionStartBB:
3049+
// <execute instructions guarded>
3050+
// goto RegionEndBB
3051+
// RegionEndBB:
3052+
// <store escaping values to shared mem>
3053+
// goto RegionBarrierBB
3054+
// RegionBarrierBB:
3055+
// __kmpc_simple_barrier_spmd()
3056+
// // second barrier is omitted if lacking escaping values.
3057+
// <load escaping values from shared mem>
3058+
// __kmpc_simple_barrier_spmd()
3059+
// goto RegionExitBB
3060+
// RegionExitBB:
3061+
// <execute rest of instructions>
3062+
3063+
BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3064+
DT, LI, MSU, "region.guarded.end");
3065+
BasicBlock *RegionBarrierBB =
3066+
SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3067+
MSU, "region.barrier");
3068+
BasicBlock *RegionExitBB =
3069+
SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3070+
DT, LI, MSU, "region.exit");
3071+
BasicBlock *RegionStartBB =
3072+
SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3073+
3074+
assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3075+
"Expected a different CFG");
3076+
3077+
BasicBlock *RegionCheckTidBB = SplitBlock(
3078+
ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3079+
3080+
// Register basic blocks with the Attributor.
3081+
A.registerManifestAddedBasicBlock(*RegionEndBB);
3082+
A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3083+
A.registerManifestAddedBasicBlock(*RegionExitBB);
3084+
A.registerManifestAddedBasicBlock(*RegionStartBB);
3085+
A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3086+
3087+
bool HasBroadcastValues = false;
3088+
// Find escaping outputs from the guarded region to outside users and
3089+
// broadcast their values to them.
3090+
for (Instruction &I : *RegionStartBB) {
3091+
SmallPtrSet<Instruction *, 4> OutsideUsers;
3092+
for (User *Usr : I.users()) {
3093+
Instruction &UsrI = *cast<Instruction>(Usr);
3094+
if (UsrI.getParent() != RegionStartBB)
3095+
OutsideUsers.insert(&UsrI);
3096+
}
3097+
3098+
if (OutsideUsers.empty())
3099+
continue;
3100+
3101+
HasBroadcastValues = true;
3102+
3103+
// Emit a global variable in shared memory to store the broadcasted
3104+
// value.
3105+
auto *SharedMem = new GlobalVariable(
3106+
M, I.getType(), /* IsConstant */ false,
3107+
GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3108+
I.getName() + ".guarded.output.alloc", nullptr,
3109+
GlobalValue::NotThreadLocal,
3110+
static_cast<unsigned>(AddressSpace::Shared));
3111+
3112+
// Emit a store instruction to update the value.
3113+
new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3114+
3115+
LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3116+
I.getName() + ".guarded.output.load",
3117+
RegionBarrierBB->getTerminator());
3118+
3119+
// Emit a load instruction and replace uses of the output value.
3120+
for (Instruction *UsrI : OutsideUsers) {
3121+
assert(UsrI->getParent() == RegionExitBB &&
3122+
"Expected escaping users in exit region");
3123+
UsrI->replaceUsesOfWith(&I, LoadI);
3124+
}
3125+
}
3126+
3127+
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3128+
3129+
// Go to tid check BB in ParentBB.
3130+
const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3131+
ParentBB->getTerminator()->eraseFromParent();
3132+
OpenMPIRBuilder::LocationDescription Loc(
3133+
InsertPointTy(ParentBB, ParentBB->end()), DL);
3134+
OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3135+
auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
3136+
Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
3137+
BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3138+
3139+
// Add check for Tid in RegionCheckTidBB
3140+
RegionCheckTidBB->getTerminator()->eraseFromParent();
3141+
OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3142+
InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3143+
OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3144+
FunctionCallee HardwareTidFn =
3145+
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3146+
M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3147+
Value *Tid =
3148+
OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3149+
Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3150+
OMPInfoCache.OMPBuilder.Builder
3151+
.CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3152+
->setDebugLoc(DL);
3153+
3154+
// First barrier for synchronization, ensures main thread has updated
3155+
// values.
3156+
FunctionCallee BarrierFn =
3157+
OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3158+
M, OMPRTL___kmpc_barrier_simple_spmd);
3159+
OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3160+
RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3161+
OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
3162+
->setDebugLoc(DL);
3163+
3164+
// Second barrier ensures workers have read broadcast values.
3165+
if (HasBroadcastValues)
3166+
CallInst::Create(BarrierFn, {Ident, Tid}, "",
3167+
RegionBarrierBB->getTerminator())
3168+
->setDebugLoc(DL);
3169+
};
3170+
3171+
SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3172+
3173+
for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3174+
BasicBlock *BB = GuardedI->getParent();
3175+
auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3176+
IRPosition::function(*GuardedI->getFunction()), nullptr,
3177+
DepClassTy::NONE);
3178+
assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3179+
auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3180+
// Continue if instruction is already guarded.
3181+
if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3182+
continue;
3183+
3184+
Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3185+
for (Instruction &I : *BB) {
3186+
// If instruction I needs to be guarded update the guarded region
3187+
// bounds.
3188+
if (SPMDCompatibilityTracker.contains(&I)) {
3189+
CalleeAAFunction.getGuardedInstructions().insert(&I);
3190+
if (GuardedRegionStart)
3191+
GuardedRegionEnd = &I;
3192+
else
3193+
GuardedRegionStart = GuardedRegionEnd = &I;
3194+
3195+
continue;
3196+
}
3197+
3198+
// Instruction I does not need guarding, store
3199+
// any region found and reset bounds.
3200+
if (GuardedRegionStart) {
3201+
GuardedRegions.push_back(
3202+
std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3203+
GuardedRegionStart = nullptr;
3204+
GuardedRegionEnd = nullptr;
3205+
}
3206+
}
3207+
}
3208+
3209+
for (auto &GR : GuardedRegions)
3210+
CreateGuardedRegion(GR.first, GR.second);
3211+
30243212
// Adjust the global exec mode flag that tells the runtime what mode this
30253213
// kernel is executed in.
30263214
Function *Kernel = getAnchorScope();
@@ -3356,8 +3544,21 @@ struct AAKernelInfoFunction : AAKernelInfo {
33563544
if (llvm::all_of(Objects,
33573545
[](const Value *Obj) { return isa<AllocaInst>(Obj); }))
33583546
return true;
3547+
// Check for AAHeapToStack moved objects which must not be guarded.
3548+
auto &HS = A.getAAFor<AAHeapToStack>(
3549+
*this, IRPosition::function(*I.getFunction()),
3550+
DepClassTy::REQUIRED);
3551+
if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3552+
auto *CB = dyn_cast<CallBase>(Obj);
3553+
if (!CB)
3554+
return false;
3555+
return HS.isAssumedHeapToStack(*CB);
3556+
})) {
3557+
return true;
3558+
}
33593559
}
3360-
// For now we give up on everything but stores.
3560+
3561+
// Insert instruction that needs guarding.
33613562
SPMDCompatibilityTracker.insert(&I);
33623563
return true;
33633564
};
@@ -3371,6 +3572,9 @@ struct AAKernelInfoFunction : AAKernelInfo {
33713572
if (!IsKernelEntry) {
33723573
updateReachingKernelEntries(A);
33733574
updateParallelLevels(A);
3575+
3576+
if (!ParallelLevels.isValidState())
3577+
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
33743578
}
33753579

33763580
// Callback to check a call instruction.
@@ -3521,8 +3725,10 @@ struct AAKernelInfoCallSite : AAKernelInfo {
35213725

35223726
// If SPMDCompatibilityTracker is not fixed, we need to give up on the
35233727
// idea we can run something unknown in SPMD-mode.
3524-
if (!SPMDCompatibilityTracker.isAtFixpoint())
3728+
if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3729+
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
35253730
SPMDCompatibilityTracker.insert(&CB);
3731+
}
35263732

35273733
// We have updated the state for this unknown call properly, there won't
35283734
// be any change so we indicate a fixpoint.
@@ -3565,6 +3771,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
35653771
case OMPScheduleType::DistributeChunked:
35663772
break;
35673773
default:
3774+
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
35683775
SPMDCompatibilityTracker.insert(&CB);
35693776
break;
35703777
};
@@ -3590,7 +3797,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
35903797
// We do not look into tasks right now, just give up.
35913798
SPMDCompatibilityTracker.insert(&CB);
35923799
ReachedUnknownParallelRegions.insert(&CB);
3593-
break;
3800+
indicatePessimisticFixpoint();
3801+
return;
35943802
case OMPRTL___kmpc_alloc_shared:
35953803
case OMPRTL___kmpc_free_shared:
35963804
// Return without setting a fixpoint, to be resolved in updateImpl.
@@ -3599,7 +3807,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
35993807
// Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
36003808
// generally.
36013809
SPMDCompatibilityTracker.insert(&CB);
3602-
break;
3810+
indicatePessimisticFixpoint();
3811+
return;
36033812
}
36043813
// All other OpenMP runtime calls will not reach parallel regions so they
36053814
// can be safely ignored for now. Since it is a known OpenMP runtime call we

0 commit comments

Comments
 (0)