@@ -231,9 +231,9 @@ struct RecordReplayTy {
231
231
OS.close ();
232
232
}
233
233
234
- void saveKernelDescr (const char *Name, void **ArgPtrs, ptrdiff_t *ArgOffsets ,
235
- int32_t NumArgs, uint64_t NumTeamsClause ,
236
- uint32_t ThreadLimitClause, uint64_t LoopTripCount) {
234
+ void saveKernelDescr (const char *Name, void **ArgPtrs, int32_t NumArgs ,
235
+ uint64_t NumTeamsClause, uint32_t ThreadLimitClause ,
236
+ uint64_t LoopTripCount) {
237
237
json::Object JsonKernelInfo;
238
238
JsonKernelInfo[" Name" ] = Name;
239
239
JsonKernelInfo[" NumArgs" ] = NumArgs;
@@ -251,7 +251,7 @@ struct RecordReplayTy {
251
251
252
252
json::Array JsonArgOffsets;
253
253
for (int I = 0 ; I < NumArgs; ++I)
254
- JsonArgOffsets.push_back (ArgOffsets[I] );
254
+ JsonArgOffsets.push_back (0 );
255
255
JsonKernelInfo[" ArgOffsets" ] = json::Value (std::move (JsonArgOffsets));
256
256
257
257
SmallString<128 > JsonFilename = {Name, " .json" };
@@ -427,6 +427,11 @@ Expected<KernelLaunchEnvironmentTy *>
427
427
GenericKernelTy::getKernelLaunchEnvironment (
428
428
GenericDeviceTy &GenericDevice,
429
429
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
430
+ // Ctor/Dtor have no arguments, replaying uses the original kernel launch
431
+ // environment.
432
+ if (isCtorOrDtor () || RecordReplay.isReplaying ())
433
+ return nullptr ;
434
+
430
435
// TODO: Check if the kernel needs a launch environment.
431
436
auto AllocOrErr = GenericDevice.dataAlloc (sizeof (KernelLaunchEnvironmentTy),
432
437
/* HostPtr=*/ nullptr ,
@@ -501,6 +506,15 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
501
506
getNumBlocks (GenericDevice, KernelArgs.NumTeams , KernelArgs.Tripcount ,
502
507
NumThreads, KernelArgs.ThreadLimit [0 ] > 0 );
503
508
509
+ // Record the kernel description after we modified the argument count and num
510
+ // blocks/threads.
511
+ if (RecordReplay.isRecording ()) {
512
+ RecordReplay.saveImage (getName (), getImage ());
513
+ RecordReplay.saveKernelInput (getName (), getImage ());
514
+ RecordReplay.saveKernelDescr (getName (), Ptrs.data (), KernelArgs.NumArgs ,
515
+ NumBlocks, NumThreads, KernelArgs.Tripcount );
516
+ }
517
+
504
518
if (auto Err =
505
519
printLaunchInfo (GenericDevice, KernelArgs, NumThreads, NumBlocks))
506
520
return Err;
@@ -517,16 +531,20 @@ void *GenericKernelTy::prepareArgs(
517
531
if (isCtorOrDtor ())
518
532
return nullptr ;
519
533
520
- NumArgs += 1 ;
534
+ uint32_t KLEOffset = !!KernelLaunchEnvironment;
535
+ NumArgs += KLEOffset;
521
536
522
537
Args.resize (NumArgs);
523
538
Ptrs.resize (NumArgs);
524
539
525
- Ptrs[0 ] = KernelLaunchEnvironment;
526
- Args[0 ] = &Ptrs[0 ];
540
+ if (KernelLaunchEnvironment) {
541
+ Ptrs[0 ] = KernelLaunchEnvironment;
542
+ Args[0 ] = &Ptrs[0 ];
543
+ }
527
544
528
- for (int I = 1 ; I < NumArgs; ++I) {
529
- Ptrs[I] = (void *)((intptr_t )ArgPtrs[I - 1 ] + ArgOffsets[I - 1 ]);
545
+ for (int I = KLEOffset; I < NumArgs; ++I) {
546
+ Ptrs[I] =
547
+ (void *)((intptr_t )ArgPtrs[I - KLEOffset] + ArgOffsets[I - KLEOffset]);
530
548
Args[I] = &Ptrs[I];
531
549
}
532
550
return &Args[0 ];
@@ -808,7 +826,7 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
808
826
return std::move (Err);
809
827
810
828
// Setup the global device memory pool if needed.
811
- if (shouldSetupDeviceMemoryPool ()) {
829
+ if (!RecordReplay. isReplaying () && shouldSetupDeviceMemoryPool ()) {
812
830
uint64_t HeapSize;
813
831
auto SizeOrErr = getDeviceHeapSize (HeapSize);
814
832
if (SizeOrErr) {
@@ -1413,21 +1431,9 @@ Error GenericDeviceTy::launchKernel(void *EntryPtr, void **ArgPtrs,
1413
1431
GenericKernelTy &GenericKernel =
1414
1432
*reinterpret_cast <GenericKernelTy *>(EntryPtr);
1415
1433
1416
- if (RecordReplay.isRecording ()) {
1417
- RecordReplay.saveImage (GenericKernel.getName (), GenericKernel.getImage ());
1418
- RecordReplay.saveKernelInput (GenericKernel.getName (),
1419
- GenericKernel.getImage ());
1420
- }
1421
-
1422
1434
auto Err = GenericKernel.launch (*this , ArgPtrs, ArgOffsets, KernelArgs,
1423
1435
AsyncInfoWrapper);
1424
1436
1425
- if (RecordReplay.isRecording ())
1426
- RecordReplay.saveKernelDescr (GenericKernel.getName (), ArgPtrs, ArgOffsets,
1427
- KernelArgs.NumArgs , KernelArgs.NumTeams [0 ],
1428
- KernelArgs.ThreadLimit [0 ],
1429
- KernelArgs.Tripcount );
1430
-
1431
1437
// 'finalize' here to guarantee next record-replay actions are in-sync
1432
1438
AsyncInfoWrapper.finalize (Err);
1433
1439
0 commit comments