Skip to content

Commit 9b57b16

Browse files
committed
[OMPIRBuilder] Fix shared clause for task construct
This patch fixes the shared clause for the task construct with multiple shared variables. The shareds field in the kmp_task_t is not an inline array in the struct, rather it is a pointer to an array. With an inline array, the pointer dereference to the outlined function body of the task would segmentation fault when accessed by the runtime. Reviewed By: kiranchandramohan, jdoerfert Differential Revision: https://reviews.llvm.org/D158462
1 parent 6c82430 commit 9b57b16

File tree

4 files changed

+64
-42
lines changed

4 files changed

+64
-42
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPKinds.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ __OMP_STRUCT_TYPE(KernelArgs, __tgt_kernel_arguments, false, Int32, Int32, VoidP
9595
Int64, Int64, Int32Arr3Ty, Int32Arr3Ty, Int32)
9696
__OMP_STRUCT_TYPE(AsyncInfo, __tgt_async_info, false, Int8Ptr)
9797
__OMP_STRUCT_TYPE(DependInfo, kmp_dep_info, false, SizeTy, SizeTy, Int8)
98+
__OMP_STRUCT_TYPE(Task, kmp_task_ompbuilder_t, false, VoidPtr, VoidPtr, Int32, VoidPtr, VoidPtr)
9899
__OMP_STRUCT_TYPE(ConfigurationEnvironment, ConfigurationEnvironmentTy, false,
99100
Int8, Int8, Int8)
100101
__OMP_STRUCT_TYPE(DynamicEnvironment, DynamicEnvironmentTy, false, Int16)

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,9 +1555,9 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
15551555
"there must be a single user for the outlined function");
15561556
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
15571557

1558-
// HasTaskData is true if any variables are captured in the outlined region,
1558+
// HasShareds is true if any variables are captured in the outlined region,
15591559
// false otherwise.
1560-
bool HasTaskData = StaleCI->arg_size() > 0;
1560+
bool HasShareds = StaleCI->arg_size() > 0;
15611561
Builder.SetInsertPoint(StaleCI);
15621562

15631563
// Gather the arguments for emitting the runtime call for
@@ -1585,8 +1585,15 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
15851585
// Argument - `sizeof_kmp_task_t` (TaskSize)
15861586
// Tasksize refers to the size in bytes of kmp_task_t data structure
15871587
// including private vars accessed in task.
1588-
Value *TaskSize = Builder.getInt64(0);
1589-
if (HasTaskData) {
1588+
// TODO: add kmp_task_t_with_privates (privates)
1589+
Value *TaskSize = Builder.getInt64(
1590+
divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
1591+
1592+
// Argument - `sizeof_shareds` (SharedsSize)
1593+
// SharedsSize refers to the shareds array size in the kmp_task_t data
1594+
// structure.
1595+
Value *SharedsSize = Builder.getInt64(0);
1596+
if (HasShareds) {
15901597
AllocaInst *ArgStructAlloca =
15911598
dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
15921599
assert(ArgStructAlloca &&
@@ -1596,19 +1603,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
15961603
dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
15971604
assert(ArgStructType && "Unable to find struct type corresponding to "
15981605
"arguments for extracted function");
1599-
TaskSize =
1606+
SharedsSize =
16001607
Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
16011608
}
16021609

1603-
// TODO: Argument - sizeof_shareds
1604-
16051610
// Argument - task_entry (the wrapper function)
1606-
// If the outlined function has some captured variables (i.e. HasTaskData is
1611+
// If the outlined function has some captured variables (i.e. HasShareds is
16071612
// true), then the wrapper function will have an additional argument (the
16081613
// struct containing captured variables). Otherwise, no such argument will
16091614
// be present.
16101615
SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
1611-
if (HasTaskData)
1616+
if (HasShareds)
16121617
WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
16131618
FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
16141619
(Twine(OutlinedFn.getName()) + ".wrapper").str(),
@@ -1617,19 +1622,19 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
16171622

16181623
// Emit the @__kmpc_omp_task_alloc runtime call
16191624
// The runtime call returns a pointer to an area where the task captured
1620-
// variables must be copied before the task is run (NewTaskData)
1621-
CallInst *NewTaskData = Builder.CreateCall(
1622-
TaskAllocFn,
1623-
{/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1624-
/*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0),
1625-
/*task_func=*/WrapperFunc});
1625+
// variables must be copied before the task is run (TaskData)
1626+
CallInst *TaskData = Builder.CreateCall(
1627+
TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1628+
/*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
1629+
/*task_func=*/WrapperFunc});
16261630

16271631
// Copy the arguments for outlined function
1628-
if (HasTaskData) {
1629-
Value *TaskData = StaleCI->getArgOperand(0);
1632+
if (HasShareds) {
1633+
Value *Shareds = StaleCI->getArgOperand(0);
16301634
Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
1631-
Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment,
1632-
TaskSize);
1635+
Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
1636+
Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
1637+
SharedsSize);
16331638
}
16341639

16351640
Value *DepArrayPtr = nullptr;
@@ -1705,12 +1710,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
17051710
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
17061711
Function *TaskCompleteFn =
17071712
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
1708-
Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, NewTaskData});
1709-
if (HasTaskData)
1710-
Builder.CreateCall(WrapperFunc, {ThreadID, NewTaskData});
1713+
Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
1714+
if (HasShareds)
1715+
Builder.CreateCall(WrapperFunc, {ThreadID, TaskData});
17111716
else
17121717
Builder.CreateCall(WrapperFunc, {ThreadID});
1713-
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, NewTaskData});
1718+
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
17141719
Builder.SetInsertPoint(ThenTI);
17151720
}
17161721

