@@ -164,6 +164,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
164
164
if (op.getDevice ())
165
165
result = todo (" device" );
166
166
};
167
+ auto checkDistSchedule = [&todo](auto op, LogicalResult &result) {
168
+ if (op.getDistScheduleChunkSize ())
169
+ result = todo (" dist_schedule with chunk_size" );
170
+ };
167
171
auto checkHasDeviceAddr = [&todo](auto op, LogicalResult &result) {
168
172
if (!op.getHasDeviceAddrVars ().empty ())
169
173
result = todo (" has_device_addr" );
@@ -255,6 +259,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
255
259
256
260
LogicalResult result = success ();
257
261
llvm::TypeSwitch<Operation &>(op)
262
+ .Case ([&](omp::DistributeOp op) {
263
+ if (op.isComposite () &&
264
+ isa_and_present<omp::WsloopOp>(op.getNestedWrapper ()))
265
+ result = op.emitError () << " not yet implemented: "
266
+ " composite omp.distribute + omp.wsloop" ;
267
+ checkAllocate (op, result);
268
+ checkDistSchedule (op, result);
269
+ checkOrder (op, result);
270
+ checkPrivate (op, result);
271
+ })
258
272
.Case ([&](omp::OrderedRegionOp op) { checkParLevelSimd (op, result); })
259
273
.Case ([&](omp::SectionsOp op) {
260
274
checkAllocate (op, result);
@@ -3755,6 +3769,67 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
3755
3769
return success ();
3756
3770
}
3757
3771
3772
+ static LogicalResult
3773
+ convertOmpDistribute (Operation &opInst, llvm::IRBuilderBase &builder,
3774
+ LLVM::ModuleTranslation &moduleTranslation) {
3775
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3776
+ auto distributeOp = cast<omp::DistributeOp>(opInst);
3777
+ if (failed (checkImplementationStatus (opInst)))
3778
+ return failure ();
3779
+
3780
+ using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3781
+ auto bodyGenCB = [&](InsertPointTy allocaIP,
3782
+ InsertPointTy codeGenIP) -> llvm::Error {
3783
+ // DistributeOp has only one region associated with it.
3784
+ builder.restoreIP (codeGenIP);
3785
+
3786
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3787
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3788
+ llvm::Expected<llvm::BasicBlock *> regionBlock =
3789
+ convertOmpOpRegions (distributeOp.getRegion (), " omp.distribute.region" ,
3790
+ builder, moduleTranslation);
3791
+ if (!regionBlock)
3792
+ return regionBlock.takeError ();
3793
+ builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
3794
+
3795
+ // TODO: Add support for clauses which are valid for DISTRIBUTE constructs.
3796
+ // Static schedule is the default.
3797
+ auto schedule = omp::ClauseScheduleKind::Static;
3798
+ bool isOrdered = false ;
3799
+ std::optional<omp::ScheduleModifier> scheduleMod;
3800
+ bool isSimd = false ;
3801
+ llvm::omp::WorksharingLoopType workshareLoopType =
3802
+ llvm::omp::WorksharingLoopType::DistributeStaticLoop;
3803
+ bool loopNeedsBarrier = false ;
3804
+ llvm::Value *chunk = nullptr ;
3805
+
3806
+ llvm::CanonicalLoopInfo *loopInfo = *findCurrentLoopInfo (moduleTranslation);
3807
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3808
+ ompBuilder->applyWorkshareLoop (
3809
+ ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
3810
+ convertToScheduleKind (schedule), chunk, isSimd,
3811
+ scheduleMod == omp::ScheduleModifier::monotonic,
3812
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3813
+ workshareLoopType);
3814
+
3815
+ if (!wsloopIP)
3816
+ return wsloopIP.takeError ();
3817
+ return llvm::Error::success ();
3818
+ };
3819
+
3820
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3821
+ findAllocaInsertPoint (builder, moduleTranslation);
3822
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3823
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3824
+ ompBuilder->createDistribute (ompLoc, allocaIP, bodyGenCB);
3825
+
3826
+ if (failed (handleError (afterIP, opInst)))
3827
+ return failure ();
3828
+
3829
+ builder.restoreIP (*afterIP);
3830
+ return success ();
3831
+ }
3832
+
3758
3833
// / Lowers the FlagsAttr which is applied to the module on the device
3759
3834
// / pass when offloading, this attribute contains OpenMP RTL globals that can
3760
3835
// / be passed as flags to the frontend, otherwise they are set to default
@@ -4685,6 +4760,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
4685
4760
.Case ([&](omp::TargetOp) {
4686
4761
return convertOmpTarget (*op, builder, moduleTranslation);
4687
4762
})
4763
+ .Case ([&](omp::DistributeOp) {
4764
+ return convertOmpDistribute (*op, builder, moduleTranslation);
4765
+ })
4688
4766
.Case ([&](omp::LoopNestOp) {
4689
4767
return convertOmpLoopNest (*op, builder, moduleTranslation);
4690
4768
})
0 commit comments