Skip to content

Commit 51f90b7

Browse files
dhruvachakronlieb
authored andcommitted
[clang] [offload] [XteamReduction] Do not use blocksize any more to choose the DeviceRTL entrypoint.
Till now, for XteamReduction, the compiler would use the blocksize to choose the runtime entrypoint to call. For example, for blocksize of 256, an entrypoint with the prefix 4x64 would be chosen for wavefront size of 64. 4x64 indicates a maximum of 4 wavefronts, each of size 64. Unfortunately, this approach leads to a large number of entrypoints and maintaining them is hard. This patch reduces the number of entrypoints manyfold by keeping support for only 1 prefix per wavefront size. Specifically, only 16x64 and 32x32 prefixes will be supported, the rest are removed. This implies a maximum of 16 wavefronts for wavefront size 64 and a maximum of 32 wavefronts for wavefront size 32. Hence, these entrypoints are now independent of the blocksize known at compile-time. Note that users can still change the blocksize by using OpenMP clauses, compile-time options, or runtime environment variables --- none of that support is changing. The only downside of the new approach is that LDS will now be allocated by DeviceRTL based on the maximum number of wavefronts, either 16 or 32. When the default XteamReduction blocksize of 1024 is used, the LDS allocation is still optimal. But if the blocksize is reduced by one of the ways mentioned above, there will be LDS over-allocation. However, given that the absolute size of this LDS usage is small, the over-allocation should not be a problem. The total number of entrypoints reduced from 264 to 48. Change-Id: I06b0232a3103180c0102191a47739ffd042f4508
1 parent 76c0736 commit 51f90b7

File tree

7 files changed

+166
-2710
lines changed

7 files changed

