Skip to content

Commit 9ce18af

Browse files
committed
[HLSL] Don't use CreateRuntimeFunction for intrinsics
HLSL uses CreateRuntimeFunction for two intrinsics. This is pretty weird thing to do, and doesn't match what the rest of the file does. I suspect this might be because these are convergent calls, but the intrinsics themselves are already marked convergent, so it's not necessary for clang to manually add the attribute.
1 parent 879a557 commit 9ce18af

File tree

4 files changed

+21
-40
lines changed

4 files changed

+21
-40
lines changed

clang/lib/CodeGen/CGHLSLBuiltins.cpp

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -676,35 +676,23 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
676676
case Builtin::BI__builtin_hlsl_wave_active_sum: {
677677
// Due to the use of variadic arguments, explicitly retreive argument
678678
Value *OpExpr = EmitScalarExpr(E->getArg(0));
679-
llvm::FunctionType *FT = llvm::FunctionType::get(
680-
OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
681679
Intrinsic::ID IID = getWaveActiveSumIntrinsic(
682680
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
683681
E->getArg(0)->getType());
684682

685-
// Get overloaded name
686-
std::string Name =
687-
Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
688-
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
689-
/*Local=*/false,
690-
/*AssumeConvergent=*/true),
683+
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
684+
&CGM.getModule(), IID, {OpExpr->getType()}),
691685
ArrayRef{OpExpr}, "hlsl.wave.active.sum");
692686
}
693687
case Builtin::BI__builtin_hlsl_wave_active_max: {
694688
// Due to the use of variadic arguments, explicitly retreive argument
695689
Value *OpExpr = EmitScalarExpr(E->getArg(0));
696-
llvm::FunctionType *FT = llvm::FunctionType::get(
697-
OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
698690
Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
699691
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
700692
E->getArg(0)->getType());
701693

702-
// Get overloaded name
703-
std::string Name =
704-
Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
705-
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
706-
/*Local=*/false,
707-
/*AssumeConvergent=*/true),
694+
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
695+
&CGM.getModule(), IID, {OpExpr->getType()}),
708696
ArrayRef{OpExpr}, "hlsl.wave.active.max");
709697
}
710698
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
@@ -739,18 +727,11 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
739727
// create our function type.
740728
Value *OpExpr = EmitScalarExpr(E->getArg(0));
741729
Value *OpIndex = EmitScalarExpr(E->getArg(1));
742-
llvm::FunctionType *FT = llvm::FunctionType::get(
743-
OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
744-
false);
745-
746-
// Get overloaded name
747-
std::string Name =
748-
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
749-
ArrayRef{OpExpr->getType()}, &CGM.getModule());
750-
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
751-
/*Local=*/false,
752-
/*AssumeConvergent=*/true),
753-
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
730+
return EmitRuntimeCall(
731+
Intrinsic::getOrInsertDeclaration(
732+
&CGM.getModule(), CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
733+
{OpExpr->getType()}),
734+
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
754735
}
755736
case Builtin::BI__builtin_hlsl_elementwise_sign: {
756737
auto *Arg0 = E->getArg(0);

clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int test_int(int expr) {
1616
}
1717

1818
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.max.i32([[TY]]) #[[#attr:]]
19-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.max.i32([[TY]]) #[[#attr:]]
19+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.max.i32([[TY]]) #[[#attr:]]
2020

2121
// CHECK-LABEL: test_uint64_t
2222
uint64_t test_uint64_t(uint64_t expr) {
@@ -27,7 +27,7 @@ uint64_t test_uint64_t(uint64_t expr) {
2727
}
2828

2929
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
30-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
30+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
3131

3232
// Test basic lowering to runtime function call with array and float value.
3333

@@ -40,7 +40,7 @@ float4 test_floatv4(float4 expr) {
4040
}
4141

4242
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
43-
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
43+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
4444

4545
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
4646

clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int test_int(int expr) {
1616
}
1717

1818
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.sum.i32([[TY]]) #[[#attr:]]
19-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.sum.i32([[TY]]) #[[#attr:]]
19+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.sum.i32([[TY]]) #[[#attr:]]
2020

2121
// CHECK-LABEL: test_uint64_t
2222
uint64_t test_uint64_t(uint64_t expr) {
@@ -27,7 +27,7 @@ uint64_t test_uint64_t(uint64_t expr) {
2727
}
2828

2929
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.usum.i64([[TY]]) #[[#attr:]]
30-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.sum.i64([[TY]]) #[[#attr:]]
30+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.sum.i64([[TY]]) #[[#attr:]]
3131

3232
// Test basic lowering to runtime function call with array and float value.
3333

@@ -40,6 +40,6 @@ float4 test_floatv4(float4 expr) {
4040
}
4141

4242
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.sum.v4f32([[TY1]]) #[[#attr]]
43-
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.reduce.sum.v4f32([[TY1]]) #[[#attr]]
43+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.reduce.sum.v4f32([[TY1]]) #[[#attr]]
4444

4545
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}

clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ int test_int(int expr, uint idx) {
1717
}
1818

1919
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i32([[TY]], i32) #[[#attr:]]
20-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.i32([[TY]], i32) #[[#attr:]]
20+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i32([[TY]], i32) #[[#attr:]]
2121

2222
// CHECK-LABEL: test_uint
2323
uint test_uint(uint expr, uint idx) {
@@ -38,7 +38,7 @@ int64_t test_int64_t(int64_t expr, uint idx) {
3838
}
3939

4040
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i64([[TY]], i32) #[[#attr:]]
41-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.i64([[TY]], i32) #[[#attr:]]
41+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i64([[TY]], i32) #[[#attr:]]
4242

4343
// CHECK-LABEL: test_uint64_t
4444
uint64_t test_uint64_t(uint64_t expr, uint idx) {
@@ -60,7 +60,7 @@ int16_t test_int16(int16_t expr, uint idx) {
6060
}
6161

6262
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i16([[TY]], i32) #[[#attr:]]
63-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.i16([[TY]], i32) #[[#attr:]]
63+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i16([[TY]], i32) #[[#attr:]]
6464

6565
// CHECK-LABEL: test_uint16
6666
uint16_t test_uint16(uint16_t expr, uint idx) {
@@ -84,7 +84,7 @@ half test_half(half expr, uint idx) {
8484
}
8585

8686
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.f16([[TY]], i32) #[[#attr:]]
87-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.f16([[TY]], i32) #[[#attr:]]
87+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.f16([[TY]], i32) #[[#attr:]]
8888

8989
// CHECK-LABEL: test_double
9090
double test_double(double expr, uint idx) {
@@ -96,7 +96,7 @@ double test_double(double expr, uint idx) {
9696
}
9797

9898
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.f64([[TY]], i32) #[[#attr:]]
99-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.readlane.f64([[TY]], i32) #[[#attr:]]
99+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.f64([[TY]], i32) #[[#attr:]]
100100

101101
// CHECK-LABEL: test_floatv4
102102
float4 test_floatv4(float4 expr, uint idx) {
@@ -108,6 +108,6 @@ float4 test_floatv4(float4 expr, uint idx) {
108108
}
109109

110110
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
111-
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
111+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
112112

113113
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}

0 commit comments

Comments
 (0)