Skip to content

[MLIR][OpenMP][OMPIRBuilder] Add lowering support for omp.target_triples #100156

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/test/Integration/OpenMP/map-types-and-sizes.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
! added to this directory and sub-directories.
!===----------------------------------------------------------------------===!

!RUN: %flang_fc1 -emit-llvm -fopenmp -flang-deprecated-no-hlfir %s -o - | FileCheck %s
!RUN: %flang_fc1 -emit-llvm -fopenmp -fopenmp-targets=amdgcn-amd-amdhsa -flang-deprecated-no-hlfir %s -o - | FileCheck %s

!===============================================================================
! Check MapTypes for target implicit captures
Expand Down
21 changes: 14 additions & 7 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ class OpenMPIRBuilderConfig {
// Grid Value for the GPU target
std::optional<omp::GV> GridValue;

/// When compilation is being done for the OpenMP host (i.e. `IsTargetDevice =
/// false`), this contains the list of offloading triples associated, if any.
SmallVector<Triple> TargetTriples;

OpenMPIRBuilderConfig();
OpenMPIRBuilderConfig(bool IsTargetDevice, bool IsGPU,
bool OpenMPOffloadMandatory,
Expand Down Expand Up @@ -2183,21 +2187,22 @@ class OpenMPIRBuilder {
/// kernel args vector.
struct TargetKernelArgs {
/// Number of arguments passed to the runtime library.
unsigned NumTargetItems;
unsigned NumTargetItems = 0;
/// Arguments passed to the runtime library
TargetDataRTArgs RTArgs;
/// The number of iterations
Value *NumIterations;
Value *NumIterations = nullptr;
/// The number of teams.
Value *NumTeams;
Value *NumTeams = nullptr;
/// The number of threads.
Value *NumThreads;
Value *NumThreads = nullptr;
/// The size of the dynamic shared memory.
Value *DynCGGroupMem;
Value *DynCGGroupMem = nullptr;
/// True if the kernel has 'no wait' clause.
bool HasNoWait;
bool HasNoWait = false;

/// Constructor for TargetKernelArgs
// Constructors for TargetKernelArgs.
TargetKernelArgs() {}
TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
Value *NumIterations, Value *NumTeams, Value *NumThreads,
Value *DynCGGroupMem, bool HasNoWait)
Expand Down Expand Up @@ -2834,6 +2839,7 @@ class OpenMPIRBuilder {
/// Generator for '#omp target'
///
/// \param Loc where the target data construct was encountered.
/// \param IsOffloadEntry whether it is an offload entry.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
Expand All @@ -2847,6 +2853,7 @@ class OpenMPIRBuilder {
/// \param Dependencies A vector of DependData objects that carry
// dependency information as passed in the depend clause
InsertPointTy createTarget(const LocationDescription &Loc,
bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
Expand Down
91 changes: 58 additions & 33 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6768,7 +6768,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
return ProxyFn;
}
static void emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
Expand All @@ -6781,8 +6781,8 @@ static void emitTargetOutlinedFunction(
CBFunc, ArgAccessorFuncCB);
};

OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, true,
OutlinedFn, OutlinedFnID);
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction,
IsOffloadEntry, OutlinedFn, OutlinedFnID);
}
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
Function *OutlinedFn, Value *OutlinedFnID,
Expand Down Expand Up @@ -6898,15 +6898,22 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(

Builder.restoreIP(TargetTaskBodyIP);

// emitKernelLaunch makes the necessary runtime call to offload the kernel.
// We then outline all that code into a separate function
// ('kernel_launch_function' in the pseudo code above). This function is then
// called by the target task proxy function (see
// '@.omp_target_task_proxy_func' in the pseudo code above)
// "@.omp_target_task_proxy_func' is generated by emitTargetTaskProxyFunction
Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
EmitTargetCallFallbackCB, Args, DeviceID,
RTLoc, TargetTaskAllocaIP));
if (OutlinedFnID) {
// emitKernelLaunch makes the necessary runtime call to offload the kernel.
// We then outline all that code into a separate function
// ('kernel_launch_function' in the pseudo code above). This function is
// then called by the target task proxy function (see
// '@.omp_target_task_proxy_func' in the pseudo code above)
// "@.omp_target_task_proxy_func' is generated by
// emitTargetTaskProxyFunction.
Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
EmitTargetCallFallbackCB, Args, DeviceID,
RTLoc, TargetTaskAllocaIP));
} else {
// When OutlinedFnID is set to nullptr, then it's not an offloading call. In
// this case, we execute the host implementation directly.
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
}

OI.ExitBB = Builder.saveIP().getBlock();
OI.PostOutlineCB = [this, ToBeDeleted, Dependencies,
Expand Down Expand Up @@ -7015,11 +7022,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
Function *TaskCompleteFn =
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
CallInst *CI = nullptr;
if (HasShareds)
CI = Builder.CreateCall(ProxyFn, {ThreadID, TaskData});
else
CI = Builder.CreateCall(ProxyFn, {ThreadID});
CallInst *CI = Builder.CreateCall(ProxyFn, {ThreadID, TaskData});
CI->setDebugLoc(StaleCI->getDebugLoc());
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
} else if (DepArray) {
Expand Down Expand Up @@ -7052,6 +7055,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
<< "\n");
return Builder.saveIP();
}