+166
-2710
lines changed

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 38 additions & 300 deletions
Original file line numberDiff line numberDiff line change
@@ -3904,325 +3904,63 @@ llvm::Value *CGOpenMPRuntimeGPU::getXteamRedSum(
39043904
if (SumType->isIntegerTy()) {
39053905
if (SumType->getPrimitiveSizeInBits() == 32) {
39063906
if (WarpSize == 32) {
3907-
switch (BlockSize) {
3908-
default:
3909-
return CGF.EmitRuntimeCall(
3910-
OMPBuilder.getOrCreateRuntimeFunction(
3911-
CGM.getModule(), IsFast
3912-
? OMPRTL___kmpc_xteamr_ui_1x32_fast_sum
3913-
: OMPRTL___kmpc_xteamr_ui_1x32),
3914-
Args);
3915-
case 64:
3916-
return CGF.EmitRuntimeCall(
3917-
OMPBuilder.getOrCreateRuntimeFunction(
3918-
CGM.getModule(), IsFast
3919-
? OMPRTL___kmpc_xteamr_ui_2x32_fast_sum
3920-
: OMPRTL___kmpc_xteamr_ui_2x32),
3921-
Args);
3922-
case 128:
3923-
return CGF.EmitRuntimeCall(
3924-
OMPBuilder.getOrCreateRuntimeFunction(
3925-
CGM.getModule(), IsFast
3926-
? OMPRTL___kmpc_xteamr_ui_4x32_fast_sum
3927-
: OMPRTL___kmpc_xteamr_ui_4x32),
3928-
Args);
3929-
case 256:
3930-
return CGF.EmitRuntimeCall(
3931-
OMPBuilder.getOrCreateRuntimeFunction(
3932-
CGM.getModule(), IsFast
3933-
? OMPRTL___kmpc_xteamr_ui_8x32_fast_sum
3934-
: OMPRTL___kmpc_xteamr_ui_8x32),
3935-
Args);
3936-
case 512:
3937-
return CGF.EmitRuntimeCall(
3938-
OMPBuilder.getOrCreateRuntimeFunction(
3939-
CGM.getModule(), IsFast
3940-
? OMPRTL___kmpc_xteamr_ui_16x32_fast_sum
3941-
: OMPRTL___kmpc_xteamr_ui_16x32),
3942-
Args);
3943-
case 1024:
3944-
return CGF.EmitRuntimeCall(
3945-
OMPBuilder.getOrCreateRuntimeFunction(
3946-
CGM.getModule(), IsFast
3947-
? OMPRTL___kmpc_xteamr_ui_32x32_fast_sum
3948-
: OMPRTL___kmpc_xteamr_ui_32x32),
3949-
Args);
3950-
}
3951-
} else {
3952-
switch (BlockSize) {
3953-
default:
3954-
return CGF.EmitRuntimeCall(
3955-
OMPBuilder.getOrCreateRuntimeFunction(
3956-
CGM.getModule(), IsFast
3957-
? OMPRTL___kmpc_xteamr_ui_1x64_fast_sum
3958-
: OMPRTL___kmpc_xteamr_ui_1x64),
3959-
Args);
3960-
case 128:
3961-
return CGF.EmitRuntimeCall(
3962-
OMPBuilder.getOrCreateRuntimeFunction(
3963-
CGM.getModule(), IsFast
3964-
? OMPRTL___kmpc_xteamr_ui_2x64_fast_sum
3965-
: OMPRTL___kmpc_xteamr_ui_2x64),
3966-
Args);
3967-
case 256:
3968-
return CGF.EmitRuntimeCall(
3969-
OMPBuilder.getOrCreateRuntimeFunction(
3970-
CGM.getModule(), IsFast
3971-
? OMPRTL___kmpc_xteamr_ui_4x64_fast_sum
3972-
: OMPRTL___kmpc_xteamr_ui_4x64),
3973-
Args);
3974-
case 512:
3975-
return CGF.EmitRuntimeCall(
3976-
OMPBuilder.getOrCreateRuntimeFunction(
3977-
CGM.getModule(), IsFast
3978-
? OMPRTL___kmpc_xteamr_ui_8x64_fast_sum
3979-
: OMPRTL___kmpc_xteamr_ui_8x64),
3980-
Args);
3981-
case 1024:
3982-
return CGF.EmitRuntimeCall(
3983-
OMPBuilder.getOrCreateRuntimeFunction(
3984-
CGM.getModule(), IsFast
3985-
? OMPRTL___kmpc_xteamr_ui_16x64_fast_sum
3986-
: OMPRTL___kmpc_xteamr_ui_16x64),
3987-
Args);
3988-
}
3989-
}
3990-
}
3991-
if (SumType->getPrimitiveSizeInBits() == 64) {
3992-
if (WarpSize == 32) {
3993-
switch (BlockSize) {
3994-
default:
3995-
return CGF.EmitRuntimeCall(
3996-
OMPBuilder.getOrCreateRuntimeFunction(
3997-
CGM.getModule(), IsFast
3998-
? OMPRTL___kmpc_xteamr_ul_1x32_fast_sum
3999-
: OMPRTL___kmpc_xteamr_ul_1x32),
4000-
Args);
4001-
case 64:
4002-
return CGF.EmitRuntimeCall(
4003-
OMPBuilder.getOrCreateRuntimeFunction(
4004-
CGM.getModule(), IsFast
4005-
? OMPRTL___kmpc_xteamr_ul_2x32_fast_sum
4006-
: OMPRTL___kmpc_xteamr_ul_2x32),
4007-
Args);
4008-
case 128:
4009-
return CGF.EmitRuntimeCall(
4010-
OMPBuilder.getOrCreateRuntimeFunction(
4011-
CGM.getModule(), IsFast
4012-
? OMPRTL___kmpc_xteamr_ul_4x32_fast_sum
4013-
: OMPRTL___kmpc_xteamr_ul_4x32),
4014-
Args);
4015-
case 256:
4016-
return CGF.EmitRuntimeCall(
4017-
OMPBuilder.getOrCreateRuntimeFunction(
4018-
CGM.getModule(), IsFast
4019-
? OMPRTL___kmpc_xteamr_ul_8x32_fast_sum
4020-
: OMPRTL___kmpc_xteamr_ul_8x32),
4021-
Args);
4022-
case 512:
4023-
return CGF.EmitRuntimeCall(
4024-
OMPBuilder.getOrCreateRuntimeFunction(
4025-
CGM.getModule(), IsFast
4026-
? OMPRTL___kmpc_xteamr_ul_16x32_fast_sum
4027-
: OMPRTL___kmpc_xteamr_ul_16x32),
4028-
Args);
4029-
case 1024:
4030-
return CGF.EmitRuntimeCall(
4031-
OMPBuilder.getOrCreateRuntimeFunction(
4032-
CGM.getModule(), IsFast
4033-
? OMPRTL___kmpc_xteamr_ul_32x32_fast_sum
4034-
: OMPRTL___kmpc_xteamr_ul_32x32),
4035-
Args);
4036-
}
4037-
} else {
4038-
switch (BlockSize) {
4039-
default:
4040-
return CGF.EmitRuntimeCall(
4041-
OMPBuilder.getOrCreateRuntimeFunction(
4042-
CGM.getModule(), IsFast
4043-
? OMPRTL___kmpc_xteamr_ul_1x64_fast_sum
4044-
: OMPRTL___kmpc_xteamr_ul_1x64),
4045-
Args);
4046-
case 128:
4047-
return CGF.EmitRuntimeCall(
4048-
OMPBuilder.getOrCreateRuntimeFunction(
4049-
CGM.getModule(), IsFast
4050-
? OMPRTL___kmpc_xteamr_ul_2x64_fast_sum
4051-
: OMPRTL___kmpc_xteamr_ul_2x64),
4052-
Args);
4053-
case 256:
4054-
return CGF.EmitRuntimeCall(
4055-
OMPBuilder.getOrCreateRuntimeFunction(
4056-
CGM.getModule(), IsFast
4057-
? OMPRTL___kmpc_xteamr_ul_4x64_fast_sum
4058-
: OMPRTL___kmpc_xteamr_ul_4x64),
4059-
Args);
4060-
case 512:
4061-
return CGF.EmitRuntimeCall(
4062-
OMPBuilder.getOrCreateRuntimeFunction(
4063-
CGM.getModule(), IsFast
4064-
? OMPRTL___kmpc_xteamr_ul_8x64_fast_sum
4065-
: OMPRTL___kmpc_xteamr_ul_8x64),
4066-
Args);
4067-
case 1024:
4068-
return CGF.EmitRuntimeCall(
4069-
OMPBuilder.getOrCreateRuntimeFunction(
4070-
CGM.getModule(), IsFast
4071-
? OMPRTL___kmpc_xteamr_ul_16x64_fast_sum
4072-
: OMPRTL___kmpc_xteamr_ul_16x64),
4073-
Args);
4074-
}
4075-
}
4076-
}
4077-
}
4078-
if (SumType->isFloatTy()) {
4079-
if (WarpSize == 32) {
4080-
switch (BlockSize) {
4081-
default:
4082-
return CGF.EmitRuntimeCall(
4083-
OMPBuilder.getOrCreateRuntimeFunction(
4084-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_1x32_fast_sum
4085-
: OMPRTL___kmpc_xteamr_f_1x32),
4086-
Args);
4087-
case 64:
4088-
return CGF.EmitRuntimeCall(
4089-
OMPBuilder.getOrCreateRuntimeFunction(
4090-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_2x32_fast_sum
4091-
: OMPRTL___kmpc_xteamr_f_2x32),
4092-
Args);
4093-
case 128:
4094-
return CGF.EmitRuntimeCall(
4095-
OMPBuilder.getOrCreateRuntimeFunction(
4096-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_4x32_fast_sum
4097-
: OMPRTL___kmpc_xteamr_f_4x32),
4098-
Args);
4099-
case 256:
4100-
return CGF.EmitRuntimeCall(
4101-
OMPBuilder.getOrCreateRuntimeFunction(
4102-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_8x32_fast_sum
4103-
: OMPRTL___kmpc_xteamr_f_8x32),
4104-
Args);
4105-
case 512:
41063907
return CGF.EmitRuntimeCall(
41073908
OMPBuilder.getOrCreateRuntimeFunction(
4108-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_16x32_fast_sum
4109-
: OMPRTL___kmpc_xteamr_f_16x32),
3909+
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_ui_32x32_fast_sum
3910+
: OMPRTL___kmpc_xteamr_ui_32x32),
41103911
Args);
4111-
case 1024:
3912+
} else {
41123913
return CGF.EmitRuntimeCall(
41133914
OMPBuilder.getOrCreateRuntimeFunction(
4114-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_32x32_fast_sum
4115-
: OMPRTL___kmpc_xteamr_f_32x32),
3915+
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_ui_16x64_fast_sum
3916+
: OMPRTL___kmpc_xteamr_ui_16x64),
41163917
Args);
41173918
}
4118-
} else {
4119-
switch (BlockSize) {
4120-
default:
4121-
return CGF.EmitRuntimeCall(
4122-
OMPBuilder.getOrCreateRuntimeFunction(
4123-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_1x64_fast_sum
4124-
: OMPRTL___kmpc_xteamr_f_1x64),
4125-
Args);
4126-
case 128:
4127-
return CGF.EmitRuntimeCall(
4128-
OMPBuilder.getOrCreateRuntimeFunction(
4129-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_2x64_fast_sum
4130-
: OMPRTL___kmpc_xteamr_f_2x64),
4131-
Args);
4132-
case 256:
4133-
return CGF.EmitRuntimeCall(
4134-
OMPBuilder.getOrCreateRuntimeFunction(
4135-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_4x64_fast_sum
4136-
: OMPRTL___kmpc_xteamr_f_4x64),
4137-
Args);
4138-
case 512:
3919+
}
3920+
if (SumType->getPrimitiveSizeInBits() == 64) {
3921+
if (WarpSize == 32) {
41393922
return CGF.EmitRuntimeCall(
41403923
OMPBuilder.getOrCreateRuntimeFunction(
4141-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_8x64_fast_sum
4142-
: OMPRTL___kmpc_xteamr_f_8x64),
3924+
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_ul_32x32_fast_sum
3925+
: OMPRTL___kmpc_xteamr_ul_32x32),
41433926
Args);
4144-
case 1024:
3927+
} else {
41453928
return CGF.EmitRuntimeCall(
41463929
OMPBuilder.getOrCreateRuntimeFunction(
4147-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_16x64_fast_sum
4148-
: OMPRTL___kmpc_xteamr_f_16x64),
3930+
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_ul_16x64_fast_sum
3931+
: OMPRTL___kmpc_xteamr_ul_16x64),
41493932
Args);
41503933
}
41513934
}
41523935
}
3936+
if (SumType->isFloatTy()) {
3937+
if (WarpSize == 32) {
3938+
return CGF.EmitRuntimeCall(
3939+
OMPBuilder.getOrCreateRuntimeFunction(
3940+
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_32x32_fast_sum
3941+
: OMPRTL___kmpc_xteamr_f_32x32),
3942+
Args);
3943+
} else {
3944+
return CGF.EmitRuntimeCall(
3945+
OMPBuilder.getOrCreateRuntimeFunction(
3946+
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_f_16x64_fast_sum
3947+
: OMPRTL___kmpc_xteamr_f_16x64),
3948+
Args);
3949+
}
3950+
}
41533951
if (SumType->isDoubleTy()) {
41543952
if (WarpSize == 32) {
4155-
switch (BlockSize) {
4156-
default:
4157-
return CGF.EmitRuntimeCall(
4158-
OMPBuilder.getOrCreateRuntimeFunction(
4159-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_1x32_fast_sum
4160-
: OMPRTL___kmpc_xteamr_d_1x32),
4161-
Args);
4162-
case 64:
4163-
return CGF.EmitRuntimeCall(
4164-
OMPBuilder.getOrCreateRuntimeFunction(
4165-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_2x32_fast_sum
4166-
: OMPRTL___kmpc_xteamr_d_2x32),
4167-
Args);
4168-
case 128:
4169-
return CGF.EmitRuntimeCall(
4170-
OMPBuilder.getOrCreateRuntimeFunction(
4171-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_4x32_fast_sum
4172-
: OMPRTL___kmpc_xteamr_d_4x32),
4173-
Args);
4174-
case 256:
4175-
return CGF.EmitRuntimeCall(
4176-
OMPBuilder.getOrCreateRuntimeFunction(
4177-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_8x32_fast_sum
4178-
: OMPRTL___kmpc_xteamr_d_8x32),
4179-
Args);
4180-
case 512:
4181-
return CGF.EmitRuntimeCall(
4182-
OMPBuilder.getOrCreateRuntimeFunction(
4183-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_16x32_fast_sum
4184-
: OMPRTL___kmpc_xteamr_d_16x32),
4185-
Args);
4186-
case 1024:
4187-
return CGF.EmitRuntimeCall(
4188-
OMPBuilder.getOrCreateRuntimeFunction(
4189-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_32x32_fast_sum
4190-
: OMPRTL___kmpc_xteamr_d_32x32),
4191-
Args);
4192-
}
3953+
return CGF.EmitRuntimeCall(
3954+
OMPBuilder.getOrCreateRuntimeFunction(
3955+
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_32x32_fast_sum
3956+
: OMPRTL___kmpc_xteamr_d_32x32),
3957+
Args);
41933958
} else {
4194-
switch (BlockSize) {
4195-
default:
4196-
return CGF.EmitRuntimeCall(
4197-
OMPBuilder.getOrCreateRuntimeFunction(
4198-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_1x64_fast_sum
4199-
: OMPRTL___kmpc_xteamr_d_1x64),
4200-
Args);
4201-
case 128:
4202-
return CGF.EmitRuntimeCall(
4203-
OMPBuilder.getOrCreateRuntimeFunction(
4204-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_2x64_fast_sum
4205-
: OMPRTL___kmpc_xteamr_d_2x64),
4206-
Args);
4207-
case 256:
4208-
return CGF.EmitRuntimeCall(
4209-
OMPBuilder.getOrCreateRuntimeFunction(
4210-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_4x64_fast_sum
4211-
: OMPRTL___kmpc_xteamr_d_4x64),
4212-
Args);
4213-
case 512:
4214-
return CGF.EmitRuntimeCall(
4215-
OMPBuilder.getOrCreateRuntimeFunction(
4216-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_8x64_fast_sum
4217-
: OMPRTL___kmpc_xteamr_d_8x64),
4218-
Args);
4219-
case 1024:
4220-
return CGF.EmitRuntimeCall(
4221-
OMPBuilder.getOrCreateRuntimeFunction(
4222-
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_16x64_fast_sum
4223-
: OMPRTL___kmpc_xteamr_d_16x64),
4224-
Args);
4225-
}
3959+
return CGF.EmitRuntimeCall(
3960+
OMPBuilder.getOrCreateRuntimeFunction(
3961+
CGM.getModule(), IsFast ? OMPRTL___kmpc_xteamr_d_16x64_fast_sum
3962+
: OMPRTL___kmpc_xteamr_d_16x64),
3963+
Args);
42263964
}
42273965
}
42283966
llvm_unreachable("No support for other types currently.");

