Skip to content

Commit 6e81e16

Browse files
committed
[Offload][OMPX] Add the runtime support for multi-dim grid and block
1 parent d648eed commit 6e81e16

File tree

8 files changed

+152
-94
lines changed

8 files changed

+152
-94
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -559,15 +559,15 @@ struct AMDGPUKernelTy : public GenericKernelTy {
559559
}
560560

561561
/// Launch the AMDGPU kernel function.
562-
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
563-
uint64_t NumBlocks, KernelArgsTy &KernelArgs,
562+
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
563+
uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
564564
KernelLaunchParamsTy LaunchParams,
565565
AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
566566

567567
/// Print more elaborate kernel launch info for AMDGPU
568568
Error printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
569-
KernelArgsTy &KernelArgs, uint32_t NumThreads,
570-
uint64_t NumBlocks) const override;
569+
KernelArgsTy &KernelArgs, uint32_t NumThreads[3],
570+
uint32_t NumBlocks[3]) const override;
571571

572572
/// Get group and private segment kernel size.
573573
uint32_t getGroupSize() const { return GroupSize; }
@@ -719,7 +719,7 @@ struct AMDGPUQueueTy {
719719
/// Push a kernel launch to the queue. The kernel launch requires an output
720720
/// signal and can define an optional input signal (nullptr if none).
721721
Error pushKernelLaunch(const AMDGPUKernelTy &Kernel, void *KernelArgs,
722-
uint32_t NumThreads, uint64_t NumBlocks,
722+
uint32_t NumThreads[3], uint32_t NumBlocks[3],
723723
uint32_t GroupSize, uint64_t StackSize,
724724
AMDGPUSignalTy *OutputSignal,
725725
AMDGPUSignalTy *InputSignal) {
@@ -746,14 +746,18 @@ struct AMDGPUQueueTy {
746746
assert(Packet && "Invalid packet");
747747

748748
// The first 32 bits of the packet are written after the other fields
749-
uint16_t Setup = UINT16_C(1) << HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
750-
Packet->workgroup_size_x = NumThreads;
751-
Packet->workgroup_size_y = 1;
752-
Packet->workgroup_size_z = 1;
749+
uint16_t Dims = NumBlocks[2] * NumThreads[2] > 1
750+
? 3
751+
: 1 + (NumBlocks[1] * NumThreads[1] != 1);
752+
uint16_t Setup = UINT16_C(Dims)
753+
<< HSA_KERNEL_DISPATCH_PACKET_SETUP_DIMENSIONS;
754+
Packet->workgroup_size_x = NumThreads[0];
755+
Packet->workgroup_size_y = NumThreads[1];
756+
Packet->workgroup_size_z = NumThreads[2];
753757
Packet->reserved0 = 0;
754-
Packet->grid_size_x = NumBlocks * NumThreads;
755-
Packet->grid_size_y = 1;
756-
Packet->grid_size_z = 1;
758+
Packet->grid_size_x = NumBlocks[0] * NumThreads[0];
759+
Packet->grid_size_y = NumBlocks[1] * NumThreads[1];
760+
Packet->grid_size_z = NumBlocks[2] * NumThreads[2];
757761
Packet->private_segment_size =
758762
Kernel.usesDynamicStack() ? StackSize : Kernel.getPrivateSize();
759763
Packet->group_segment_size = GroupSize;
@@ -1240,7 +1244,7 @@ struct AMDGPUStreamTy {
12401244
/// the kernel finalizes. Once the kernel is finished, the stream will release
12411245
/// the kernel args buffer to the specified memory manager.
12421246
Error pushKernelLaunch(const AMDGPUKernelTy &Kernel, void *KernelArgs,
1243-
uint32_t NumThreads, uint64_t NumBlocks,
1247+
uint32_t NumThreads[3], uint32_t NumBlocks[3],
12441248
uint32_t GroupSize, uint64_t StackSize,
12451249
AMDGPUMemoryManagerTy &MemoryManager) {
12461250
if (Queue == nullptr)
@@ -2829,10 +2833,10 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
28292833
AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);
28302834

28312835
KernelArgsTy KernelArgs = {};
2832-
if (auto Err =
2833-
AMDGPUKernel.launchImpl(*this, /*NumThread=*/1u,
2834-
/*NumBlocks=*/1ul, KernelArgs,
2835-
KernelLaunchParamsTy{}, AsyncInfoWrapper))
2836+
uint32_t NumBlocksAndThreads[3] = {1u, 1u, 1u};
2837+
if (auto Err = AMDGPUKernel.launchImpl(
2838+
*this, NumBlocksAndThreads, NumBlocksAndThreads, KernelArgs,
2839+
KernelLaunchParamsTy{}, AsyncInfoWrapper))
28362840
return Err;
28372841

28382842
Error Err = Plugin::success();
@@ -3330,7 +3334,7 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
33303334
};
33313335

33323336
Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
3333-
uint32_t NumThreads, uint64_t NumBlocks,
3337+
uint32_t NumThreads[3], uint32_t NumBlocks[3],
33343338
KernelArgsTy &KernelArgs,
33353339
KernelLaunchParamsTy LaunchParams,
33363340
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
@@ -3387,13 +3391,15 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
33873391
// Only COV5 implicitargs needs to be set. COV4 implicitargs are not used.
33883392
if (ImplArgs &&
33893393
getImplicitArgsSize() == sizeof(hsa_utils::AMDGPUImplicitArgsTy)) {
3390-
ImplArgs->BlockCountX = NumBlocks;
3391-
ImplArgs->BlockCountY = 1;
3392-
ImplArgs->BlockCountZ = 1;
3393-
ImplArgs->GroupSizeX = NumThreads;
3394-
ImplArgs->GroupSizeY = 1;
3395-
ImplArgs->GroupSizeZ = 1;
3396-
ImplArgs->GridDims = 1;
3394+
ImplArgs->BlockCountX = NumBlocks[0];
3395+
ImplArgs->BlockCountY = NumBlocks[1];
3396+
ImplArgs->BlockCountZ = NumBlocks[2];
3397+
ImplArgs->GroupSizeX = NumThreads[0];
3398+
ImplArgs->GroupSizeY = NumThreads[1];
3399+
ImplArgs->GroupSizeZ = NumThreads[2];
3400+
ImplArgs->GridDims = NumBlocks[2] * NumThreads[2] > 1
3401+
? 3
3402+
: 1 + (NumBlocks[1] * NumThreads[1] != 1);
33973403
ImplArgs->DynamicLdsSize = KernelArgs.DynCGroupMem;
33983404
}
33993405

@@ -3404,8 +3410,8 @@ Error AMDGPUKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
34043410

34053411
Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
34063412
KernelArgsTy &KernelArgs,
3407-
uint32_t NumThreads,
3408-
uint64_t NumBlocks) const {
3413+
uint32_t NumThreads[3],
3414+
uint32_t NumBlocks[3]) const {
34093415
// Only do all this when the output is requested
34103416
if (!(getInfoLevel() & OMP_INFOTYPE_PLUGIN_KERNEL))
34113417
return Plugin::success();
@@ -3442,12 +3448,13 @@ Error AMDGPUKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
34423448
// S/VGPR Spill Count: how many S/VGPRs are spilled by the kernel
34433449
// Tripcount: loop tripcount for the kernel
34443450
INFO(OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId(),
3445-
"#Args: %d Teams x Thrds: %4lux%4u (MaxFlatWorkGroupSize: %u) LDS "
3451+
"#Args: %d Teams x Thrds: %4ux%4u (MaxFlatWorkGroupSize: %u) LDS "
34463452
"Usage: %uB #SGPRs/VGPRs: %u/%u #SGPR/VGPR Spills: %u/%u Tripcount: "
34473453
"%lu\n",
3448-
ArgNum, NumGroups, ThreadsPerGroup, MaxFlatWorkgroupSize,
3449-
GroupSegmentSize, SGPRCount, VGPRCount, SGPRSpillCount, VGPRSpillCount,
3450-
LoopTripCount);
3454+
ArgNum, NumGroups[0] * NumGroups[1] * NumGroups[2],
3455+
ThreadsPerGroup[0] * ThreadsPerGroup[1] * ThreadsPerGroup[2],
3456+
MaxFlatWorkgroupSize, GroupSegmentSize, SGPRCount, VGPRCount,
3457+
SGPRSpillCount, VGPRSpillCount, LoopTripCount);
34513458

34523459
return Plugin::success();
34533460
}

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,9 @@ struct GenericKernelTy {
265265
Error launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
266266
ptrdiff_t *ArgOffsets, KernelArgsTy &KernelArgs,
267267
AsyncInfoWrapperTy &AsyncInfoWrapper) const;
268-
virtual Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
269-
uint64_t NumBlocks, KernelArgsTy &KernelArgs,
268+
virtual Error launchImpl(GenericDeviceTy &GenericDevice,
269+
uint32_t NumThreads[3], uint32_t NumBlocks[3],
270+
KernelArgsTy &KernelArgs,
270271
KernelLaunchParamsTy LaunchParams,
271272
AsyncInfoWrapperTy &AsyncInfoWrapper) const = 0;
272273

@@ -316,15 +317,15 @@ struct GenericKernelTy {
316317

317318
/// Prints generic kernel launch information.
318319
Error printLaunchInfo(GenericDeviceTy &GenericDevice,
319-
KernelArgsTy &KernelArgs, uint32_t NumThreads,
320-
uint64_t NumBlocks) const;
320+
KernelArgsTy &KernelArgs, uint32_t NumThreads[3],
321+
uint32_t NumBlocks[3]) const;
321322

322323
/// Prints plugin-specific kernel launch information after generic kernel
323324
/// launch information
324325
virtual Error printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
325326
KernelArgsTy &KernelArgs,
326-
uint32_t NumThreads,
327-
uint64_t NumBlocks) const;
327+
uint32_t NumThreads[3],
328+
uint32_t NumBlocks[3]) const;
328329

