@@ -161,6 +161,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
161
161
if (op.getDevice ())
162
162
result = todo (" device" );
163
163
};
164
+ auto checkDistSchedule = [&todo](auto op, LogicalResult &result) {
165
+ if (op.getDistScheduleChunkSize ())
166
+ result = todo (" dist_schedule with chunk_size" );
167
+ };
164
168
auto checkHasDeviceAddr = [&todo](auto op, LogicalResult &result) {
165
169
if (!op.getHasDeviceAddrVars ().empty ())
166
170
result = todo (" has_device_addr" );
@@ -252,6 +256,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
252
256
253
257
LogicalResult result = success ();
254
258
llvm::TypeSwitch<Operation &>(op)
259
+ .Case ([&](omp::DistributeOp op) {
260
+ if (op.isComposite () &&
261
+ isa_and_present<omp::WsloopOp>(op.getNestedWrapper ()))
262
+ result = op.emitError () << " not yet implemented: "
263
+ " composite omp.distribute + omp.wsloop" ;
264
+ checkAllocate (op, result);
265
+ checkDistSchedule (op, result);
266
+ checkOrder (op, result);
267
+ checkPrivate (op, result);
268
+ })
255
269
.Case ([&](omp::OrderedRegionOp op) { checkParLevelSimd (op, result); })
256
270
.Case ([&](omp::SectionsOp op) {
257
271
checkAllocate (op, result);
@@ -3754,6 +3768,72 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
3754
3768
return success ();
3755
3769
}
3756
3770
3771
+ static LogicalResult
3772
+ convertOmpDistribute (Operation &opInst, llvm::IRBuilderBase &builder,
3773
+ LLVM::ModuleTranslation &moduleTranslation) {
3774
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3775
+ auto distributeOp = cast<omp::DistributeOp>(opInst);
3776
+ if (failed (checkImplementationStatus (opInst)))
3777
+ return failure ();
3778
+
3779
+ using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3780
+ auto bodyGenCB = [&](InsertPointTy allocaIP,
3781
+ InsertPointTy codeGenIP) -> llvm::Error {
3782
+ // Save the alloca insertion point on ModuleTranslation stack for use in
3783
+ // nested regions.
3784
+ LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame (
3785
+ moduleTranslation, allocaIP);
3786
+
3787
+ // DistributeOp has only one region associated with it.
3788
+ builder.restoreIP (codeGenIP);
3789
+
3790
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3791
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3792
+ llvm::Expected<llvm::BasicBlock *> regionBlock =
3793
+ convertOmpOpRegions (distributeOp.getRegion (), " omp.distribute.region" ,
3794
+ builder, moduleTranslation);
3795
+ if (!regionBlock)
3796
+ return regionBlock.takeError ();
3797
+ builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
3798
+
3799
+ // TODO: Add support for clauses which are valid for DISTRIBUTE constructs.
3800
+ // Static schedule is the default.
3801
+ auto schedule = omp::ClauseScheduleKind::Static;
3802
+ bool isOrdered = false ;
3803
+ std::optional<omp::ScheduleModifier> scheduleMod;
3804
+ bool isSimd = false ;
3805
+ llvm::omp::WorksharingLoopType workshareLoopType =
3806
+ llvm::omp::WorksharingLoopType::DistributeStaticLoop;
3807
+ bool loopNeedsBarrier = false ;
3808
+ llvm::Value *chunk = nullptr ;
3809
+
3810
+ llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
3811
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3812
+ ompBuilder->applyWorkshareLoop (
3813
+ ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
3814
+ convertToScheduleKind (schedule), chunk, isSimd,
3815
+ scheduleMod == omp::ScheduleModifier::monotonic,
3816
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3817
+ workshareLoopType);
3818
+
3819
+ if (!wsloopIP)
3820
+ return wsloopIP.takeError ();
3821
+ return llvm::Error::success ();
3822
+ };
3823
+
3824
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3825
+ findAllocaInsertPoint (builder, moduleTranslation);
3826
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3827
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3828
+ ompBuilder->createDistribute (ompLoc, allocaIP, bodyGenCB);
3829
+
3830
+ if (failed (handleError (afterIP, opInst)))
3831
+ return failure ();
3832
+
3833
+ builder.restoreIP (*afterIP);
3834
+ return success ();
3835
+ }
3836
+
3757
3837
// / Lowers the FlagsAttr which is applied to the module on the device
3758
3838
// / pass when offloading, this attribute contains OpenMP RTL globals that can
3759
3839
// / be passed as flags to the frontend, otherwise they are set to default
@@ -4697,6 +4777,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
4697
4777
.Case ([&](omp::TargetOp) {
4698
4778
return convertOmpTarget (*op, builder, moduleTranslation);
4699
4779
})
4780
+ .Case ([&](omp::DistributeOp) {
4781
+ return convertOmpDistribute (*op, builder, moduleTranslation);
4782
+ })
4700
4783
.Case ([&](omp::LoopNestOp) {
4701
4784
return convertOmpLoopNest (*op, builder, moduleTranslation);
4702
4785
})
0 commit comments