@@ -1719,14 +1724,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
17191724
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
17201725
Builder.CreateCall(
17211726
TaskFn,
1722-
{Ident, ThreadID, NewTaskData, Builder.getInt32(Dependencies.size()),
1727+
{Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
17231728
DepArrayPtr, ConstantInt::get(Builder.getInt32Ty(), 0),
17241729
ConstantPointerNull::get(Type::getInt8PtrTy(M.getContext()))});
17251730

17261731
} else {
17271732
// Emit the @__kmpc_omp_task runtime call to spawn the task
17281733
Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
1729-
Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
1734+
Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
17301735
}
17311736

17321737
StaleCI->eraseFromParent();
@@ -1735,10 +1740,13 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
17351740
BasicBlock *WrapperEntryBB =
17361741
BasicBlock::Create(M.getContext(), "", WrapperFunc);
17371742
Builder.SetInsertPoint(WrapperEntryBB);
1738-
if (HasTaskData)
1739-
Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)});
1740-
else
1743+
if (HasShareds) {
1744+
llvm::Value *Shareds =
1745+
Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1));
1746+
Builder.CreateCall(&OutlinedFn, {Shareds});
1747+
} else {
17411748
Builder.CreateCall(&OutlinedFn);
1749+
}
17421750
Builder.CreateRet(Builder.getInt32(0));
17431751
};
17441752

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5397,19 +5397,29 @@ TEST_F(OpenMPIRBuilderTest, CreateTask) {
53975397
ConstantInt *DataSize =
53985398
dyn_cast<ConstantInt>(TaskAllocCall->getArgOperand(3));
53995399
ASSERT_NE(DataSize, nullptr);
5400-
EXPECT_EQ(DataSize->getSExtValue(), 24); // 64-bit pointer + 128-bit integer
5400+
EXPECT_EQ(DataSize->getSExtValue(), 40);
54015401

5402-
// TODO: Verify size of shared clause variables
5402+
ConstantInt *SharedsSize =
5403+
dyn_cast<ConstantInt>(TaskAllocCall->getOperand(4));
5404+
EXPECT_EQ(SharedsSize->getSExtValue(),
5405+
24); // 64-bit pointer + 128-bit integer
54035406

54045407
// Verify Wrapper function
54055408
Function *WrapperFunc =
54065409
dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
54075410
ASSERT_NE(WrapperFunc, nullptr);
5411+
5412+
LoadInst *SharedsLoad = dyn_cast<LoadInst>(WrapperFunc->begin()->begin());
5413+
ASSERT_NE(SharedsLoad, nullptr);
5414+
EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1));
5415+
54085416
EXPECT_FALSE(WrapperFunc->isDeclaration());
5409-
CallInst *OutlinedFnCall = dyn_cast<CallInst>(WrapperFunc->begin()->begin());
5417+
CallInst *OutlinedFnCall =
5418+
dyn_cast<CallInst>(++WrapperFunc->begin()->begin());
54105419
ASSERT_NE(OutlinedFnCall, nullptr);
54115420
EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty());
5412-
EXPECT_EQ(OutlinedFnCall->getArgOperand(0), WrapperFunc->getArg(1));
5421+
EXPECT_EQ(OutlinedFnCall->getArgOperand(0),
5422+
WrapperFunc->getArg(1)->uses().begin()->getUser());
54135423

