Skip to content

Commit 6d13cc9

Browse files
authored
[HLSL] Implement WaveReadLaneAt intrinsic (#111010)
- create a clang built-in in Builtins.td - add semantic checking in SemaHLSL.cpp - link the WaveReadLaneAt api in hlsl_intrinsics.h - add lowering to spirv backend op GroupNonUniformShuffle with Scope = 2 (Group) in SPIRVInstructionSelector.cpp - add WaveReadLaneAt intrinsic to IntrinsicsDirectX.td and mapping to DXIL.td - add tests for HLSL intrinsic lowering to spirv intrinsic in WaveReadLaneAt.hlsl - add tests for sema checks in WaveReadLaneAt-errors.hlsl - add spir-v backend tests in WaveReadLaneAt.ll - add test to show scalar dxil lowering functionality - note that this doesn't include support for the scalarizer to handle WaveReadLaneAt will be added in a future pr This is the first part #70104
1 parent 210140a commit 6d13cc9

File tree

14 files changed

+413
-0
lines changed

14 files changed

+413
-0
lines changed

clang/include/clang/Basic/Builtins.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4761,6 +4761,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
47614761
let Prototype = "bool()";
47624762
}
47634763

4764+
def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
4765+
let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
4766+
let Attributes = [NoThrow, Const];
4767+
let Prototype = "void(...)";
4768+
}
4769+
47644770
def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
47654771
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
47664772
let Attributes = [NoThrow, Const];

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9230,6 +9230,8 @@ def err_typecheck_cond_incompatible_operands : Error<
92309230
def err_typecheck_expect_scalar_or_vector : Error<
92319231
"invalid operand of type %0 where %1 or "
92329232
"a vector of such type is required">;
9233+
def err_typecheck_expect_any_scalar_or_vector : Error<
9234+
"invalid operand of type %0 where a scalar or vector is required">;
92339235
def err_typecheck_expect_flt_or_vector : Error<
92349236
"invalid operand of type %0 where floating, complex or "
92359237
"a vector of such types is required">;

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18905,6 +18905,24 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
1890518905
return EmitRuntimeCall(
1890618906
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
1890718907
}
18908+
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
18909+
// Due to the use of variadic arguments we must explicitly retreive them and
18910+
// create our function type.
18911+
Value *OpExpr = EmitScalarExpr(E->getArg(0));
18912+
Value *OpIndex = EmitScalarExpr(E->getArg(1));
18913+
llvm::FunctionType *FT = llvm::FunctionType::get(
18914+
OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
18915+
false);
18916+
18917+
// Get overloaded name
18918+
std::string Name =
18919+
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
18920+
ArrayRef{OpExpr->getType()}, &CGM.getModule());
18921+
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, Name, {},
18922+
/*Local=*/false,
18923+
/*AssumeConvergent=*/true),
18924+
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
18925+
}
1890818926
case Builtin::BI__builtin_hlsl_elementwise_sign: {
1890918927
auto *Arg0 = E->getArg(0);
1891018928
Value *Op0 = EmitScalarExpr(Arg0);

clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class CGHLSLRuntime {
9090
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
9191
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
9292
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
93+
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
9394

9495
//===----------------------------------------------------------------------===//
9596
// End of reserved area for HLSL intrinsic getters.

clang/lib/Headers/hlsl/hlsl_intrinsics.h

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,6 +2097,86 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
20972097
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
20982098
__attribute__((convergent)) bool WaveIsFirstLane();
20992099

2100+
//===----------------------------------------------------------------------===//
2101+
// WaveReadLaneAt builtins
2102+
//===----------------------------------------------------------------------===//
2103+
2104+
// \brief Returns the value of the expression for the given lane index within
2105+
// the specified wave.
2106+
2107+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2108+
__attribute__((convergent)) bool WaveReadLaneAt(bool, int32_t);
2109+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2110+
__attribute__((convergent)) bool2 WaveReadLaneAt(bool2, int32_t);
2111+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2112+
__attribute__((convergent)) bool3 WaveReadLaneAt(bool3, int32_t);
2113+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2114+
__attribute__((convergent)) bool4 WaveReadLaneAt(bool4, int32_t);
2115+
2116+
#ifdef __HLSL_ENABLE_16_BIT
2117+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2118+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2119+
__attribute__((convergent)) int16_t WaveReadLaneAt(int16_t, int32_t);
2120+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2121+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2122+
__attribute__((convergent)) int16_t2 WaveReadLaneAt(int16_t2, int32_t);
2123+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2124+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2125+
__attribute__((convergent)) int16_t3 WaveReadLaneAt(int16_t3, int32_t);
2126+
_HLSL_AVAILABILITY(shadermodel, 6.0)
2127+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2128+
__attribute__((convergent)) int16_t4 WaveReadLaneAt(int16_t4, int32_t);
2129+
#endif
2130+
2131+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
2132+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2133+
__attribute__((convergent)) half WaveReadLaneAt(half, int32_t);
2134+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
2135+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2136+
__attribute__((convergent)) half2 WaveReadLaneAt(half2, int32_t);
2137+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
2138+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2139+
__attribute__((convergent)) half3 WaveReadLaneAt(half3, int32_t);
2140+
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.0)
2141+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2142+
__attribute__((convergent)) half4 WaveReadLaneAt(half4, int32_t);
2143+
2144+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2145+
__attribute__((convergent)) int WaveReadLaneAt(int, int32_t);
2146+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2147+
__attribute__((convergent)) int2 WaveReadLaneAt(int2, int32_t);
2148+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2149+
__attribute__((convergent)) int3 WaveReadLaneAt(int3, int32_t);
2150+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2151+
__attribute__((convergent)) int4 WaveReadLaneAt(int4, int32_t);
2152+
2153+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2154+
__attribute__((convergent)) float WaveReadLaneAt(float, int32_t);
2155+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2156+
__attribute__((convergent)) float2 WaveReadLaneAt(float2, int32_t);
2157+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2158+
__attribute__((convergent)) float3 WaveReadLaneAt(float3, int32_t);
2159+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2160+
__attribute__((convergent)) float4 WaveReadLaneAt(float4, int32_t);
2161+
2162+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2163+
__attribute__((convergent)) int64_t WaveReadLaneAt(int64_t, int32_t);
2164+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2165+
__attribute__((convergent)) int64_t2 WaveReadLaneAt(int64_t2, int32_t);
2166+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2167+
__attribute__((convergent)) int64_t3 WaveReadLaneAt(int64_t3, int32_t);
2168+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2169+
__attribute__((convergent)) int64_t4 WaveReadLaneAt(int64_t4, int32_t);
2170+
2171+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2172+
__attribute__((convergent)) double WaveReadLaneAt(double, int32_t);
2173+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2174+
__attribute__((convergent)) double2 WaveReadLaneAt(double2, int32_t);
2175+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2176+
__attribute__((convergent)) double3 WaveReadLaneAt(double3, int32_t);
2177+
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
2178+
__attribute__((convergent)) double4 WaveReadLaneAt(double4, int32_t);
2179+
21002180
//===----------------------------------------------------------------------===//
21012181
// sign builtins
21022182
//===----------------------------------------------------------------------===//

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1751,6 +1751,22 @@ static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar,
17511751
return false;
17521752
}
17531753

