@@ -161,6 +161,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
161
161
if (op.getDevice ())
162
162
result = todo (" device" );
163
163
};
164
+ auto checkDistSchedule = [&todo](auto op, LogicalResult &result) {
165
+ if (op.getDistScheduleChunkSize ())
166
+ result = todo (" dist_schedule with chunk_size" );
167
+ };
164
168
auto checkHasDeviceAddr = [&todo](auto op, LogicalResult &result) {
165
169
if (!op.getHasDeviceAddrVars ().empty ())
166
170
result = todo (" has_device_addr" );
@@ -252,6 +256,16 @@ static LogicalResult checkImplementationStatus(Operation &op) {
252
256
253
257
LogicalResult result = success ();
254
258
llvm::TypeSwitch<Operation &>(op)
259
+ .Case ([&](omp::DistributeOp op) {
260
+ if (op.isComposite () &&
261
+ isa_and_present<omp::WsloopOp>(op.getNestedWrapper ()))
262
+ result = op.emitError () << " not yet implemented: "
263
+ " composite omp.distribute + omp.wsloop" ;
264
+ checkAllocate (op, result);
265
+ checkDistSchedule (op, result);
266
+ checkOrder (op, result);
267
+ checkPrivate (op, result);
268
+ })
255
269
.Case ([&](omp::OrderedRegionOp op) { checkParLevelSimd (op, result); })
256
270
.Case ([&](omp::SectionsOp op) {
257
271
checkAllocate (op, result);
@@ -3854,6 +3868,72 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
3854
3868
return success ();
3855
3869
}
3856
3870
3871
+ static LogicalResult
3872
+ convertOmpDistribute (Operation &opInst, llvm::IRBuilderBase &builder,
3873
+ LLVM::ModuleTranslation &moduleTranslation) {
3874
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3875
+ auto distributeOp = cast<omp::DistributeOp>(opInst);
3876
+ if (failed (checkImplementationStatus (opInst)))
3877
+ return failure ();
3878
+
3879
+ using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
3880
+ auto bodyGenCB = [&](InsertPointTy allocaIP,
3881
+ InsertPointTy codeGenIP) -> llvm::Error {
3882
+ // Save the alloca insertion point on ModuleTranslation stack for use in
3883
+ // nested regions.
3884
+ LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame (
3885
+ moduleTranslation, allocaIP);
3886
+
3887
+ // DistributeOp has only one region associated with it.
3888
+ builder.restoreIP (codeGenIP);
3889
+
3890
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3891
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3892
+ llvm::Expected<llvm::BasicBlock *> regionBlock =
3893
+ convertOmpOpRegions (distributeOp.getRegion (), " omp.distribute.region" ,
3894
+ builder, moduleTranslation);
3895
+ if (!regionBlock)
3896
+ return regionBlock.takeError ();
3897
+ builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
3898
+
3899
+ // TODO: Add support for clauses which are valid for DISTRIBUTE constructs.
3900
+ // Static schedule is the default.
3901
+ auto schedule = omp::ClauseScheduleKind::Static;
3902
+ bool isOrdered = false ;
3903
+ std::optional<omp::ScheduleModifier> scheduleMod;
3904
+ bool isSimd = false ;
3905
+ llvm::omp::WorksharingLoopType workshareLoopType =
3906
+ llvm::omp::WorksharingLoopType::DistributeStaticLoop;
3907
+ bool loopNeedsBarrier = false ;
3908
+ llvm::Value *chunk = nullptr ;
3909
+
3910
+ llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
3911
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
3912
+ ompBuilder->applyWorkshareLoop (
3913
+ ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
3914
+ convertToScheduleKind (schedule), chunk, isSimd,
3915
+ scheduleMod == omp::ScheduleModifier::monotonic,
3916
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
3917
+ workshareLoopType);
3918
+
3919
+ if (!wsloopIP)
3920
+ return wsloopIP.takeError ();
3921
+ return llvm::Error::success ();
3922
+ };
3923
+
3924
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3925
+ findAllocaInsertPoint (builder, moduleTranslation);
3926
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3927
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3928
+ ompBuilder->createDistribute (ompLoc, allocaIP, bodyGenCB);
3929
+
3930
+ if (failed (handleError (afterIP, opInst)))
3931
+ return failure ();
3932
+
3933
+ builder.restoreIP (*afterIP);
3934
+ return success ();
3935
+ }
3936
+
3857
3937
// / Lowers the FlagsAttr which is applied to the module on the device
3858
3938
// / pass when offloading, this attribute contains OpenMP RTL globals that can
3859
3939
// / be passed as flags to the frontend, otherwise they are set to default
@@ -4813,6 +4893,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
4813
4893
.Case ([&](omp::TargetOp) {
4814
4894
return convertOmpTarget (*op, builder, moduleTranslation);
4815
4895
})
4896
+ .Case ([&](omp::DistributeOp) {
4897
+ return convertOmpDistribute (*op, builder, moduleTranslation);
4898
+ })
4816
4899
.Case ([&](omp::LoopNestOp) {
4817
4900
return convertOmpLoopNest (*op, builder, moduleTranslation);
4818
4901
})
0 commit comments