Skip to content

Commit c541955

Browse files
committed
[HLSL] Implement WaveActiveSum intrinsic
- add clang builtin to Builtins.td - link builtin in hlsl_intrinsics - add codegen for spirv intrinsic and two directx intrinsics to retain signedness information of the operands in CGBuiltin.cpp - add semantic analysis in SemaHLSL.cpp - add lowering of spirv intrinsic to spirv backend in SPIRVInstructionSelector.cpp - add directx intrinsic expansion to WaveActiveOp in DXILIntrinsicExpansion.cpp - add test cases to illustrate passes
1 parent 23309d7 commit c541955

File tree

13 files changed

+427
-0
lines changed

13 files changed

+427
-0
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4721,6 +4721,12 @@ def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
47214721
let Prototype = "unsigned int(bool)";
47224722
}
47234723

4724+
def HLSLWaveActiveSum : LangBuiltin<"HLSL_LANG"> {
4725+
let Spellings = ["__builtin_hlsl_wave_active_sum"];
4726+
let Attributes = [NoThrow, Const];
4727+
let Prototype = "void (...)";
4728+
}
4729+
47244730
def HLSLWaveGetLaneIndex : LangBuiltin<"HLSL_LANG"> {
47254731
let Spellings = ["__builtin_hlsl_wave_get_lane_index"];
47264732
let Attributes = [NoThrow, Const];

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9232,6 +9232,9 @@ def err_typecheck_cond_incompatible_operands : Error<
92329232
def err_typecheck_expect_scalar_or_vector : Error<
92339233
"invalid operand of type %0 where %1 or "
92349234
"a vector of such type is required">;
9235+
def err_typecheck_expect_scalar_or_vector_not_type : Error<
9236+
"invalid operand of type %0 where %1 or "
9237+
"a vector of such type is not allowed">;
92359238
def err_typecheck_expect_flt_or_vector : Error<
92369239
"invalid operand of type %0 where floating, complex or "
92379240
"a vector of such types is required">;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18631,6 +18631,23 @@ static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
1863118631
return RT.getUDotIntrinsic();
1863218632
}
1863318633

