@@ -3495,9 +3495,9 @@ checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3495
3495
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU (
3496
3496
const LocationDescription &Loc, InsertPointTy AllocaIP,
3497
3497
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3498
- bool IsNoWait, bool IsTeamsReduction, bool HasDistribute ,
3499
- ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
3500
- unsigned ReductionBufNum, Value *SrcLocInfo) {
3498
+ bool IsNoWait, bool IsTeamsReduction, ReductionGenCBKind ReductionGenCBKind ,
3499
+ std::optional<omp::GV> GridValue, unsigned ReductionBufNum ,
3500
+ Value *SrcLocInfo) {
3501
3501
if (!updateToLocation (Loc))
3502
3502
return InsertPointTy ();
3503
3503
Builder.restoreIP (CodeGenIP);
@@ -3514,6 +3514,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
3514
3514
if (ReductionInfos.size () == 0 )
3515
3515
return Builder.saveIP ();
3516
3516
3517
+ BasicBlock *ContinuationBlock = nullptr ;
3518
+ if (ReductionGenCBKind != ReductionGenCBKind::Clang) {
3519
+ // Copied code from createReductions
3520
+ BasicBlock *InsertBlock = Loc.IP .getBlock ();
3521
+ ContinuationBlock =
3522
+ InsertBlock->splitBasicBlock (Loc.IP .getPoint (), " reduce.finalize" );
3523
+ InsertBlock->getTerminator ()->eraseFromParent ();
3524
+ Builder.SetInsertPoint (InsertBlock, InsertBlock->end ());
3525
+ }
3526
+
3517
3527
Function *CurFunc = Builder.GetInsertBlock ()->getParent ();
3518
3528
AttributeList FuncAttrs;
3519
3529
AttrBuilder AttrBldr (Ctx);
@@ -3669,11 +3679,21 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductionsGPU(
3669
3679
ReductionFunc;
3670
3680
});
3671
3681
} else {
3672
- assert (false && " Unhandled ReductionGenCBKind" );
3682
+ Value *LHSValue = Builder.CreateLoad (RI.ElementType , LHS, " final.lhs" );
3683
+ Value *RHSValue = Builder.CreateLoad (RI.ElementType , RHS, " final.rhs" );
3684
+ Value *Reduced;
3685
+ InsertPointOrErrorTy AfterIP =
3686
+ RI.ReductionGen (Builder.saveIP (), RHSValue, LHSValue, Reduced);
3687
+ if (!AfterIP)
3688
+ return AfterIP.takeError ();
3689
+ Builder.CreateStore (Reduced, LHS, false );
3673
3690
}
3674
3691
}
3675
3692
emitBlock (ExitBB, CurFunc);
3676
-
3693
+ if (ContinuationBlock) {
3694
+ Builder.CreateBr (ContinuationBlock);
3695
+ Builder.SetInsertPoint (ContinuationBlock);
3696
+ }
3677
3697
Config.setEmitLLVMUsed ();
3678
3698
3679
3699
return Builder.saveIP ();
@@ -3688,27 +3708,95 @@ static Function *getFreshReductionFunc(Module &M) {
3688
3708
" .omp.reduction.func" , &M);
3689
3709
}
3690
3710
3691
- OpenMPIRBuilder::InsertPointOrErrorTy
3692
- OpenMPIRBuilder::createReductions (const LocationDescription &Loc,
3693
- InsertPointTy AllocaIP,
3694
- ArrayRef<ReductionInfo> ReductionInfos,
3695
- ArrayRef<bool > IsByRef, bool IsNoWait) {
3696
- assert (ReductionInfos.size () == IsByRef.size ());
3697
- for (const ReductionInfo &RI : ReductionInfos) {
3698
- (void )RI;
3699
- assert (RI.Variable && " expected non-null variable" );
3700
- assert (RI.PrivateVariable && " expected non-null private variable" );
3701
- assert (RI.ReductionGen && " expected non-null reduction generator callback" );
3702
- assert (RI.Variable ->getType () == RI.PrivateVariable ->getType () &&
3703
- " expected variables and their private equivalents to have the same "
3704
- " type" );
3705
- assert (RI.Variable ->getType ()->isPointerTy () &&
3706
- " expected variables to be pointers" );
3711
+ static Error populateReductionFunction (
3712
+ Function *ReductionFunc,
3713
+ ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3714
+ IRBuilder<> &Builder, ArrayRef<bool > IsByRef, bool IsGPU) {
3715
+ Module *Module = ReductionFunc->getParent ();
3716
+ BasicBlock *ReductionFuncBlock =
3717
+ BasicBlock::Create (Module->getContext (), " " , ReductionFunc);
3718
+ Builder.SetInsertPoint (ReductionFuncBlock);
3719
+ Value *LHSArrayPtr = nullptr ;
3720
+ Value *RHSArrayPtr = nullptr ;
3721
+ if (IsGPU) {
3722
+ // Need to alloca memory here and deal with the pointers before getting
3723
+ // LHS/RHS pointers out
3724
+ //
3725
+ Argument *Arg0 = ReductionFunc->getArg (0 );
3726
+ Argument *Arg1 = ReductionFunc->getArg (1 );
3727
+ Type *Arg0Type = Arg0->getType ();
3728
+ Type *Arg1Type = Arg1->getType ();
3729
+
3730
+ Value *LHSAlloca =
3731
+ Builder.CreateAlloca (Arg0Type, nullptr , Arg0->getName () + " .addr" );
3732
+ Value *RHSAlloca =
3733
+ Builder.CreateAlloca (Arg1Type, nullptr , Arg1->getName () + " .addr" );
3734
+ Value *LHSAddrCast =
3735
+ Builder.CreatePointerBitCastOrAddrSpaceCast (LHSAlloca, Arg0Type);
3736
+ Value *RHSAddrCast =
3737
+ Builder.CreatePointerBitCastOrAddrSpaceCast (RHSAlloca, Arg1Type);
3738
+ Builder.CreateStore (Arg0, LHSAddrCast);
3739
+ Builder.CreateStore (Arg1, RHSAddrCast);
3740
+ LHSArrayPtr = Builder.CreateLoad (Arg0Type, LHSAddrCast);
3741
+ RHSArrayPtr = Builder.CreateLoad (Arg1Type, RHSAddrCast);
3742
+ } else {
3743
+ LHSArrayPtr = ReductionFunc->getArg (0 );
3744
+ RHSArrayPtr = ReductionFunc->getArg (1 );
3707
3745
}
3708
3746
3747
+ unsigned NumReductions = ReductionInfos.size ();
3748
+ Type *RedArrayTy = ArrayType::get (Builder.getPtrTy (), NumReductions);
3749
+
3750
+ for (auto En : enumerate(ReductionInfos)) {
3751
+ const OpenMPIRBuilder::ReductionInfo &RI = En.value ();
3752
+ Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
3753
+ RedArrayTy, LHSArrayPtr, 0 , En.index ());
3754
+ Value *LHSI8Ptr = Builder.CreateLoad (Builder.getPtrTy (), LHSI8PtrPtr);
3755
+ Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast (
3756
+ LHSI8Ptr, RI.Variable ->getType ());
3757
+ Value *LHS = Builder.CreateLoad (RI.ElementType , LHSPtr);
3758
+ Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
3759
+ RedArrayTy, RHSArrayPtr, 0 , En.index ());
3760
+ Value *RHSI8Ptr = Builder.CreateLoad (Builder.getPtrTy (), RHSI8PtrPtr);
3761
+ Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast (
3762
+ RHSI8Ptr, RI.PrivateVariable ->getType ());
3763
+ Value *RHS = Builder.CreateLoad (RI.ElementType , RHSPtr);
3764
+ Value *Reduced;
3765
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
3766
+ RI.ReductionGen (Builder.saveIP (), LHS, RHS, Reduced);
3767
+ if (!AfterIP)
3768
+ return AfterIP.takeError ();
3769
+
3770
+ Builder.restoreIP (*AfterIP);
3771
+ // TODO: Consider flagging an error.
3772
+ if (!Builder.GetInsertBlock ())
3773
+ return Error::success ();
3774
+
3775
+ // store is inside of the reduction region when using by-ref
3776
+ if (!IsByRef[En.index ()])
3777
+ Builder.CreateStore (Reduced, LHSPtr);
3778
+ }
3779
+ Builder.CreateRetVoid ();
3780
+ return Error::success ();
3781
+ }
3782
+
3783
+ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createReductions (
3784
+ const LocationDescription &Loc, InsertPointTy AllocaIP,
3785
+ ArrayRef<ReductionInfo> ReductionInfos, ArrayRef<bool > IsByRef,
3786
+ bool IsNoWait, bool IsTeamsReduction) {
3787
+ assert (ReductionInfos.size () == IsByRef.size ());
3788
+ if (Config.isGPU ())
3789
+ return createReductionsGPU (Loc, AllocaIP, Builder.saveIP (), ReductionInfos,
3790
+ IsNoWait, IsTeamsReduction);
3791
+
3792
+ checkReductionInfos (ReductionInfos, /* IsGPU*/ false );
3793
+
3709
3794
if (!updateToLocation (Loc))
3710
3795
return InsertPointTy ();
3711
3796
3797
+ if (ReductionInfos.size () == 0 )
3798
+ return Builder.saveIP ();
3799
+
3712
3800
BasicBlock *InsertBlock = Loc.IP .getBlock ();
3713
3801
BasicBlock *ContinuationBlock =
3714
3802
InsertBlock->splitBasicBlock (Loc.IP .getPoint (), " reduce.finalize" );
@@ -3832,38 +3920,13 @@ OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
3832
3920
// Populate the outlined reduction function using the elementwise reduction
3833
3921
// function. Partial values are extracted from the type-erased array of
3834
3922
// pointers to private variables.
3835
- BasicBlock *ReductionFuncBlock =
3836
- BasicBlock::Create (Module->getContext (), " " , ReductionFunc);
3837
- Builder.SetInsertPoint (ReductionFuncBlock);
3838
- Value *LHSArrayPtr = ReductionFunc->getArg (0 );
3839
- Value *RHSArrayPtr = ReductionFunc->getArg (1 );
3923
+ Error Err = populateReductionFunction (ReductionFunc, ReductionInfos, Builder,
3924
+ IsByRef, /* isGPU=*/ false );
3925
+ if (Err)
3926
+ return Err;
3840
3927
3841
- for (auto En : enumerate(ReductionInfos)) {
3842
- const ReductionInfo &RI = En.value ();
3843
- Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
3844
- RedArrayTy, LHSArrayPtr, 0 , En.index ());
3845
- Value *LHSI8Ptr = Builder.CreateLoad (Builder.getPtrTy (), LHSI8PtrPtr);
3846
- Value *LHSPtr = Builder.CreateBitCast (LHSI8Ptr, RI.Variable ->getType ());
3847
- Value *LHS = Builder.CreateLoad (RI.ElementType , LHSPtr);
3848
- Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64 (
3849
- RedArrayTy, RHSArrayPtr, 0 , En.index ());
3850
- Value *RHSI8Ptr = Builder.CreateLoad (Builder.getPtrTy (), RHSI8PtrPtr);
3851
- Value *RHSPtr =
3852
- Builder.CreateBitCast (RHSI8Ptr, RI.PrivateVariable ->getType ());
3853
- Value *RHS = Builder.CreateLoad (RI.ElementType , RHSPtr);
3854
- Value *Reduced;
3855
- InsertPointOrErrorTy AfterIP =
3856
- RI.ReductionGen (Builder.saveIP (), LHS, RHS, Reduced);
3857
- if (!AfterIP)
3858
- return AfterIP.takeError ();
3859
- Builder.restoreIP (*AfterIP);
3860
- if (!Builder.GetInsertBlock ())
3861
- return InsertPointTy ();
3862
- // store is inside of the reduction region when using by-ref
3863
- if (!IsByRef[En.index ()])
3864
- Builder.CreateStore (Reduced, LHSPtr);
3865
- }
3866
- Builder.CreateRetVoid ();
3928
+ if (!Builder.GetInsertBlock ())
3929
+ return InsertPointTy ();
3867
3930
3868
3931
Builder.SetInsertPoint (ContinuationBlock);
3869
3932
return Builder.saveIP ();
@@ -6239,8 +6302,10 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6239
6302
Constant *MaxThreads = ConstantInt::getSigned (Int32, MaxThreadsVal);
6240
6303
Constant *MinTeams = ConstantInt::getSigned (Int32, Attrs.MinTeams );
6241
6304
Constant *MaxTeams = ConstantInt::getSigned (Int32, Attrs.MaxTeams .front ());
6242
- Constant *ReductionDataSize = ConstantInt::getSigned (Int32, 0 );
6243
- Constant *ReductionBufferLength = ConstantInt::getSigned (Int32, 0 );
6305
+ Constant *ReductionDataSize =
6306
+ ConstantInt::getSigned (Int32, Attrs.ReductionDataSize );
6307
+ Constant *ReductionBufferLength =
6308
+ ConstantInt::getSigned (Int32, Attrs.ReductionBufferLength );
6244
6309
6245
6310
Function *Fn = getOrCreateRuntimeFunctionPtr (
6246
6311
omp::RuntimeFunction::OMPRTL___kmpc_target_init);
0 commit comments