@@ -523,7 +523,7 @@ struct KernelInfoState : AbstractState {
523
523
// / State to track if we are in SPMD-mode, assumed or know, and why we decided
524
524
// / we cannot be. If it is assumed, then RequiresFullRuntime should also be
525
525
// / false.
526
- BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker;
526
+ BooleanStateWithPtrSetVector<Instruction, false > SPMDCompatibilityTracker;
527
527
528
528
// / The __kmpc_target_init call in this kernel, if any. If we find more than
529
529
// / one we abort as the kernel is malformed.
@@ -2821,6 +2821,12 @@ struct AAKernelInfoFunction : AAKernelInfo {
2821
2821
AAKernelInfoFunction (const IRPosition &IRP, Attributor &A)
2822
2822
: AAKernelInfo(IRP, A) {}
2823
2823
2824
+ SmallPtrSet<Instruction *, 4 > GuardedInstructions;
2825
+
2826
+ SmallPtrSetImpl<Instruction *> &getGuardedInstructions () {
2827
+ return GuardedInstructions;
2828
+ }
2829
+
2824
2830
// / See AbstractAttribute::initialize(...).
2825
2831
void initialize (Attributor &A) override {
2826
2832
// This is a high-level transform that might change the constant arguments
@@ -3021,6 +3027,188 @@ struct AAKernelInfoFunction : AAKernelInfo {
3021
3027
return false ;
3022
3028
}
3023
3029
3030
+ auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3031
+ Instruction *RegionEndI) {
3032
+ LoopInfo *LI = nullptr ;
3033
+ DominatorTree *DT = nullptr ;
3034
+ MemorySSAUpdater *MSU = nullptr ;
3035
+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3036
+
3037
+ BasicBlock *ParentBB = RegionStartI->getParent ();
3038
+ Function *Fn = ParentBB->getParent ();
3039
+ Module &M = *Fn->getParent ();
3040
+
3041
+ // Create all the blocks and logic.
3042
+ // ParentBB:
3043
+ // goto RegionCheckTidBB
3044
+ // RegionCheckTidBB:
3045
+ // Tid = __kmpc_hardware_thread_id()
3046
+ // if (Tid != 0)
3047
+ // goto RegionBarrierBB
3048
+ // RegionStartBB:
3049
+ // <execute instructions guarded>
3050
+ // goto RegionEndBB
3051
+ // RegionEndBB:
3052
+ // <store escaping values to shared mem>
3053
+ // goto RegionBarrierBB
3054
+ // RegionBarrierBB:
3055
+ // __kmpc_simple_barrier_spmd()
3056
+ // // second barrier is omitted if lacking escaping values.
3057
+ // <load escaping values from shared mem>
3058
+ // __kmpc_simple_barrier_spmd()
3059
+ // goto RegionExitBB
3060
+ // RegionExitBB:
3061
+ // <execute rest of instructions>
3062
+
3063
+ BasicBlock *RegionEndBB = SplitBlock (ParentBB, RegionEndI->getNextNode (),
3064
+ DT, LI, MSU, " region.guarded.end" );
3065
+ BasicBlock *RegionBarrierBB =
3066
+ SplitBlock (RegionEndBB, &*RegionEndBB->getFirstInsertionPt (), DT, LI,
3067
+ MSU, " region.barrier" );
3068
+ BasicBlock *RegionExitBB =
3069
+ SplitBlock (RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt (),
3070
+ DT, LI, MSU, " region.exit" );
3071
+ BasicBlock *RegionStartBB =
3072
+ SplitBlock (ParentBB, RegionStartI, DT, LI, MSU, " region.guarded" );
3073
+
3074
+ assert (ParentBB->getUniqueSuccessor () == RegionStartBB &&
3075
+ " Expected a different CFG" );
3076
+
3077
+ BasicBlock *RegionCheckTidBB = SplitBlock (
3078
+ ParentBB, ParentBB->getTerminator (), DT, LI, MSU, " region.check.tid" );
3079
+
3080
+ // Register basic blocks with the Attributor.
3081
+ A.registerManifestAddedBasicBlock (*RegionEndBB);
3082
+ A.registerManifestAddedBasicBlock (*RegionBarrierBB);
3083
+ A.registerManifestAddedBasicBlock (*RegionExitBB);
3084
+ A.registerManifestAddedBasicBlock (*RegionStartBB);
3085
+ A.registerManifestAddedBasicBlock (*RegionCheckTidBB);
3086
+
3087
+ bool HasBroadcastValues = false ;
3088
+ // Find escaping outputs from the guarded region to outside users and
3089
+ // broadcast their values to them.
3090
+ for (Instruction &I : *RegionStartBB) {
3091
+ SmallPtrSet<Instruction *, 4 > OutsideUsers;
3092
+ for (User *Usr : I.users ()) {
3093
+ Instruction &UsrI = *cast<Instruction>(Usr);
3094
+ if (UsrI.getParent () != RegionStartBB)
3095
+ OutsideUsers.insert (&UsrI);
3096
+ }
3097
+
3098
+ if (OutsideUsers.empty ())
3099
+ continue ;
3100
+
3101
+ HasBroadcastValues = true ;
3102
+
3103
+ // Emit a global variable in shared memory to store the broadcasted
3104
+ // value.
3105
+ auto *SharedMem = new GlobalVariable (
3106
+ M, I.getType (), /* IsConstant */ false ,
3107
+ GlobalValue::InternalLinkage, UndefValue::get (I.getType ()),
3108
+ I.getName () + " .guarded.output.alloc" , nullptr ,
3109
+ GlobalValue::NotThreadLocal,
3110
+ static_cast <unsigned >(AddressSpace::Shared));
3111
+
3112
+ // Emit a store instruction to update the value.
3113
+ new StoreInst (&I, SharedMem, RegionEndBB->getTerminator ());
3114
+
3115
+ LoadInst *LoadI = new LoadInst (I.getType (), SharedMem,
3116
+ I.getName () + " .guarded.output.load" ,
3117
+ RegionBarrierBB->getTerminator ());
3118
+
3119
+ // Emit a load instruction and replace uses of the output value.
3120
+ for (Instruction *UsrI : OutsideUsers) {
3121
+ assert (UsrI->getParent () == RegionExitBB &&
3122
+ " Expected escaping users in exit region" );
3123
+ UsrI->replaceUsesOfWith (&I, LoadI);
3124
+ }
3125
+ }
3126
+
3127
+ auto &OMPInfoCache = static_cast <OMPInformationCache &>(A.getInfoCache ());
3128
+
3129
+ // Go to tid check BB in ParentBB.
3130
+ const DebugLoc DL = ParentBB->getTerminator ()->getDebugLoc ();
3131
+ ParentBB->getTerminator ()->eraseFromParent ();
3132
+ OpenMPIRBuilder::LocationDescription Loc (
3133
+ InsertPointTy (ParentBB, ParentBB->end ()), DL);
3134
+ OMPInfoCache.OMPBuilder .updateToLocation (Loc);
3135
+ auto *SrcLocStr = OMPInfoCache.OMPBuilder .getOrCreateSrcLocStr (Loc);
3136
+ Value *Ident = OMPInfoCache.OMPBuilder .getOrCreateIdent (SrcLocStr);
3137
+ BranchInst::Create (RegionCheckTidBB, ParentBB)->setDebugLoc (DL);
3138
+
3139
+ // Add check for Tid in RegionCheckTidBB
3140
+ RegionCheckTidBB->getTerminator ()->eraseFromParent ();
3141
+ OpenMPIRBuilder::LocationDescription LocRegionCheckTid (
3142
+ InsertPointTy (RegionCheckTidBB, RegionCheckTidBB->end ()), DL);
3143
+ OMPInfoCache.OMPBuilder .updateToLocation (LocRegionCheckTid);
3144
+ FunctionCallee HardwareTidFn =
3145
+ OMPInfoCache.OMPBuilder .getOrCreateRuntimeFunction (
3146
+ M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3147
+ Value *Tid =
3148
+ OMPInfoCache.OMPBuilder .Builder .CreateCall (HardwareTidFn, {});
3149
+ Value *TidCheck = OMPInfoCache.OMPBuilder .Builder .CreateIsNull (Tid);
3150
+ OMPInfoCache.OMPBuilder .Builder
3151
+ .CreateCondBr (TidCheck, RegionStartBB, RegionBarrierBB)
3152
+ ->setDebugLoc (DL);
3153
+
3154
+ // First barrier for synchronization, ensures main thread has updated
3155
+ // values.
3156
+ FunctionCallee BarrierFn =
3157
+ OMPInfoCache.OMPBuilder .getOrCreateRuntimeFunction (
3158
+ M, OMPRTL___kmpc_barrier_simple_spmd);
3159
+ OMPInfoCache.OMPBuilder .updateToLocation (InsertPointTy (
3160
+ RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt ()));
3161
+ OMPInfoCache.OMPBuilder .Builder .CreateCall (BarrierFn, {Ident, Tid})
3162
+ ->setDebugLoc (DL);
3163
+
3164
+ // Second barrier ensures workers have read broadcast values.
3165
+ if (HasBroadcastValues)
3166
+ CallInst::Create (BarrierFn, {Ident, Tid}, " " ,
3167
+ RegionBarrierBB->getTerminator ())
3168
+ ->setDebugLoc (DL);
3169
+ };
3170
+
3171
+ SmallVector<std::pair<Instruction *, Instruction *>, 4 > GuardedRegions;
3172
+
3173
+ for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3174
+ BasicBlock *BB = GuardedI->getParent ();
3175
+ auto *CalleeAA = A.lookupAAFor <AAKernelInfo>(
3176
+ IRPosition::function (*GuardedI->getFunction ()), nullptr ,
3177
+ DepClassTy::NONE);
3178
+ assert (CalleeAA != nullptr && " Expected Callee AAKernelInfo" );
3179
+ auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3180
+ // Continue if instruction is already guarded.
3181
+ if (CalleeAAFunction.getGuardedInstructions ().contains (GuardedI))
3182
+ continue ;
3183
+
3184
+ Instruction *GuardedRegionStart = nullptr , *GuardedRegionEnd = nullptr ;
3185
+ for (Instruction &I : *BB) {
3186
+ // If instruction I needs to be guarded update the guarded region
3187
+ // bounds.
3188
+ if (SPMDCompatibilityTracker.contains (&I)) {
3189
+ CalleeAAFunction.getGuardedInstructions ().insert (&I);
3190
+ if (GuardedRegionStart)
3191
+ GuardedRegionEnd = &I;
3192
+ else
3193
+ GuardedRegionStart = GuardedRegionEnd = &I;
3194
+
3195
+ continue ;
3196
+ }
3197
+
3198
+ // Instruction I does not need guarding, store
3199
+ // any region found and reset bounds.
3200
+ if (GuardedRegionStart) {
3201
+ GuardedRegions.push_back (
3202
+ std::make_pair (GuardedRegionStart, GuardedRegionEnd));
3203
+ GuardedRegionStart = nullptr ;
3204
+ GuardedRegionEnd = nullptr ;
3205
+ }
3206
+ }
3207
+ }
3208
+
3209
+ for (auto &GR : GuardedRegions)
3210
+ CreateGuardedRegion (GR.first , GR.second );
3211
+
3024
3212
// Adjust the global exec mode flag that tells the runtime what mode this
3025
3213
// kernel is executed in.
3026
3214
Function *Kernel = getAnchorScope ();
@@ -3356,8 +3544,21 @@ struct AAKernelInfoFunction : AAKernelInfo {
3356
3544
if (llvm::all_of (Objects,
3357
3545
[](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3358
3546
return true ;
3547
+ // Check for AAHeapToStack moved objects which must not be guarded.
3548
+ auto &HS = A.getAAFor <AAHeapToStack>(
3549
+ *this , IRPosition::function (*I.getFunction ()),
3550
+ DepClassTy::REQUIRED);
3551
+ if (llvm::all_of (Objects, [&HS](const Value *Obj) {
3552
+ auto *CB = dyn_cast<CallBase>(Obj);
3553
+ if (!CB)
3554
+ return false ;
3555
+ return HS.isAssumedHeapToStack (*CB);
3556
+ })) {
3557
+ return true ;
3558
+ }
3359
3559
}
3360
- // For now we give up on everything but stores.
3560
+
3561
+ // Insert instruction that needs guarding.
3361
3562
SPMDCompatibilityTracker.insert (&I);
3362
3563
return true ;
3363
3564
};
@@ -3371,6 +3572,9 @@ struct AAKernelInfoFunction : AAKernelInfo {
3371
3572
if (!IsKernelEntry) {
3372
3573
updateReachingKernelEntries (A);
3373
3574
updateParallelLevels (A);
3575
+
3576
+ if (!ParallelLevels.isValidState ())
3577
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint ();
3374
3578
}
3375
3579
3376
3580
// Callback to check a call instruction.
@@ -3521,8 +3725,10 @@ struct AAKernelInfoCallSite : AAKernelInfo {
3521
3725
3522
3726
// If SPMDCompatibilityTracker is not fixed, we need to give up on the
3523
3727
// idea we can run something unknown in SPMD-mode.
3524
- if (!SPMDCompatibilityTracker.isAtFixpoint ())
3728
+ if (!SPMDCompatibilityTracker.isAtFixpoint ()) {
3729
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint ();
3525
3730
SPMDCompatibilityTracker.insert (&CB);
3731
+ }
3526
3732
3527
3733
// We have updated the state for this unknown call properly, there won't
3528
3734
// be any change so we indicate a fixpoint.
@@ -3565,6 +3771,7 @@ struct AAKernelInfoCallSite : AAKernelInfo {
3565
3771
case OMPScheduleType::DistributeChunked:
3566
3772
break ;
3567
3773
default :
3774
+ SPMDCompatibilityTracker.indicatePessimisticFixpoint ();
3568
3775
SPMDCompatibilityTracker.insert (&CB);
3569
3776
break ;
3570
3777
};
@@ -3590,7 +3797,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
3590
3797
// We do not look into tasks right now, just give up.
3591
3798
SPMDCompatibilityTracker.insert (&CB);
3592
3799
ReachedUnknownParallelRegions.insert (&CB);
3593
- break ;
3800
+ indicatePessimisticFixpoint ();
3801
+ return ;
3594
3802
case OMPRTL___kmpc_alloc_shared:
3595
3803
case OMPRTL___kmpc_free_shared:
3596
3804
// Return without setting a fixpoint, to be resolved in updateImpl.
@@ -3599,7 +3807,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
3599
3807
// Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3600
3808
// generally.
3601
3809
SPMDCompatibilityTracker.insert (&CB);
3602
- break ;
3810
+ indicatePessimisticFixpoint ();
3811
+ return ;
3603
3812
}
3604
3813
// All other OpenMP runtime calls will not reach parallel regions so they
3605
3814
// can be safely ignored for now. Since it is a known OpenMP runtime call we
0 commit comments