18634+
// Return wave active sum that corresponds to the QT scalar type
18635+
static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
18636+
CGHLSLRuntime &RT, QualType QT) {
18637+
switch (Arch) {
18638+
case llvm::Triple::spirv:
18639+
return llvm::Intrinsic::spv_wave_active_sum;
18640+
case llvm::Triple::dxil: {
18641+
if (QT->isUnsignedIntegerType())
18642+
return llvm::Intrinsic::dx_wave_active_usum;
18643+
return llvm::Intrinsic::dx_wave_active_sum;
18644+
}
18645+
default:
18646+
llvm_unreachable("Intrinsic WaveActiveSum"
18647+
" not supported by target architecture");
18648+
}
18649+
}
18650+
1863418651
Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
1863518652
const CallExpr *E,
1863618653
ReturnValueSlot ReturnValue) {
@@ -18866,6 +18883,23 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1886618883
/*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
1886718884
ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
1886818885
}
18886+
case Builtin::BI__builtin_hlsl_wave_active_sum: {
18887+
// Due to the use of variadic arguments, explicitly retreive argument
18888+
Value *OpExpr = EmitScalarExpr(E->getArg(0));
18889+
llvm::FunctionType *FT = llvm::FunctionType::get(
18890+
OpExpr->getType(), ArrayRef{OpExpr->getType()}, false);
18891+
Intrinsic::ID IID = getWaveActiveSumIntrinsic(
18892+
getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
18893+
E->getArg(0)->getType());
18894+
18895+
// Get overloaded name
18896+
std::string Name =
18897+
Intrinsic::getName(IID, ArrayRef{OpExpr->getType()}, &CGM.getModule());
18898+
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
18899+
/*Local=*/false,
18900+
/*AssumeConvergent=*/true),
18901+
ArrayRef{OpExpr}, "hlsl.wave.active.sum");
18902+
}
1886918903
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
1887018904
return EmitRuntimeCall(CGM.CreateRuntimeFunction(
1887118905
llvm::FunctionType::get(IntTy, {}, false), "__hlsl_wave_get_lane_index",

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2067,6 +2067,105 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
20672067
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
20682068
__attribute__((convergent)) bool WaveIsFirstLane();
20692069

2070+
//===----------------------------------------------------------------------===//
2071+
// WaveActiveSum builtins
2072+
//===----------------------------------------------------------------------===//
2073+
2074+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
2075+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2076+
__attribute((convergent)) half WaveActiveSum(half);
2077+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
2078+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2079+
__attribute((convergent)) half2 WaveActiveSum(half2);
2080+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
2081+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2082+
__attribute((convergent)) half3 WaveActiveSum(half3);
2083+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
2084+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2085+
__attribute((convergent)) half4 WaveActiveSum(half4);
2086+
2087+
#ifdef __HLSL_ENABLE_16_BIT
2088+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2089+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2090+
__attribute((convergent)) int16_t WaveActiveSum(int16_t);
2091+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2092+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2093+
__attribute((convergent)) int16_t2 WaveActiveSum(int16_t2);
2094+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2095+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2096+
__attribute((convergent)) int16_t3 WaveActiveSum(int16_t3);
2097+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2098+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2099+
__attribute((convergent)) int16_t4 WaveActiveSum(int16_t4);
2100+
2101+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2102+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2103+
__attribute((convergent)) uint16_t WaveActiveSum(uint16_t);
2104+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2105+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2106+
__attribute((convergent)) uint16_t2 WaveActiveSum(uint16_t2);
2107+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2108+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2109+
__attribute((convergent)) uint16_t3 WaveActiveSum(uint16_t3);
2110+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2111+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2112+
__attribute((convergent)) uint16_t4 WaveActiveSum(uint16_t4);
2113+
#endif
2114+
2115+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2116+
__attribute((convergent)) int WaveActiveSum(int);
2117+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2118+
__attribute((convergent)) int2 WaveActiveSum(int2);
2119+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2120+
__attribute((convergent)) int3 WaveActiveSum(int3);
2121+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2122+
__attribute((convergent)) int4 WaveActiveSum(int4);
2123+
2124+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2125+
__attribute((convergent)) uint WaveActiveSum(uint);
2126+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2127+
__attribute((convergent)) uint2 WaveActiveSum(uint2);
2128+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2129+
__attribute((convergent)) uint3 WaveActiveSum(uint3);
2130+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2131+
__attribute((convergent)) uint4 WaveActiveSum(uint4);
2132+
2133+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2134+
__attribute((convergent)) int64_t WaveActiveSum(int64_t);
2135+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2136+
__attribute((convergent)) int64_t2 WaveActiveSum(int64_t2);
2137+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2138+
__attribute((convergent)) int64_t3 WaveActiveSum(int64_t3);
2139+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2140+
__attribute((convergent)) int64_t4 WaveActiveSum(int64_t4);
2141+
2142+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2143+
__attribute((convergent)) uint64_t WaveActiveSum(uint64_t);
2144+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2145+
__attribute((convergent)) uint64_t2 WaveActiveSum(uint64_t2);
2146+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2147+
__attribute((convergent)) uint64_t3 WaveActiveSum(uint64_t3);
2148+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2149+
__attribute((convergent)) uint64_t4 WaveActiveSum(uint64_t4);
2150+
2151+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2152+
__attribute((convergent)) float WaveActiveSum(float);
2153+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2154+
__attribute((convergent)) float2 WaveActiveSum(float2);
2155+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2156+
__attribute((convergent)) float3 WaveActiveSum(float3);
2157+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2158+
__attribute((convergent)) float4 WaveActiveSum(float4);
2159+
2160+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2161+
__attribute((convergent)) double WaveActiveSum(double);
2162+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2163+
__attribute((convergent)) double2 WaveActiveSum(double2);
2164+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2165+
__attribute((convergent)) double3 WaveActiveSum(double3);
2166+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_sum)
2167+
__attribute((convergent)) double4 WaveActiveSum(double4);
2168+
20702169
//===----------------------------------------------------------------------===//
20712170
// sign builtins
20722171
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,23 @@ static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
17511751
return false;
17521752
}
17531753