329330
private:
330331
/// Prepare the arguments before launching the kernel.
@@ -337,15 +338,15 @@ struct GenericKernelTy {
337338

338339
/// Get the number of threads and blocks for the kernel based on the
339340
/// user-defined threads and block clauses.
340-
uint32_t getNumThreads(GenericDeviceTy &GenericDevice,
341-
uint32_t ThreadLimitClause[3]) const;
341+
void getNumThreads(GenericDeviceTy &GenericDevice,
342+
uint32_t ThreadLimitClause[3]) const;
342343

343344
/// The number of threads \p NumThreads can be adjusted by this method.
344345
/// \p IsNumThreadsFromUser is true is \p NumThreads is defined by user via
345346
/// thread_limit clause.
346-
uint64_t getNumBlocks(GenericDeviceTy &GenericDevice,
347-
uint32_t BlockLimitClause[3], uint64_t LoopTripCount,
348-
uint32_t &NumThreads, bool IsNumThreadsFromUser) const;
347+
void getNumBlocks(GenericDeviceTy &GenericDevice,
348+
uint32_t BlockLimitClause[3], uint64_t LoopTripCount,
349+
uint32_t &NumThreads, bool IsNumThreadsFromUser) const;
349350

350351
/// Indicate if the kernel works in Generic SPMD, Generic or SPMD mode.
351352
bool isGenericSPMDMode() const {

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -526,20 +526,21 @@ GenericKernelTy::getKernelLaunchEnvironment(
526526

527527
Error GenericKernelTy::printLaunchInfo(GenericDeviceTy &GenericDevice,
528528
KernelArgsTy &KernelArgs,
529-
uint32_t NumThreads,
530-
uint64_t NumBlocks) const {
531-
INFO(OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId(),
532-
"Launching kernel %s with %" PRIu64
533-
" blocks and %d threads in %s mode\n",
534-
getName(), NumBlocks, NumThreads, getExecutionModeName());
529+
uint32_t NumThreads[3],
530+
uint32_t NumBlocks[3]) const {
531+
if (!IsBareKernel) {
532+
INFO(OMP_INFOTYPE_PLUGIN_KERNEL, GenericDevice.getDeviceId(),
533+
"Launching kernel %s with %u blocks and %u threads in %s mode\n",
534+
getName(), NumBlocks[0], NumThreads[0], getExecutionModeName());
535+
}
535536
return printLaunchInfoDetails(GenericDevice, KernelArgs, NumThreads,
536537
NumBlocks);
537538
}
538539

539540
Error GenericKernelTy::printLaunchInfoDetails(GenericDeviceTy &GenericDevice,
540541
KernelArgsTy &KernelArgs,
541-
uint32_t NumThreads,
542-
uint64_t NumBlocks) const {
542+
uint32_t NumThreads[3],
543+
uint32_t NumBlocks[3]) const {
543544
return Plugin::success();
544545
}
545546

@@ -566,10 +567,14 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
566567
Args, Ptrs, *KernelLaunchEnvOrErr);
567568
}
568569

569-
uint32_t NumThreads = getNumThreads(GenericDevice, KernelArgs.ThreadLimit);
570-
uint64_t NumBlocks =
571-
getNumBlocks(GenericDevice, KernelArgs.NumTeams, KernelArgs.Tripcount,
572-
NumThreads, KernelArgs.ThreadLimit[0] > 0);
570+
uint32_t NumThreads[3] = {KernelArgs.ThreadLimit[0],
571+
KernelArgs.ThreadLimit[1],
572+
KernelArgs.ThreadLimit[2]};
573+
uint32_t NumBlocks[3] = {KernelArgs.NumTeams[0], KernelArgs.NumTeams[1],
574+
KernelArgs.NumTeams[2]};
575+
getNumThreads(GenericDevice, NumThreads);
576+
getNumBlocks(GenericDevice, NumBlocks, KernelArgs.Tripcount, NumThreads[0],
577+
NumThreads[0] > 0);
573578

574579
// Record the kernel description after we modified the argument count and num
575580
// blocks/threads.
@@ -578,7 +583,7 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
578583
RecordReplay.saveImage(getName(), getImage());
579584
RecordReplay.saveKernelInput(getName(), getImage());
580585
RecordReplay.saveKernelDescr(getName(), LaunchParams, KernelArgs.NumArgs,
581-
NumBlocks, NumThreads, KernelArgs.Tripcount);
586+
NumBlocks[0], NumThreads[0], KernelArgs.Tripcount);
582587
}
583588

