@@ -2041,7 +2041,8 @@ bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2041
2041
UndefValue::get (Int8Ty), F->getName () + " .ID" );
2042
2042
2043
2043
for (Use *U : ToBeReplacedStateMachineUses)
2044
- U->set (ConstantExpr::getBitCast (ID, U->get ()->getType ()));
2044
+ U->set (ConstantExpr::getPointerBitCastOrAddrSpaceCast (
2045
+ ID, U->get ()->getType ()));
2045
2046
2046
2047
++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2047
2048
@@ -3455,10 +3456,14 @@ struct AAKernelInfoFunction : AAKernelInfo {
3455
3456
IsWorker->setDebugLoc (DLoc);
3456
3457
BranchInst::Create (StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3457
3458
3459
+ Module &M = *Kernel->getParent ();
3460
+
3458
3461
// Create local storage for the work function pointer.
3462
+ const DataLayout &DL = M.getDataLayout ();
3459
3463
Type *VoidPtrTy = Type::getInt8PtrTy (Ctx);
3460
- AllocaInst *WorkFnAI = new AllocaInst (VoidPtrTy, 0 , " worker.work_fn.addr" ,
3461
- &Kernel->getEntryBlock ().front ());
3464
+ Instruction *WorkFnAI =
3465
+ new AllocaInst (VoidPtrTy, DL.getAllocaAddrSpace (), nullptr ,
3466
+ " worker.work_fn.addr" , &Kernel->getEntryBlock ().front ());
3462
3467
WorkFnAI->setDebugLoc (DLoc);
3463
3468
3464
3469
auto &OMPInfoCache = static_cast <OMPInformationCache &>(A.getInfoCache ());
@@ -3471,13 +3476,23 @@ struct AAKernelInfoFunction : AAKernelInfo {
3471
3476
Value *Ident = KernelInitCB->getArgOperand (0 );
3472
3477
Value *GTid = KernelInitCB;
3473
3478
3474
- Module &M = *Kernel->getParent ();
3475
3479
FunctionCallee BarrierFn =
3476
3480
OMPInfoCache.OMPBuilder .getOrCreateRuntimeFunction (
3477
3481
M, OMPRTL___kmpc_barrier_simple_spmd);
3478
3482
CallInst::Create (BarrierFn, {Ident, GTid}, " " , StateMachineBeginBB)
3479
3483
->setDebugLoc (DLoc);
3480
3484
3485
+ if (WorkFnAI->getType ()->getPointerAddressSpace () !=
3486
+ (unsigned int )AddressSpace::Generic) {
3487
+ WorkFnAI = new AddrSpaceCastInst (
3488
+ WorkFnAI,
3489
+ PointerType::getWithSamePointeeType (
3490
+ cast<PointerType>(WorkFnAI->getType ()),
3491
+ (unsigned int )AddressSpace::Generic),
3492
+ WorkFnAI->getName () + " .generic" , StateMachineBeginBB);
3493
+ WorkFnAI->setDebugLoc (DLoc);
3494
+ }
3495
+
3481
3496
FunctionCallee KernelParallelFn =
3482
3497
OMPInfoCache.OMPBuilder .getOrCreateRuntimeFunction (
3483
3498
M, OMPRTL___kmpc_kernel_parallel);
0 commit comments