Skip to content

Commit a56f802

Browse files
authored
[OpenMP][Offload] Resubmit - fix envar for setting teams per CU (llvm#1399)
2 parents f04d8d9 + f70ef61 commit a56f802

File tree

1 file changed

+8
-13
lines changed
  • offload/plugins-nextgen/amdgpu/src

1 file changed

+8
-13
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,12 +1100,7 @@ struct AMDGPUKernelTy : public GenericKernelTy {
11001100
if (GenericDevice.isFastReductionEnabled()) {
11011101
// When fast reduction is enabled, the number of teams is capped by
11021102
// the MaxCUMultiplier constant.
1103-
// When envar is enabled, use it for computing MaxNumGroup.
1104-
if (EnvarCUMultiplier > 0)
1105-
MaxNumGroups = DeviceNumCUs * EnvarCUMultiplier;
1106-
else
1107-
MaxNumGroups = DeviceNumCUs * llvm::omp::xteam_red::MaxCUMultiplier;
1108-
1103+
MaxNumGroups = DeviceNumCUs * llvm::omp::xteam_red::MaxCUMultiplier;
11091104
} else {
11101105
// When fast reduction is not enabled, the number of teams is capped
11111106
// by the metadata that clang CodeGen created. The number of teams
@@ -1116,13 +1111,7 @@ struct AMDGPUKernelTy : public GenericKernelTy {
11161111
// ConstWGSize is the block size that CodeGen used.
11171112
uint32_t CUMultiplier =
11181113
llvm::omp::xteam_red::getXteamRedCUMultiplier(ConstWGSize);
1119-
1120-
if (EnvarCUMultiplier > 0) {
1121-
MaxNumGroups =
1122-
DeviceNumCUs * std::min(CUMultiplier, EnvarCUMultiplier);
1123-
} else {
1124-
MaxNumGroups = DeviceNumCUs * CUMultiplier;
1125-
}
1114+
MaxNumGroups = DeviceNumCUs * CUMultiplier;
11261115
}
11271116

11281117
// If envar OMPX_XTEAMREDUCTION_OCCUPANCY_BASED_OPT is set and no
@@ -1177,6 +1166,12 @@ struct AMDGPUKernelTy : public GenericKernelTy {
11771166
}
11781167
NumGroups = DesiredNumGroups;
11791168
}
1169+
1170+
// Prefer OMPX_AdjustNumTeamsForXteamRedSmallBlockSize over
1171+
// OMPX_XTeamRedTeamsPerCU.
1172+
if (AdjustFactor == 0 && EnvarCUMultiplier > 0)
1173+
NumGroups = DeviceNumCUs * EnvarCUMultiplier;
1174+
11801175
NumGroups = std::min(NumGroups, MaxNumGroups);
11811176
NumGroups = std::min(NumGroups, NumGroupsFromTripCount);
11821177

0 commit comments

Comments
 (0)