clang/test/OpenMP/fast_red_codegen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ int main()
11281128
// CHECK-NEXT: [[TMP29:%.*]] = load ptr, ptr [[DOTADDR_ASCAST]], align 8
11291129
// CHECK-NEXT: [[TMP30:%.*]] = load ptr, ptr [[DOTADDR1_ASCAST]], align 8
11301130
// CHECK-NEXT: [[TMP31:%.*]] = load double, ptr addrspace(5) [[TMP5]], align 8
1131-
// CHECK-NEXT: call void @__kmpc_xteamr_d_8x64_fast_sum(double [[TMP31]], ptr [[TMP2]], ptr [[TMP29]], ptr [[TMP30]], ptr @__kmpc_rfun_sum_d, ptr @__kmpc_rfun_sum_lds_d, double 0.000000e+00, i64 [[TMP17]], i32 [[TMP16]])
1131+
// CHECK-NEXT: call void @__kmpc_xteamr_d_16x64_fast_sum(double [[TMP31]], ptr [[TMP2]], ptr [[TMP29]], ptr [[TMP30]], ptr @__kmpc_rfun_sum_d, ptr @__kmpc_rfun_sum_lds_d, double 0.000000e+00, i64 [[TMP17]], i32 [[TMP16]])
11321132
// CHECK-NEXT: ret void
11331133
//
11341134
//

