Skip to content

Commit 4eb990b

Browse files
committed
[HLSL] Don't use CreateRuntimeFunction for intrinsics
HLSL uses CreateRuntimeFunction for two intrinsics. This is pretty weird thing to do, and doesn't match what the rest of the file does. I suspect this might be because these are convergent calls, but the intrinsics themselves are already marked convergent, so it's not necessary for clang to manually add the attribute.
1 parent 879a557 commit 4eb990b

File tree

3 files changed

+10
-22
lines changed

3 files changed

+10
-22
lines changed

clang/lib/CodeGen/CGHLSLBuiltins.cpp

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -676,35 +676,23 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
676676
case Builtin::BI__builtin_hlsl_wave_active_sum: {
677677
// Due to the use of variadic arguments, explicitly retreive argument
678678
Value *OpExpr = EmitScalarExpr(E->getArg(0));
679-
llvm::FunctionType *FT = llvm::FunctionType::get(
680-
OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
681679
Intrinsic::ID IID = getWaveActiveSumIntrinsic(
682680
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
683681
E->getArg(0)->getType());
684682

685-
// Get overloaded name
686-
std::string Name =
687-
Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
688-
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
689-
/*Local=*/false,
690-
/*AssumeConvergent=*/true),
683+
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
684+
&CGM.getModule(), IID, {OpExpr->getType()}),
691685
ArrayRef{OpExpr}, "hlsl.wave.active.sum");
692686
}
693687
case Builtin::BI__builtin_hlsl_wave_active_max: {
694688
// Due to the use of variadic arguments, explicitly retreive argument
695689
Value *OpExpr = EmitScalarExpr(E->getArg(0));
696-
llvm::FunctionType *FT = llvm::FunctionType::get(
697-
OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
698690
Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
699691
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
700692
E->getArg(0)->getType());
701693

702-
// Get overloaded name
703-
std::string Name =
704-
Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
705-
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
706-
/*Local=*/false,
707-
/*AssumeConvergent=*/true),
694+
return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
695+
&CGM.getModule(), IID, {OpExpr->getType()}),
708696
ArrayRef{OpExpr}, "hlsl.wave.active.max");
709697
}
710698
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {

clang/test/CodeGenHLSL/builtins/WaveActiveMax.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int test_int(int expr) {
1616
}
1717

1818
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.max.i32([[TY]]) #[[#attr:]]
19-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.max.i32([[TY]]) #[[#attr:]]
19+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.max.i32([[TY]]) #[[#attr:]]
2020

2121
// CHECK-LABEL: test_uint64_t
2222
uint64_t test_uint64_t(uint64_t expr) {
@@ -27,7 +27,7 @@ uint64_t test_uint64_t(uint64_t expr) {
2727
}
2828

2929
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
30-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
30+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.umax.i64([[TY]]) #[[#attr:]]
3131

3232
// Test basic lowering to runtime function call with array and float value.
3333

@@ -40,7 +40,7 @@ float4 test_floatv4(float4 expr) {
4040
}
4141

4242
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
43-
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
43+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.reduce.max.v4f32([[TY1]]) #[[#attr]]
4444

4545
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
4646

clang/test/CodeGenHLSL/builtins/WaveActiveSum.hlsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int test_int(int expr) {
1616
}
1717

1818
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.sum.i32([[TY]]) #[[#attr:]]
19-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.sum.i32([[TY]]) #[[#attr:]]
19+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.sum.i32([[TY]]) #[[#attr:]]
2020

2121
// CHECK-LABEL: test_uint64_t
2222
uint64_t test_uint64_t(uint64_t expr) {
@@ -27,7 +27,7 @@ uint64_t test_uint64_t(uint64_t expr) {
2727
}
2828

2929
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.usum.i64([[TY]]) #[[#attr:]]
30-
// CHECK-SPIRV: declare spir_func [[TY]] @llvm.spv.wave.reduce.sum.i64([[TY]]) #[[#attr:]]
30+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.sum.i64([[TY]]) #[[#attr:]]
3131

3232
// Test basic lowering to runtime function call with array and float value.
3333

@@ -40,6 +40,6 @@ float4 test_floatv4(float4 expr) {
4040
}
4141

4242
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.reduce.sum.v4f32([[TY1]]) #[[#attr]]
43-
// CHECK-SPIRV: declare spir_func [[TY1]] @llvm.spv.wave.reduce.sum.v4f32([[TY1]]) #[[#attr]]
43+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.reduce.sum.v4f32([[TY1]]) #[[#attr]]
4444

4545
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}

0 commit comments

Comments
 (0)