Skip to content

Commit 1128a4f

Browse files
authored
[HLSL] Don't use CreateRuntimeFunction for intrinsics (#145334)
HLSL uses CreateRuntimeFunction for three intrinsics. This is pretty unusual thing to do, and doesn't match what the rest of the file does. I suspect this might be because these are convergent calls, but the intrinsics themselves are already marked convergent, so it's not necessary for clang to manually add the attribute. This does lose the spir_func CC on the intrinsic declaration, but again, CC should not be relevant to intrinsics at all.
1 parent 9f7567d commit 1128a4f

File tree

4 files changed

+21
-40
lines changed

4 files changed

+21
-40
lines changed

clang/lib/CodeGen/CGHLSLBuiltins.cpp

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -676,35 +676,23 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
676676
case Builtin::BI__builtin_hlsl_wave_active_sum: {
677677
// Due to the use of variadic arguments, explicitly retreive argument
678678
Value *OpExpr = EmitScalarExpr(E->getArg(0));
679-
llvm::FunctionType *FT = llvm::FunctionType::get(
680-
OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
681679
Intrinsic::ID IID = getWaveActiveSumIntrinsic(
682680
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
683681
E->getArg(0)->getType());
684682

685-
// Get overloaded name
686-
std::string Name =
687-
Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
688-
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
689-
/*Local=*/false,
690-
/*AssumeConvergent=*/true),
683+
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
684+
&CGM.getModule(), IID, {OpExpr->getType()}),
691685
ArrayRef{OpExpr}, "hlsl.wave.active.sum");
692686
}
693687
case Builtin::BI__builtin_hlsl_wave_active_max: {
694688
// Due to the use of variadic arguments, explicitly retreive argument
695689
Value *OpExpr = EmitScalarExpr(E->getArg(0));
696-
llvm::FunctionType *FT = llvm::FunctionType::get(
697-
OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
698690
Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
699691
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
700692
E->getArg(0)->getType());
701693

702-
// Get overloaded name
703-
std::string Name =
704-
Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
705-
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
706-
/*Local=*/false,
707-
/*AssumeConvergent=*/true),
694+
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
695+
&CGM.getModule(), IID, {OpExpr->getType()}),
708696
ArrayRef{OpExpr}, "hlsl.wave.active.max");
709697
}
710698
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
@@ -739,18 +727,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
739727
// create our function type.
740728
Value *OpExpr = EmitScalarExpr(E->getArg(0));
741729
Value *OpIndex = EmitScalarExpr(E->getArg(1));
742-
llvm::FunctionType *FT = llvm::FunctionType::get(
743-
OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
744-
false);
745-
746-
// Get overloaded name
747-
std::string Name =
748-
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
749-
ArrayRef{OpExpr->getType()}, &CGM.getModule());
750-
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
751-
/*Local=*/false,
752-
/*AssumeConvergent=*/true),
753-
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
730+
return EmitRuntimeCall(
731+
Intrinsic::getOrInsertDeclaration(
732+
&CGM.getModule(), CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
733+
{OpExpr->getType()}),
734+
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
754735
}
755736
case Builtin::BI__builtin_hlsl_elementwise_sign: {
756737
auto *Arg0 = E->getArg(0);

clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int test_int(int expr) {
1616
}
1717

1818
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.max.i32([[TY]]) #[[#attr:]]
19-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.max.i32([[TY]]) #[[#attr:]]
19+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.max.i32([[TY]]) #[[#attr:]]
2020

2121
// CHECK-LABEL: test_uint64_t
2222
uint64_t test_uint64_t(uint64_t expr) {
@@ -27,7 +27,7 @@ uint64_t test_uint64_t(uint64_t expr) {
2727
}
2828

2929
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
30-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
30+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
3131

3232
// Test basic lowering to runtime function call with array and float value.
3333

@@ -40,7 +40,7 @@ float4 test_floatv4(float4 expr) {
4040
}
4141

4242
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
43-
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
43+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
4444

4545
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
4646

clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int test_int(int expr) {
1616
}
1717

1818
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.sum.i32([[TY]]) #[[#attr:]]
19-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.sum.i32([[TY]]) #[[#attr:]]
19+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.sum.i32([[TY]]) #[[#attr:]]
2020

2121
// CHECK-LABEL: test_uint64_t
2222
uint64_t test_uint64_t(uint64_t expr) {
@@ -27,7 +27,7 @@ uint64_t test_uint64_t(uint64_t expr) {
2727
}
2828

2929
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.usum.i64([[TY]]) #[[#attr:]]
30-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.sum.i64([[TY]]) #[[#attr:]]
30+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.sum.i64([[TY]]) #[[#attr:]]
3131

3232
// Test basic lowering to runtime function call with array and float value.
3333

@@ -40,6 +40,6 @@ float4 test_floatv4(float4 expr) {
4040
}
4141

4242
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.sum.v4f32([[TY1]]) #[[#attr]]
43-
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.reduce.sum.v4f32([[TY1]]) #[[#attr]]
43+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.reduce.sum.v4f32([[TY1]]) #[[#attr]]
4444

4545
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}

clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ int test_int(int expr, uint idx) {
1717
}
1818

1919
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i32([[TY]], i32) #[[#attr:]]
20-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.i32([[TY]], i32) #[[#attr:]]
20+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i32([[TY]], i32) #[[#attr:]]
2121

2222
// CHECK-LABEL: test_uint
2323
uint test_uint(uint expr, uint idx) {
@@ -38,7 +38,7 @@ int64_t test_int64_t(int64_t expr, uint idx) {
3838
}
3939

4040
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i64([[TY]], i32) #[[#attr:]]
41-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.i64([[TY]], i32) #[[#attr:]]
41+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i64([[TY]], i32) #[[#attr:]]
4242

4343
// CHECK-LABEL: test_uint64_t
4444
uint64_t test_uint64_t(uint64_t expr, uint idx) {
@@ -60,7 +60,7 @@ int16_t test_int16(int16_t expr, uint idx) {
6060
}
6161

6262
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i16([[TY]], i32) #[[#attr:]]
63-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.i16([[TY]], i32) #[[#attr:]]
63+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i16([[TY]], i32) #[[#attr:]]
6464

6565
// CHECK-LABEL: test_uint16
6666
uint16_t test_uint16(uint16_t expr, uint idx) {
@@ -84,7 +84,7 @@ half test_half(half expr, uint idx) {
8484
}
8585

8686
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.f16([[TY]], i32) #[[#attr:]]
87-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.f16([[TY]], i32) #[[#attr:]]
87+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.f16([[TY]], i32) #[[#attr:]]
8888

8989
// CHECK-LABEL: test_double
9090
double test_double(double expr, uint idx) {
@@ -96,7 +96,7 @@ double test_double(double expr, uint idx) {
9696
}
9797

9898
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.f64([[TY]], i32) #[[#attr:]]
99-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.f64([[TY]], i32) #[[#attr:]]
99+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.f64([[TY]], i32) #[[#attr:]]
100100

101101
// CHECK-LABEL: test_floatv4
102102
float4 test_floatv4(float4 expr, uint idx) {
@@ -108,6 +108,6 @@ float4 test_floatv4(float4 expr, uint idx) {
108108
}
109109

110110
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
111-
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
111+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
112112

113113
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}

0 commit comments

Comments
 (0)