clang/test/OpenMP/xteam_red_codegen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ int main()
11281128
// CHECK-NEXT: [[TMP29:%.*]] = load ptr, ptr [[DOTADDR_ASCAST]], align 8
11291129
// CHECK-NEXT: [[TMP30:%.*]] = load ptr, ptr [[DOTADDR1_ASCAST]], align 8
11301130
// CHECK-NEXT: [[TMP31:%.*]] = load double, ptr addrspace(5) [[TMP5]], align 8
1131-
// CHECK-NEXT: call void @__kmpc_xteamr_d_8x64(double [[TMP31]], ptr [[TMP2]], ptr [[TMP29]], ptr [[TMP30]], ptr @__kmpc_rfun_sum_d, ptr @__kmpc_rfun_sum_lds_d, double 0.000000e+00, i64 [[TMP17]], i32 [[TMP16]])
1131+
// CHECK-NEXT: call void @__kmpc_xteamr_d_16x64(double [[TMP31]], ptr [[TMP2]], ptr [[TMP29]], ptr [[TMP30]], ptr @__kmpc_rfun_sum_d, ptr @__kmpc_rfun_sum_lds_d, double 0.000000e+00, i64 [[TMP17]], i32 [[TMP16]])
11321132
// CHECK-NEXT: ret void
11331133
//
11341134
//

0 commit comments

Comments
 (0)