54145424
// Verify the presence of `trunc` and `icmp` instructions in Outlined function
54155425
Function *OutlinedFn = OutlinedFnCall->getCalledFunction();

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,7 +2208,7 @@ llvm.mlir.global internal @_QFsubEx() : i32
22082208
llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
22092209
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
22102210
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
2211-
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0,
2211+
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
22122212
// CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]])
22132213
// CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
22142214
omp.task {
@@ -2258,7 +2258,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
22582258
llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
22592259
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
22602260
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
2261-
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0,
2261+
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
22622262
// CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]])
22632263
// CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}})
22642264
omp.task depend(taskdependin -> %zaddr : !llvm.ptr<i32>) {
@@ -2303,9 +2303,10 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
23032303
llvm.store %diff, %zaddr : !llvm.ptr<i32>
23042304
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
23052305
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
2306-
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 16, i64 0,
2306+
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16,
23072307
// CHECK-SAME: ptr @[[wrapper_fn:.+]])
2308-
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[task_data]], ptr {{.+}}, i64 16, i1 false)
2308+
// CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]]
2309+
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false)
23092310
// CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
23102311
omp.task {
23112312
%z = llvm.add %x, %y : i32
@@ -2334,7 +2335,8 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
23342335

23352336

23362337
// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) {
2337-
// CHECK: call void @[[outlined_fn]](ptr %[[task_data]])
2338+
// CHECK: %[[shareds:.+]] = load ptr, ptr %1, align 8
2339+
// CHECK: call void @[[outlined_fn]](ptr %[[shareds]])
23382340
// CHECK: ret i32 0
23392341
// CHECK: }
23402342

@@ -2430,7 +2432,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
24302432
// CHECK: br label %[[codeRepl:[^,]+]]
24312433
// CHECK: [[codeRepl]]:
24322434
// CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
2433-
// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 0, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
2435+
// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
24342436
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]])
24352437
// CHECK: br label %[[task_exit:[^,]+]]
24362438
// CHECK: [[task_exit]]:
@@ -2443,8 +2445,9 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
24432445
// CHECK: %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2
24442446
// CHECK: store ptr %[[zaddr]], ptr %[[gep3]], align 8
24452447
// CHECK: %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
2446-
// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 16, i64 0, ptr @omp_taskgroup_task..omp_par.1.wrapper)
2447-
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 8 %[[t2_alloc]], ptr align 8 %[[structArg]], i64 16, i1 false)
2448+
// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper)
2449+
// CHECK: %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]]
2450+
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align 1 %[[shareds]], ptr align 1 %[[structArg]], i64 16, i1 false)
24482451
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], ptr %[[t2_alloc]])
24492452
// CHECK: br label %[[task_exit3:[^,]+]]
24502453
// CHECK: [[task_exit3]]:
@@ -2614,7 +2617,7 @@ llvm.func @omp_task_final(%boolexpr: i1) {
26142617
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
26152618
// CHECK: %[[final_flag:.+]] = select i1 %[[boolexpr]], i32 2, i32 0
26162619
// CHECK: %[[task_flags:.+]] = or i32 %[[final_flag]], 1
2617-
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 0, i64 0, ptr @omp_task_final..omp_par.wrapper)
2620+
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 %[[task_flags]], i64 40, i64 0, ptr @omp_task_final..omp_par.wrapper)
26182621
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
26192622
// CHECK: br label %[[task_exit:[^,]+]]
26202623
// CHECK: [[task_exit]]:
@@ -2645,7 +2648,7 @@ llvm.func @omp_task_if(%boolexpr: i1) {
26452648
// CHECK: br label %[[codeRepl:[^,]+]]
26462649
// CHECK: [[codeRepl]]:
26472650
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
2648-
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0, i64 0, ptr @omp_task_if..omp_par.wrapper)
2651+
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 0, ptr @omp_task_if..omp_par.wrapper)
26492652
// CHECK: br i1 %[[boolexpr]], label %[[true_label:[^,]+]], label %[[false_label:[^,]+]]
26502653
// CHECK: [[true_label]]:
26512654
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])

0 commit comments

Comments
 (0)