Skip to content

Commit 41566fb

Browse files
committed
[OpenMP][FIX] Ensure recording works properly w/ late allocations
1 parent 6663df3 commit 41566fb

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

openmp/libomptarget/plugins-nextgen/common/PluginInterface/PluginInterface.cpp

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,9 @@ struct RecordReplayTy {
231231
OS.close();
232232
}
233233

234-
void saveKernelDescr(const char *Name, void **ArgPtrs, ptrdiff_t *ArgOffsets,
235-
int32_t NumArgs, uint64_t NumTeamsClause,
236-
uint32_t ThreadLimitClause, uint64_t LoopTripCount) {
234+
void saveKernelDescr(const char *Name, void **ArgPtrs, int32_t NumArgs,
235+
uint64_t NumTeamsClause, uint32_t ThreadLimitClause,
236+
uint64_t LoopTripCount) {
237237
json::Object JsonKernelInfo;
238238
JsonKernelInfo["Name"] = Name;
239239
JsonKernelInfo["NumArgs"] = NumArgs;
@@ -251,7 +251,7 @@ struct RecordReplayTy {
251251

252252
json::Array JsonArgOffsets;
253253
for (int I = 0; I < NumArgs; ++I)
254-
JsonArgOffsets.push_back(ArgOffsets[I]);
254+
JsonArgOffsets.push_back(0);
255255
JsonKernelInfo["ArgOffsets"] = json::Value(std::move(JsonArgOffsets));
256256

257257
SmallString<128> JsonFilename = {Name, ".json"};
@@ -427,6 +427,11 @@ Expected<KernelLaunchEnvironmentTy *>
427427
GenericKernelTy::getKernelLaunchEnvironment(
428428
GenericDeviceTy &GenericDevice,
429429
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
430+
// Ctor/Dtor have no arguments, replaying uses the original kernel launch
431+
// environment.
432+
if (isCtorOrDtor() || RecordReplay.isReplaying())
433+
return nullptr;
434+
430435
// TODO: Check if the kernel needs a launch environment.
431436
auto AllocOrErr = GenericDevice.dataAlloc(sizeof(KernelLaunchEnvironmentTy),
432437
/*HostPtr=*/nullptr,
@@ -501,6 +506,15 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
501506
getNumBlocks(GenericDevice, KernelArgs.NumTeams, KernelArgs.Tripcount,
502507
NumThreads, KernelArgs.ThreadLimit[0] > 0);
503508

509+
// Record the kernel description after we modified the argument count and num
510+
// blocks/threads.
511+
if (RecordReplay.isRecording()) {
512+
RecordReplay.saveImage(getName(), getImage());
513+
RecordReplay.saveKernelInput(getName(), getImage());
514+
RecordReplay.saveKernelDescr(getName(), Ptrs.data(), KernelArgs.NumArgs,
515+
NumBlocks, NumThreads, KernelArgs.Tripcount);
516+
}
517+
504518
if (auto Err =
505519
printLaunchInfo(GenericDevice, KernelArgs, NumThreads, NumBlocks))
506520
return Err;
@@ -517,16 +531,20 @@ void *GenericKernelTy::prepareArgs(
517531
if (isCtorOrDtor())
518532
return nullptr;
519533

520-
NumArgs += 1;
534+
uint32_t KLEOffset = !!KernelLaunchEnvironment;
535+
NumArgs += KLEOffset;
521536

522537
Args.resize(NumArgs);
523538
Ptrs.resize(NumArgs);
524539

525-
Ptrs[0] = KernelLaunchEnvironment;
526-
Args[0] = &Ptrs[0];
540+
if (KernelLaunchEnvironment) {
541+
Ptrs[0] = KernelLaunchEnvironment;
542+
Args[0] = &Ptrs[0];
543+
}
527544

528-
for (int I = 1; I < NumArgs; ++I) {
529-
Ptrs[I] = (void *)((intptr_t)ArgPtrs[I - 1] + ArgOffsets[I - 1]);
545+
for (int I = KLEOffset; I < NumArgs; ++I) {
546+
Ptrs[I] =
547+
(void *)((intptr_t)ArgPtrs[I - KLEOffset] + ArgOffsets[I - KLEOffset]);
530548
Args[I] = &Ptrs[I];
531549
}
532550
return &Args[0];
@@ -808,7 +826,7 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
808826
return std::move(Err);
809827

810828
// Setup the global device memory pool if needed.
811-
if (shouldSetupDeviceMemoryPool()) {
829+
if (!RecordReplay.isReplaying() && shouldSetupDeviceMemoryPool()) {
812830
uint64_t HeapSize;
813831
auto SizeOrErr = getDeviceHeapSize(HeapSize);
814832
if (SizeOrErr) {
@@ -1413,21 +1431,9 @@ Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs,
14131431
GenericKernelTy &GenericKernel =
14141432
*reinterpret_cast<GenericKernelTy *>(EntryPtr);
14151433

1416-
if (RecordReplay.isRecording()) {
1417-
RecordReplay.saveImage(GenericKernel.getName(), GenericKernel.getImage());
1418-
RecordReplay.saveKernelInput(GenericKernel.getName(),
1419-
GenericKernel.getImage());
1420-
}
1421-
14221434
auto Err = GenericKernel.launch(*this, ArgPtrs, ArgOffsets, KernelArgs,
14231435
AsyncInfoWrapper);
14241436

1425-
if (RecordReplay.isRecording())
1426-
RecordReplay.saveKernelDescr(GenericKernel.getName(), ArgPtrs, ArgOffsets,
1427-
KernelArgs.NumArgs, KernelArgs.NumTeams[0],
1428-
KernelArgs.ThreadLimit[0],
1429-
KernelArgs.Tripcount);
1430-
14311437
// 'finalize' here to guarantee next record-replay actions are in-sync
14321438
AsyncInfoWrapper.finalize(Err);
14331439

0 commit comments

Comments
 (0)