584589
if (auto Err =
@@ -616,38 +621,37 @@ KernelLaunchParamsTy GenericKernelTy::prepareArgs(
616621
return KernelLaunchParamsTy{sizeof(void *) * NumArgs, &Args[0], &Ptrs[0]};
617622
}
618623

619-
uint32_t GenericKernelTy::getNumThreads(GenericDeviceTy &GenericDevice,
620-
uint32_t ThreadLimitClause[3]) const {
621-
assert(ThreadLimitClause[1] == 0 && ThreadLimitClause[2] == 0 &&
622-
"Multi dimensional launch not supported yet.");
623-
624-
if (IsBareKernel && ThreadLimitClause[0] > 0)
625-
return ThreadLimitClause[0];
624+
void GenericKernelTy::getNumThreads(GenericDeviceTy &GenericDevice,
625+
uint32_t ThreadLimitClause[3]) const {
626+
if (IsBareKernel)
627+
return;
626628

627629
if (ThreadLimitClause[0] > 0 && isGenericMode())
628630
ThreadLimitClause[0] += GenericDevice.getWarpSize();
629631

630-
return std::min(MaxNumThreads, (ThreadLimitClause[0] > 0)
631-
? ThreadLimitClause[0]
632-
: PreferredNumThreads);
632+
ThreadLimitClause[0] =
633+
std::min(MaxNumThreads, (ThreadLimitClause[0] > 0) ? ThreadLimitClause[0]
634+
: PreferredNumThreads);
633635
}
634636

635-
uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
637+
void GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
636638
uint32_t NumTeamsClause[3],
637639
uint64_t LoopTripCount,
638640
uint32_t &NumThreads,
639641
bool IsNumThreadsFromUser) const {
640642
assert(NumTeamsClause[1] == 0 && NumTeamsClause[2] == 0 &&
641643
"Multi dimensional launch not supported yet.");
642644

643-
if (IsBareKernel && NumTeamsClause[0] > 0)
644-
return NumTeamsClause[0];
645+
if (IsBareKernel)
646+
return;
645647

646648
if (NumTeamsClause[0] > 0) {
647649
// TODO: We need to honor any value and consequently allow more than the
648650
// block limit. For this we might need to start multiple kernels or let the
649651
// blocks start again until the requested number has been started.
650-
return std::min(NumTeamsClause[0], GenericDevice.getBlockLimit());
652+
NumTeamsClause[0] =
653+
std::min(NumTeamsClause[0], GenericDevice.getBlockLimit());
654+
return;
651655
}
652656

653657
uint64_t DefaultNumBlocks = GenericDevice.getDefaultNumBlocks();
@@ -719,7 +723,8 @@ uint64_t GenericKernelTy::getNumBlocks(GenericDeviceTy &GenericDevice,
719723
// If the loops are long running we rather reuse blocks than spawn too many.
720724
if (GenericDevice.getReuseBlocksForHighTripCount())
721725
PreferredNumBlocks = std::min(TripCountNumBlocks, DefaultNumBlocks);
722-
return std::min(PreferredNumBlocks, GenericDevice.getBlockLimit());
726+
NumTeamsClause[0] =
727+
std::min(PreferredNumBlocks, GenericDevice.getBlockLimit());
723728
}
724729

725730
GenericDeviceTy::GenericDeviceTy(GenericPluginTy &Plugin, int32_t DeviceId,

offload/plugins-nextgen/cuda/src/rtl.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ struct CUDAKernelTy : public GenericKernelTy {
149149
}
150150

151151
/// Launch the CUDA kernel function.
152-
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
153-
uint64_t NumBlocks, KernelArgsTy &KernelArgs,
152+
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
153+
uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
154154
KernelLaunchParamsTy LaunchParams,
155155
AsyncInfoWrapperTy &AsyncInfoWrapper) const override;
156156

@@ -1230,10 +1230,10 @@ struct CUDADeviceTy : public GenericDeviceTy {
12301230
AsyncInfoWrapperTy AsyncInfoWrapper(*this, nullptr);
12311231

12321232
KernelArgsTy KernelArgs = {};
1233-
if (auto Err =
1234-
CUDAKernel.launchImpl(*this, /*NumThread=*/1u,
1235-
/*NumBlocks=*/1ul, KernelArgs,
1236-
KernelLaunchParamsTy{}, AsyncInfoWrapper))
1233+
uint32_t NumBlocksAndThreads[3] = {1u, 1u, 1u};
1234+
if (auto Err = CUDAKernel.launchImpl(
1235+
*this, NumBlocksAndThreads, NumBlocksAndThreads, KernelArgs,
1236+
KernelLaunchParamsTy{}, AsyncInfoWrapper))
12371237
return Err;
12381238

12391239
Error Err = Plugin::success();
@@ -1276,7 +1276,7 @@ struct CUDADeviceTy : public GenericDeviceTy {
12761276
};
12771277

12781278
Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
1279-
uint32_t NumThreads, uint64_t NumBlocks,
1279+
uint32_t NumThreads[3], uint32_t NumBlocks[3],
12801280
KernelArgsTy &KernelArgs,
12811281
KernelLaunchParamsTy LaunchParams,
12821282
AsyncInfoWrapperTy &AsyncInfoWrapper) const {
@@ -1294,9 +1294,8 @@ Error CUDAKernelTy::launchImpl(GenericDeviceTy &GenericDevice,
12941294
reinterpret_cast<void *>(&LaunchParams.Size),
12951295
CU_LAUNCH_PARAM_END};
12961296

1297-
CUresult Res = cuLaunchKernel(Func, NumBlocks, /*gridDimY=*/1,
1298-
/*gridDimZ=*/1, NumThreads,
1299-
/*blockDimY=*/1, /*blockDimZ=*/1,
1297+
CUresult Res = cuLaunchKernel(Func, NumBlocks[0], NumBlocks[1], NumBlocks[2],
1298+
NumThreads[0], NumThreads[1], NumThreads[2],
13001299
MaxDynCGroupMem, Stream, nullptr, Config);
13011300
return Plugin::check(Res, "Error in cuLaunchKernel for '%s': %s", getName());
13021301
}

offload/plugins-nextgen/host/src/rtl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ struct GenELF64KernelTy : public GenericKernelTy {
8989
}
9090

9191
/// Launch the kernel using the libffi.
92-
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads,
93-
uint64_t NumBlocks, KernelArgsTy &KernelArgs,
92+
Error launchImpl(GenericDeviceTy &GenericDevice, uint32_t NumThreads[3],
93+
uint32_t NumBlocks[3], KernelArgsTy &KernelArgs,
9494
KernelLaunchParamsTy LaunchParams,
9595
AsyncInfoWrapperTy &AsyncInfoWrapper) const override {
9696
// Create a vector of ffi_types, one per argument.

offload/src/interface.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,11 @@ static KernelArgsTy *upgradeKernelArgs(KernelArgsTy *KernelArgs,
284284
LocalKernelArgs.Flags = KernelArgs->Flags;
285285
LocalKernelArgs.DynCGroupMem = 0;
286286
LocalKernelArgs.NumTeams[0] = NumTeams;
287-
LocalKernelArgs.NumTeams[1] = 0;
288-
LocalKernelArgs.NumTeams[2] = 0;
287+
LocalKernelArgs.NumTeams[1] = 1;
288+
LocalKernelArgs.NumTeams[2] = 1;
289289
LocalKernelArgs.ThreadLimit[0] = ThreadLimit;
290-
LocalKernelArgs.ThreadLimit[1] = 0;
291-
LocalKernelArgs.ThreadLimit[2] = 0;
290+
LocalKernelArgs.ThreadLimit[1] = 1;
291+
LocalKernelArgs.ThreadLimit[2] = 1;
292292
return &LocalKernelArgs;
293293
}
294294

@@ -320,12 +320,6 @@ static inline int targetKernel(ident_t *Loc, int64_t DeviceId, int32_t NumTeams,
320320
KernelArgs =
321321
upgradeKernelArgs(KernelArgs, LocalKernelArgs, NumTeams, ThreadLimit);
322322

323-
assert(KernelArgs->NumTeams[0] == static_cast<uint32_t>(NumTeams) &&
324-
!KernelArgs->NumTeams[1] && !KernelArgs->NumTeams[2] &&
325-
"OpenMP interface should not use multiple dimensions");
326-
assert(KernelArgs->ThreadLimit[0] == static_cast<uint32_t>(ThreadLimit) &&
327-
!KernelArgs->ThreadLimit[1] && !KernelArgs->ThreadLimit[2] &&
328-
"OpenMP interface should not use multiple dimensions");
329323
TIMESCOPE_WITH_DETAILS_AND_IDENT(
330324
"Runtime: target exe",
331325
"NumTeams=" + std::to_string(NumTeams) +

0 commit comments

Comments
 (0)