1754+
static bool CheckAnyScalarOrVector(Sema *S, CallExpr *TheCall,
1755+
unsigned ArgIndex) {
1756+
assert(TheCall->getNumArgs() >= ArgIndex);
1757+
QualType ArgType = TheCall->getArg(ArgIndex)->getType();
1758+
auto *VTy = ArgType->getAs<VectorType>();
1759+
// not the scalar or vector<scalar>
1760+
if (!(ArgType->isScalarType() ||
1761+
(VTy && VTy->getElementType()->isScalarType()))) {
1762+
S->Diag(TheCall->getArg(0)->getBeginLoc(),
1763+
diag::err_typecheck_expect_any_scalar_or_vector)
1764+
<< ArgType;
1765+
return true;
1766+
}
1767+
return false;
1768+
}
1769+
17541770
static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) {
17551771
assert(TheCall->getNumArgs() == 3);
17561772
Expr *Arg1 = TheCall->getArg(1);
@@ -1993,6 +2009,29 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
19932009
return true;
19942010
break;
19952011
}
2012+
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
2013+
if (SemaRef.checkArgCount(TheCall, 2))
2014+
return true;
2015+
2016+
// Ensure index parameter type can be interpreted as a uint
2017+
ExprResult Index = TheCall->getArg(1);
2018+
QualType ArgTyIndex = Index.get()->getType();
2019+
if (!ArgTyIndex->isIntegerType()) {
2020+
SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
2021+
diag::err_typecheck_convert_incompatible)
2022+
<< ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
2023+
return true;
2024+
}
2025+
2026+
// Ensure input expr type is a scalar/vector and the same as the return type
2027+
if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
2028+
return true;
2029+
2030+
ExprResult Expr = TheCall->getArg(0);
2031+
QualType ArgTyExpr = Expr.get()->getType();
2032+
TheCall->setType(ArgTyExpr);
2033+
break;
2034+
}
19962035
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
19972036
if (SemaRef.checkArgCount(TheCall, 0))
19982037
return true;
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -fnative-half-type -triple \
2+
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
3+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
4+
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -fnative-half-type -triple \
5+
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
6+
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
7+
8+
// Test basic lowering to runtime function call for int values.
9+
10+
// CHECK-LABEL: test_int
11+
int test_int(int expr, uint idx) {
12+
// CHECK-SPIRV: %[[#entry_tok0:]] = call token @llvm.experimental.convergence.entry()
13+
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlane.i32([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok0]]) ]
14+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlane.i32([[TY]] %[[#]], i32 %[[#]])
15+
// CHECK: ret [[TY]] %[[RET]]
16+
return WaveReadLaneAt(expr, idx);
17+
}
18+
19+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i32([[TY]], i32) #[[#attr:]]
20+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i32([[TY]], i32) #[[#attr:]]
21+
22+
#ifdef __HLSL_ENABLE_16_BIT
23+
// CHECK-LABEL: test_int16
24+
int16_t test_int16(int16_t expr, uint idx) {
25+
// CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()
26+
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlane.i16([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok1]]) ]
27+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlane.i16([[TY]] %[[#]], i32 %[[#]])
28+
// CHECK: ret [[TY]] %[[RET]]
29+
return WaveReadLaneAt(expr, idx);
30+
}
31+
32+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.i16([[TY]], i32) #[[#attr:]]
33+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.i16([[TY]], i32) #[[#attr:]]
34+
#endif
35+
36+
// Test basic lowering to runtime function call with array and float values.
37+
38+
// CHECK-LABEL: test_half
39+
half test_half(half expr, uint idx) {
40+
// CHECK-SPIRV: %[[#entry_tok2:]] = call token @llvm.experimental.convergence.entry()
41+
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlane.f16([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok2]]) ]
42+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlane.f16([[TY]] %[[#]], i32 %[[#]])
43+
// CHECK: ret [[TY]] %[[RET]]
44+
return WaveReadLaneAt(expr, idx);
45+
}
46+
47+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.f16([[TY]], i32) #[[#attr:]]
48+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.f16([[TY]], i32) #[[#attr:]]
49+
50+
// CHECK-LABEL: test_double
51+
double test_double(double expr, uint idx) {
52+
// CHECK-SPIRV: %[[#entry_tok3:]] = call token @llvm.experimental.convergence.entry()
53+
// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.readlane.f64([[TY]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok3]]) ]
54+
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.readlane.f64([[TY]] %[[#]], i32 %[[#]])
55+
// CHECK: ret [[TY]] %[[RET]]
56+
return WaveReadLaneAt(expr, idx);
57+
}
58+
59+
// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.readlane.f64([[TY]], i32) #[[#attr:]]
60+
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.readlane.f64([[TY]], i32) #[[#attr:]]
61+
62+
// CHECK-LABEL: test_floatv4
63+
float4 test_floatv4(float4 expr, uint idx) {
64+
// CHECK-SPIRV: %[[#entry_tok4:]] = call token @llvm.experimental.convergence.entry()
65+
// CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.readlane.v4f32([[TY1]] %[[#]], i32 %[[#]]) [ "convergencectrl"(token %[[#entry_tok4]]) ]
66+
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.readlane.v4f32([[TY1]] %[[#]], i32 %[[#]])
67+
// CHECK: ret [[TY1]] %[[RET1]]
68+
return WaveReadLaneAt(expr, idx);
69+
}
70+
71+
// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
72+
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.readlane.v4f32([[TY1]], i32) #[[#attr]]
73+
74+
// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
2+
3+
bool test_too_few_arg() {
4+
return __builtin_hlsl_wave_read_lane_at();
5+
// expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
6+
}
7+
8+
float2 test_too_few_arg_1(float2 p0) {
9+
return __builtin_hlsl_wave_read_lane_at(p0);
10+
// expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
11+
}
12+
13+
float2 test_too_many_arg(float2 p0) {
14+
return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
15+
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
16+
}
17+
18+
float3 test_index_double_type_check(float3 p0, double idx) {
19+
return __builtin_hlsl_wave_read_lane_at(p0, idx);
20+
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
21+
}
22+
23+
float3 test_index_int3_type_check(float3 p0, int3 idxs) {
24+
return __builtin_hlsl_wave_read_lane_at(p0, idxs);
25+
// expected-error@-1 {{passing 'int3' (aka 'vector<int, 3>') to parameter of incompatible type 'unsigned int'}}
26+
}
27+
28+
struct S { float f; };
29+
30+
float3 test_index_S_type_check(float3 p0, S idx) {
31+
return __builtin_hlsl_wave_read_lane_at(p0, idx);
32+
// expected-error@-1 {{passing 'S' to parameter of incompatible type 'unsigned int'}}
33+
}
34+
35+
S test_expr_struct_type_check(S p0, int idx) {
36+
return __builtin_hlsl_wave_read_lane_at(p0, idx);
37+
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
38+
}

llvm/include/llvm/IR/IntrinsicsDirectX.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_
8686
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
8787
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
8888
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
89+
def int_dx_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
8990
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
9091
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
9192
def int_dx_radians : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ let TargetPrefix = "spv" in {
8484
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
8585
[IntrNoMem, Commutative] >;
8686
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
87+
def int_spv_wave_readlane : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent, IntrNoMem]>;
8788
def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
8889
def int_spv_radians : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
8990

llvm/lib/Target/DirectX/DXIL.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,16 @@ def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
802802
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
803803
}
804804

805+
def WaveReadLaneAt: DXILOp<117, waveReadLaneAt> {
806+
let Doc = "returns the value from the specified lane";
807+
let LLVMIntrinsic = int_dx_wave_readlane;
808+
let arguments = [OverloadTy, Int32Ty];
809+
let result = OverloadTy;
810+
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int1Ty, Int16Ty, Int32Ty, Int64Ty]>];
811+
let stages = [Stages<DXIL1_0, [all_stages]>];
812+
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
813+
}
814+
805815
def WaveGetLaneIndex : DXILOp<111, waveGetLaneIndex> {
806816
let Doc = "returns the index of the current lane in the wave";
807817
let LLVMIntrinsic = int_dx_wave_getlaneindex;

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
230230
bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType,
231231
MachineInstr &I) const;
232232

233+
bool selectWaveReadLaneAt(Register ResVReg, const SPIRVType *ResType,
234+
MachineInstr &I) const;
235+
233236
bool selectUnmergeValues(MachineInstr &I) const;
234237

235238
void selectHandleFromBinding(Register &ResVReg, const SPIRVType *ResType,
@@ -417,6 +420,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
417420

418421
case TargetOpcode::G_INTRINSIC:
419422
case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
423+
case TargetOpcode::G_INTRINSIC_CONVERGENT:
420424
case TargetOpcode::G_INTRINSIC_CONVERGENT_W_SIDE_EFFECTS:
421425
return selectIntrinsic(ResVReg, ResType, I);
422426
case TargetOpcode::G_BITREVERSE:
@@ -1758,6 +1762,26 @@ bool SPIRVInstructionSelector::selectSign(Register ResVReg,
17581762
return Result;
17591763
}
17601764

1765+
bool SPIRVInstructionSelector::selectWaveReadLaneAt(Register ResVReg,
1766+
const SPIRVType *ResType,
1767+
MachineInstr &I) const {
1768+
assert(I.getNumOperands() == 4);
1769+
assert(I.getOperand(2).isReg());
1770+
assert(I.getOperand(3).isReg());
1771+
MachineBasicBlock &BB = *I.getParent();
1772+
1773+
// IntTy is used to define the execution scope, set to 3 to denote a
1774+
// cross-lane interaction equivalent to a SPIR-V subgroup.
1775+
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
1776+
return BuildMI(BB, I, I.getDebugLoc(),
1777+
TII.get(SPIRV::OpGroupNonUniformShuffle))
1778+
.addDef(ResVReg)
1779+
.addUse(GR.getSPIRVTypeID(ResType))
1780+
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII))
1781+
.addUse(I.getOperand(2).getReg())
1782+
.addUse(I.getOperand(3).getReg());
1783+
}
1784+
17611785
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
17621786
const SPIRVType *ResType,
17631787
MachineInstr &I) const {
@@ -2543,6 +2567,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
25432567
.addUse(GR.getSPIRVTypeID(ResType))
25442568
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
25452569
}
2570+
case Intrinsic::spv_wave_readlane:
2571+
return selectWaveReadLaneAt(ResVReg, ResType, I);
25462572
case Intrinsic::spv_step:
25472573
return selectExtInst(ResVReg, ResType, I, CL::step, GL::Step);
25482574
case Intrinsic::spv_radians:

0 commit comments

Comments
 (0)