void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, TargetDataInfo &Info,
TargetDataRTArgs &RTArgs, MapInfosTy &CombinedInfo, bool IsNonContiguous,
Expand All @@ -7069,6 +7073,37 @@ static void emitTargetCall(
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
// Generate a function call to the host fallback implementation of the target
// region. This is called by the host when no offload entry was generated for
// the target region and when the offloading call fails at runtime.
auto &&EmitTargetCallFallbackCB =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note here, that if (!OutlinedFnID), then we cannot simply inline a call to OutlinedFn. We need to the following check

if (!OutlinedFnID) {
   if(RequiresOuterTargetTask)
       Builder.restoreIP(emitTargetTask(...));
   else
       Builder.restoreIP(EmitTargetCallFallBackCB(Builder.saveIP()));
}

But dont change this PR because I'll fold that into my upcoming changes and more importantly, emitTargetTask ->emitKernelLaunch assert on OutlinedFnID here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out this issue. I just pushed some changes to hopefully address it, though I think these should get a second review from you since you're more familiar with the handling of target depend.

[&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
Builder.restoreIP(IP);
Builder.CreateCall(OutlinedFn, Args);
return Builder.saveIP();
};

bool HasNoWait = false;
bool HasDependencies = Dependencies.size() > 0;
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;

// If we don't have an ID for the target region, it means an offload entry
// wasn't created. In this case we just run the host fallback directly.
if (!OutlinedFnID) {
if (RequiresOuterTargetTask) {
// Arguments that are intended to be directly forwarded to an
// emitKernelLaunch call are pased as nullptr, since OutlinedFnID=nullptr
// results in that call not being done.
OpenMPIRBuilder::TargetKernelArgs KArgs;
Builder.restoreIP(OMPBuilder.emitTargetTask(
OutlinedFn, /*OutlinedFnID=*/nullptr, EmitTargetCallFallbackCB, KArgs,
/*DeviceID=*/nullptr, /*RTLoc=*/nullptr, AllocaIP, Dependencies,
HasNoWait));
} else {
Builder.restoreIP(EmitTargetCallFallbackCB(Builder.saveIP()));
}
return;
}

OpenMPIRBuilder::TargetDataInfo Info(
/*RequiresDevicePointerInfo=*/false,
Expand All @@ -7081,14 +7116,6 @@ static void emitTargetCall(
/*IsNonContiguous=*/true,
/*ForEndCall=*/false);

// emitKernelLaunch
auto &&EmitTargetCallFallbackCB =
[&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
Builder.restoreIP(IP);
Builder.CreateCall(OutlinedFn, Args);
return Builder.saveIP();
};

unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
Expand All @@ -7103,10 +7130,6 @@ static void emitTargetCall(
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);

bool HasNoWait = false;
bool HasDependencies = Dependencies.size() > 0;
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;

OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
NumTeamsVal, NumThreadsVal,
DynCGGroupMem, HasNoWait);
Expand All @@ -7123,8 +7146,9 @@ static void emitTargetCall(
DeviceID, RTLoc, AllocaIP));
}
}

OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, InsertPointTy AllocaIP,
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
int32_t NumThreads, SmallVectorImpl<Value *> &Args,
GenMapInfoCallbackTy GenMapInfoCB,
Expand All @@ -7138,12 +7162,13 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
Builder.restoreIP(CodeGenIP);

Function *OutlinedFn;
Constant *OutlinedFnID;
Constant *OutlinedFnID = nullptr;
// The target region is outlined into its own function. The LLVM IR for
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn,
OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB);
emitTargetOutlinedFunction(*this, Builder, IsOffloadEntry, EntryInfo,
OutlinedFn, OutlinedFnID, Args, CBFunc,
ArgAccessorFuncCB);

// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
Expand Down
10 changes: 6 additions & 4 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5983,8 +5983,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
Builder.restoreIP(OMPBuilder.createTarget(
OmpLoc, Builder.saveIP(), Builder.saveIP(), EntryInfo, -1, 0, Inputs,
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
OMPBuilder.finalize();
Builder.CreateRetVoid();

Expand Down Expand Up @@ -6087,7 +6087,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
/*Line=*/3, /*Count=*/0);

Builder.restoreIP(
OMPBuilder.createTarget(Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, /*NumTeams=*/-1,
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));

Expand Down Expand Up @@ -6235,7 +6236,8 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
/*Line=*/3, /*Count=*/0);

Builder.restoreIP(
OMPBuilder.createTarget(Loc, EntryIP, EntryIP, EntryInfo, /*NumTeams=*/-1,
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
EntryInfo, /*NumTeams=*/-1,
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
BodyGenCB, SimpleArgAccessorCB));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3233,13 +3233,20 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
if (!targetOpSupported(opInst))
return failure();

llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
bool isTargetDevice = ompBuilder->Config.isTargetDevice();
auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
auto targetOp = cast<omp::TargetOp>(opInst);
auto &targetRegion = targetOp.getRegion();
DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
SmallVector<Value> mapVars = targetOp.getMapVars();
llvm::Function *llvmOutlinedFn = nullptr;

// TODO: It can also be false if a compile-time constant `false` IF clause is
// specified.
bool isOffloadEntry =
isTargetDevice || !ompBuilder->Config.TargetTriples.empty();

LogicalResult bodyGenStatus = success();
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
auto bodyCB = [&](InsertPointTy allocaIP,
Expand Down Expand Up @@ -3306,14 +3313,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
llvm::Value *&retVal, InsertPointTy allocaIP,
InsertPointTy codeGenIP) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();

// We just return the unaltered argument for the host function
// for now, some alterations may be required in the future to
// keep host fallback functions working identically to the device
// version (e.g. pass ByCopy values should be treated as such on
// host and device, currently not always the case)
if (!ompBuilder->Config.isTargetDevice()) {
if (!isTargetDevice) {
retVal = cast<llvm::Value>(&arg);
return codeGenIP;
}
Expand All @@ -3339,9 +3344,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
moduleTranslation, dds);

builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, allocaIP, builder.saveIP(), entryInfo, defaultValTeams,
defaultValThreads, kernelInput, genMapInfoCB, bodyCB, argAccessorCB,
dds));
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
argAccessorCB, dds));

// Remap access operations to declare target reference pointers for the
// device, essentially generating extra loadop's as necessary
Expand Down Expand Up @@ -3714,6 +3719,23 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
}
return failure();
})
.Case("omp.target_triples",
[&](Attribute attr) {
if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
llvm::OpenMPIRBuilderConfig &config =
moduleTranslation.getOpenMPBuilder()->Config;
config.TargetTriples.clear();
config.TargetTriples.reserve(triplesAttr.size());
for (Attribute tripleAttr : triplesAttr) {
if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
config.TargetTriples.emplace_back(tripleStrAttr.getValue());
else
return failure();
}
return success();
}
return failure();
})
.Default([](Attribute) {
// Fall through for omp attributes that do not require lowering.
return success();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
// array bounds to lower to the full size of the array and the sectioned
// array to be the size of 3*3*1*element-byte-size (36 bytes in this case).

module attributes {omp.is_target_device = false} {
module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
llvm.func @_3d_target_array_section() {
%0 = llvm.mlir.addressof @_QFEinarray : !llvm.ptr
%1 = llvm.mlir.addressof @_QFEoutarray : !llvm.ptr
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

module attributes {omp.is_target_device = false} {
module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
llvm.func @_QQmain() attributes {fir.bindc_name = "main"} {
%0 = llvm.mlir.addressof @_QFEi : !llvm.ptr
%1 = llvm.mlir.addressof @_QFEsp : !llvm.ptr
Expand Down
33 changes: 33 additions & 0 deletions mlir/test/Target/LLVMIR/omptarget-depend-host-only.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

module attributes {omp.is_target_device = false} {
llvm.func @omp_target_depend_() {
%0 = llvm.mlir.constant(39 : index) : i64
%1 = llvm.mlir.constant(1 : index) : i64
%2 = llvm.mlir.constant(40 : index) : i64
%3 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) extent(%2 : i64) stride(%1 : i64) start_idx(%1 : i64)
%4 = llvm.mlir.addressof @_QFEa : !llvm.ptr
%5 = omp.map.info var_ptr(%4 : !llvm.ptr, !llvm.array<40 x i32>) map_clauses(from) capture(ByRef) bounds(%3) -> !llvm.ptr {name = "a"}
omp.target map_entries(%5 -> %arg0 : !llvm.ptr) depend(taskdependin -> %4 : !llvm.ptr) {
^bb0(%arg0: !llvm.ptr):
%6 = llvm.mlir.constant(100 : index) : i32
llvm.store %6, %arg0 : i32, !llvm.ptr
omp.terminator
}
llvm.return
}

llvm.mlir.global internal @_QFEa() {addr_space = 0 : i32} : !llvm.array<40 x i32> {
%0 = llvm.mlir.zero : !llvm.array<40 x i32>
llvm.return %0 : !llvm.array<40 x i32>
}
}

// CHECK: define void @omp_target_depend_()
// CHECK-NOT: define {{.*}} @
// CHECK-NOT: call i32 @__tgt_target_kernel({{.*}})
// CHECK: call void @__omp_offloading_[[DEV:.*]]_[[FIL:.*]]_omp_target_depend__l[[LINE:.*]](ptr {{.*}})
// CHECK-NEXT: ret void

// CHECK: define internal void @__omp_offloading_[[DEV]]_[[FIL]]_omp_target_depend__l[[LINE]](ptr %[[ADDR_A:.*]])
// CHECK: store i32 100, ptr %[[ADDR_A]], align 4
Loading
Loading