Skip to content

Commit 8bcd336

Browse files
committed
Revert "[OpenMP][Offload] Fix envar for setting teams per cu"
This reverts commit 7ef5baf.
1 parent 67b946b commit 8bcd336

File tree

1 file changed

+13
-8
lines changed
  • offload/plugins-nextgen/amdgpu/src

1 file changed

+13
-8
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,7 +1100,12 @@ struct AMDGPUKernelTy : public GenericKernelTy {
11001100
if (GenericDevice.isFastReductionEnabled()) {
11011101
// When fast reduction is enabled, the number of teams is capped by
11021102
// the MaxCUMultiplier constant.
1103-
MaxNumGroups = DeviceNumCUs * llvm::omp::xteam_red::MaxCUMultiplier;
1103+
// When envar is enabled, use it for computing MaxNumGroup.
1104+
if (EnvarCUMultiplier > 0)
1105+
MaxNumGroups = DeviceNumCUs * EnvarCUMultiplier;
1106+
else
1107+
MaxNumGroups = DeviceNumCUs * llvm::omp::xteam_red::MaxCUMultiplier;
1108+
11041109
} else {
11051110
// When fast reduction is not enabled, the number of teams is capped
11061111
// by the metadata that clang CodeGen created. The number of teams
@@ -1111,7 +1116,13 @@ struct AMDGPUKernelTy : public GenericKernelTy {
11111116
// ConstWGSize is the block size that CodeGen used.
11121117
uint32_t CUMultiplier =
11131118
llvm::omp::xteam_red::getXteamRedCUMultiplier(ConstWGSize);
1114-
MaxNumGroups = DeviceNumCUs * CUMultiplier;
1119+
1120+
if (EnvarCUMultiplier > 0) {
1121+
MaxNumGroups =
1122+
DeviceNumCUs * std::min(CUMultiplier, EnvarCUMultiplier);
1123+
} else {
1124+
MaxNumGroups = DeviceNumCUs * CUMultiplier;
1125+
}
11151126
}
11161127

11171128
// If envar OMPX_XTEAMREDUCTION_OCCUPANCY_BASED_OPT is set and no
@@ -1166,12 +1177,6 @@ struct AMDGPUKernelTy : public GenericKernelTy {
11661177
}
11671178
NumGroups = DesiredNumGroups;
11681179
}
1169-
1170-
// Prefer OMPX_AdjustNumTeamsForXteamRedSmallBlockSize over
1171-
// OMPX_XTeamRedTeamsPerCU.
1172-
if (AdjustFactor == 0 && EnvarCUMultiplier > 0)
1173-
NumGroups = DeviceNumCUs * EnvarCUMultiplier;
1174-
11751180
NumGroups = std::min(NumGroups, MaxNumGroups);
11761181
NumGroups = std::min(NumGroups, NumGroupsFromTripCount);
11771182

0 commit comments

Comments
 (0)