1754+
static bool CheckScalarOrVectorNotType(Sema *S, CallExpr *TheCall,
1755+
QualType Scalar, unsigned ArgIndex) {
1756+
assert(TheCall->getNumArgs() >= ArgIndex);
1757+
QualType ArgType = TheCall->getArg(ArgIndex)->getType();
1758+
auto *VTy = ArgType->getAs<VectorType>();
1759+
// is the scalar or vector<scalar>
1760+
if (S->Context.hasSameUnqualifiedType(ArgType, Scalar) ||
1761+
(VTy &&
1762+
S->Context.hasSameUnqualifiedType(VTy->getElementType(), Scalar))) {
1763+
S->Diag(TheCall->getArg(0)->getBeginLoc(),
1764+
diag::err_typecheck_expect_scalar_or_vector_not_type)
1765+
<< ArgType << Scalar;
1766+
return true;
1767+
}
1768+
return false;
1769+
}
1770+
17541771
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
17551772
assert(TheCall->getNumArgs() == 3);
17561773
Expr *Arg1 = TheCall->getArg(1);
@@ -1985,6 +2002,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
19852002
TheCall->setType(ArgTyA);
19862003
break;
19872004
}
2005+
case Builtin::BI__builtin_hlsl_wave_active_sum: {
2006+
if (SemaRef.checkArgCount(TheCall, 1))
2007+
return true;
2008+
2009+
// Ensure input expr type is a scalar/vector and the same as the return type
2010+
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
2011+
return true;
2012+
if (CheckScalarOrVectorNotType(&SemaRef, TheCall, getASTContext().BoolTy,
2013+
0))
2014+
return true;
2015+
ExprResult Expr = TheCall->getArg(0);
2016+
QualType ArgTyExpr = Expr.get()->getType();
2017+
TheCall->setType(ArgTyExpr);
2018+
break;
2019+
}
19882020
// Note these are llvm builtins that we want to catch invalid intrinsic
19892021
// generation. Normal handling of these builitns will occur elsewhere.
19902022
case Builtin::BI__builtin_elementwise_bitreverse: {
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call.
9+
10+
// CHECK-LABEL: test_int
11+
int test_int(int expr) {
12+
// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()
13+
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.active.sum.i32([[TY]] %[[#]]) [ "convergencectrl"(token %[[#entry_tok]]) ]
14+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.active.sum.i32([[TY]] %[[#]])
15+
// CHECK: ret [[TY]] %[[RET]]
16+
return WaveActiveSum(expr);
17+
}
18+
19+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.active.sum.i32([[TY]]) #[[#attr:]]
20+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.active.sum.i32([[TY]]) #[[#attr:]]
21+
22+
// CHECK-LABEL: test_uint64_t
23+
uint64_t test_uint64_t(uint64_t expr) {
24+
// CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
25+
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.active.sum.i64([[TY]] %[[#]]) [ "convergencectrl"(token %[[#entry_tok1]]) ]
26+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.active.usum.i64([[TY]] %[[#]])
27+
// CHECK: ret [[TY]] %[[RET]]
28+
return WaveActiveSum(expr);
29+
}
30+
31+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.active.usum.i64([[TY]]) #[[#attr:]]
32+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.active.sum.i64([[TY]]) #[[#attr:]]
33+
34+
// Test basic lowering to runtime function call with array and float value.
35+
36+
// CHECK-LABEL: test_floatv4
37+
float4 test_floatv4(float4 expr) {
38+
// CHECK-SPIRV: %[[#entry_tok2:]] = call token @llvm.experimental.convergence.entry()
39+
// CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.active.sum.v4f32([[TY1]] %[[#]]) [ "convergencectrl"(token %[[#entry_tok2]]) ]
40+
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.active.sum.v4f32([[TY1]] %[[#]])
41+
// CHECK: ret [[TY1]] %[[RET1]]
42+
return WaveActiveSum(expr);
43+
}
44+
45+
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.active.sum.v4f32([[TY1]]) #[[#attr]]
46+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.active.sum.v4f32([[TY1]]) #[[#attr]]
47+
48+
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
int test_too_few_arg() {
4+
return __builtin_hlsl_wave_active_sum();
5+
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
6+
}
7+
8+
float2 test_too_many_arg(float2 p0) {
9+
return __builtin_hlsl_wave_active_sum(p0, p0);
10+
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
11+
}
12+
13+
bool test_expr_bool_type_check(bool p0) {
14+
return __builtin_hlsl_wave_active_sum(p0);
15+
// expected-error@-1 {{invalid operand of type 'bool' where 'bool' or a vector of such type is not allowed}}
16+
}
17+
18+
bool2 test_expr_bool_vec_type_check(bool2 p0) {
19+
return __builtin_hlsl_wave_active_sum(p0);
20+
// expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>') where 'bool' or a vector of such type is not allowed}}
21+
}
22+
23+
struct S { float f; };
24+
25+
S test_expr_struct_type_check(S p0) {
26+
return __builtin_hlsl_wave_active_sum(p0);
27+
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
28+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
8383
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
8484
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
8585
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
86+
def int_dx_wave_active_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
87+
def int_dx_wave_active_usum : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
8688
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8789
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
8890
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ let TargetPrefix = "spv" in {
8282
DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
8383
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
8484
[IntrNoMem, Commutative] >;
85+
def int_spv_wave_active_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
8586
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
8687
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
8788
def int_spv_radians : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;

llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ static bool isIntrinsicExpansion(Function &F) {
6565
case Intrinsic::dx_sign:
6666
case Intrinsic::dx_step:
6767
case Intrinsic::dx_radians:
68+
case Intrinsic::dx_wave_active_sum:
69+
case Intrinsic::dx_wave_active_usum:
6870
return true;
6971
}
7072
return false;
@@ -451,6 +453,19 @@ static Value *expandRadiansIntrinsic(CallInst *Orig) {
451453
return Builder.CreateFMul(X, PiOver180);
452454
}
453455

456+
template <int OpcodeVal, bool Signed>
457+
static Value *expandWaveActiveOpIntrinsic(CallInst *Orig) {
458+
Value *X = Orig->getOperand(0);
459+
Type *Ty = X->getType();
460+
461+
IRBuilder<> Builder(Orig);
462+
IntegerType *IntTy = IntegerType::get(Builder.getContext(), 8);
463+
Constant *Opcode = ConstantInt::get(IntTy, OpcodeVal);
464+
Constant *SOp = ConstantInt::get(IntTy, Signed ? 0 : 1);
465+
return Builder.CreateIntrinsic(Ty, Intrinsic::dx_wave_active_op,
466+
{X, Opcode, SOp}, nullptr, "dx.active.op");
467+
}
468+
454469
static Intrinsic::ID getMaxForClamp(Type *ElemTy,
455470
Intrinsic::ID ClampIntrinsic) {
456471
if (ClampIntrinsic == Intrinsic::dx_uclamp)
@@ -574,6 +589,12 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
574589
case Intrinsic::dx_radians:
575590
Result = expandRadiansIntrinsic(Orig);
576591
break;
592+
case Intrinsic::dx_wave_active_sum:
593+
Result = expandWaveActiveOpIntrinsic<0, true>(Orig);
594+
break;
595+
case Intrinsic::dx_wave_active_usum:
596+
Result = expandWaveActiveOpIntrinsic<0, false>(Orig);
597+
break;
577598
}
578599
if (Result) {
579600
Orig->replaceAllUsesWith(Result);

0 commit comments

Comments
 (0)