Skip to content

[HLSL] Implement WaveReadLaneAt intrinsic #111010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Oct 16, 2024
Merged
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4703,6 +4703,12 @@ def HLSLWaveIsFirstLane : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool()";
}

def HLSLWaveReadLaneAt : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_read_lane_at"];
let Attributes = [NoThrow, Const];
let Prototype = "void(...)";
}

def HLSLClamp : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_clamp"];
let Attributes = [NoThrow, Const];
Expand Down
16 changes: 16 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18835,6 +18835,22 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
return EmitRuntimeCall(Intrinsic::getDeclaration(&CGM.getModule(), ID));
}
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
// Due to the use of variadic arguments we must explicitly retreive them and
// create our function type.
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Value *OpIndex = EmitScalarExpr(E->getArg(1));
llvm::FunctionType *FT = llvm::FunctionType::get(
OpExpr->getType(), ArrayRef{OpExpr->getType(), OpIndex->getType()},
false);

// Get overloaded name
std::string name =
Intrinsic::getName(CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
ArrayRef{OpExpr->getType()}, &CGM.getModule());
return EmitRuntimeCall(CGM.CreateRuntimeFunction(FT, name, {}, false, true),
ArrayRef{OpExpr, OpIndex}, "hlsl.wave.read.lane.at");
Copy link
Member

@farzonl farzonl Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string we used for BI__builtin_hlsl_wave_get_lane_index was __hlsl_wave_get_lane_index. Why would we use periods here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think __hlsl_wave_get_lane_index is the odd one out. The other intrinsics follow the pattern of hlsl.name. Having changed to using one word waveReadLaneAt I think we can keep it consistent naming with hlsl.waveReadLaneAt.

I can change the name to hlsl.waveGetLaneIndex in the clean-up pr.

}
case Builtin::BI__builtin_hlsl_elementwise_sign: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
llvm::Type *Xty = Op0->getType();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot)
GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_read_lane_at)

//===----------------------------------------------------------------------===//
// End of reserved area for HLSL intrinsic getters.
Expand Down
7 changes: 7 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,13 @@ _HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_is_first_lane)
__attribute__((convergent)) bool WaveIsFirstLane();

// \brief Returns the value of the expression for the given lane index within
// the specified wave.
template <typename T>
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) T WaveReadLaneAt(T, int32_t);

//===----------------------------------------------------------------------===//
// sign builtins
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1956,6 +1956,26 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;

// Ensure index parameter type can be interpreted as a uint
ExprResult Index = TheCall->getArg(1);
QualType ArgTyIndex = Index.get()->getType();
if (!ArgTyIndex->hasIntegerRepresentation()) {
SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyIndex << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
return true;
}

// Ensure return type is the same as the input expr type
ExprResult Expr = TheCall->getArg(0);
QualType ArgTyExpr = Expr.get()->getType();
TheCall->setType(ArgTyExpr);
break;
}
case Builtin::BI__builtin_elementwise_acos:
case Builtin::BI__builtin_elementwise_asin:
case Builtin::BI__builtin_elementwise_atan:
Expand Down
40 changes: 40 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveReadLaneAt.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -x hlsl -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call.

// CHECK-LABEL: test_int
int test_int(int expr, uint idx) {
// CHECK-SPIRV: %[[#entry_tok:]] = call token @llvm.experimental.convergence.entry()

// CHECK-SPIRV: %[[RET:.*]] = call [[TY:.*]] @llvm.spv.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.read.lane.at.i32([[TY]] %[[#]], i32 %[[#]])

// CHECK: ret [[TY]] %[[RET]]
return WaveReadLaneAt(expr, idx);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.read.lane.at.i32([[TY]], i32) #[[#attr:]]

// Test basic lowering to runtime function call with array and float value.

// CHECK-LABEL: test_floatv4
float4 test_floatv4(float4 expr, uint idx) {
// CHECK-SPIRV: %[[#entry_tok1:]] = call token @llvm.experimental.convergence.entry()

// CHECK-SPIRV: %[[RET1:.*]] = call [[TY1:.*]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])
// CHECK-DXIL: %[[RET1:.*]] = call [[TY1:.*]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]] %[[#]], i32 %[[#]])

// CHECK: ret [[TY1]] %[[RET1]]
return WaveReadLaneAt(expr, idx);
}

// CHECK-DXIL: declare [[TY1]] @llvm.dx.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]
// CHECK-SPIRV: declare [[TY1]] @llvm.spv.wave.read.lane.at.v4f32([[TY1]], i32) #[[#attr]]

// CHECK: attributes #[[#attr]] = {{{.*}} convergent {{.*}}}
21 changes: 21 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveReadLaneAt-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected

bool test_too_few_arg() {
return __builtin_hlsl_wave_read_lane_at();
// expected-error@-1 {{too few arguments to function call, expected 2, have 0}}
}

float2 test_too_few_arg_1(float2 p0) {
return __builtin_hlsl_wave_read_lane_at(p0);
// expected-error@-1 {{too few arguments to function call, expected 2, have 1}}
}

float2 test_too_many_arg(float2 p0) {
return __builtin_hlsl_wave_read_lane_at(p0, p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 2, have 3}}
}

float3 test_index_type_check(float3 p0, double idx) {
return __builtin_hlsl_wave_read_lane_at(p0, idx);
// expected-error@-1 {{passing 'double' to parameter of incompatible type 'unsigned int'}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLV
def int_dx_normalize : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty], [IntrNoMem]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>], [IntrNoMem]>;
def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_dx_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>;
def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>;
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,6 @@ let TargetPrefix = "spv" in {
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, Commutative] >;
def int_spv_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>;
def int_spv_wave_read_lane_at : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, llvm_i32_ty], [IntrConvergent]>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wave_read_lane_at would be better as one word.

Probably should update wave_is_first_lane in a follow cleanup on pr.

Copy link
Contributor Author

@inbelic inbelic Oct 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a commit to change the name to waveReadLaneAt as described in below comment. I will take care of the follow-up change pr.

def int_spv_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty]>;
}
10 changes: 10 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -801,3 +801,13 @@ def WaveIsFirstLane : DXILOp<110, waveIsFirstLane> {
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def WaveReadLaneAt: DXILOp<117, waveIsFirstLane> {
let Doc = "returns the value from the specified lane";
let LLVMIntrinsic = int_dx_wave_read_lane_at;
let arguments = [OverloadTy, Int32Ty];
let result = OverloadTy;
let overloads = [Overloads<DXIL1_0, [HalfTy, FloatTy, DoubleTy, Int1Ty, Int16Ty, Int32Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}
15 changes: 15 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2653,6 +2653,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(3, I, IntTy, TII));
}
case Intrinsic::spv_wave_read_lane_at: {
assert(I.getNumOperands() == 4);
assert(I.getOperand(2).isReg());
assert(I.getOperand(3).isReg());

// Defines the execution scope currently 2 for group, see scope table
SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
return BuildMI(BB, I, I.getDebugLoc(),
TII.get(SPIRV::OpGroupNonUniformShuffle))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(2).getReg())
.addUse(I.getOperand(3).getReg())
.addUse(GR.getOrCreateConstInt(2, I, IntTy, TII));
}
case Intrinsic::spv_step:
return selectStep(ResVReg, ResType, I);
// Discard intrinsics which we do not expect to actually represent code after
Expand Down
56 changes: 56 additions & 0 deletions llvm/test/CodeGen/DirectX/WaveReadLaneAt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

; Test that for scalar values, WaveReadLaneAt maps down to the DirectX op

define noundef half @wave_rla_half(half noundef %expr, i32 noundef %idx) #0 {
entry:
; CHECK: call half @dx.op.waveReadLaneAt.f16(i32 117, half %expr, i32 %idx)
%ret = call half @llvm.dx.wave.read.lane.at.f16(half %expr, i32 %idx)
ret half %ret
}

define noundef float @wave_rla_float(float noundef %expr, i32 noundef %idx) #0 {
entry:
; CHECK: call float @dx.op.waveReadLaneAt.f32(i32 117, float %expr, i32 %idx)
%ret = call float @llvm.dx.wave.read.lane.at(float %expr, i32 %idx)
ret float %ret
}

define noundef double @wave_rla_double(double noundef %expr, i32 noundef %idx) #0 {
entry:
; CHECK: call double @dx.op.waveReadLaneAt.f64(i32 117, double %expr, i32 %idx)
%ret = call double @llvm.dx.wave.read.lane.at(double %expr, i32 %idx)
ret double %ret
}

define noundef i1 @wave_rla_i1(i1 noundef %expr, i32 noundef %idx) #0 {
entry:
; CHECK: call i1 @dx.op.waveReadLaneAt.i1(i32 117, i1 %expr, i32 %idx)
%ret = call i1 @llvm.dx.wave.read.lane.at.i1(i1 %expr, i32 %idx)
ret i1 %ret
}

define noundef i16 @wave_rla_i16(i16 noundef %expr, i32 noundef %idx) #0 {
entry:
; CHECK: call i16 @dx.op.waveReadLaneAt.i16(i32 117, i16 %expr, i32 %idx)
%ret = call i16 @llvm.dx.wave.read.lane.at.i16(i16 %expr, i32 %idx)
ret i16 %ret
}

define noundef i32 @wave_rla_i32(i32 noundef %expr, i32 noundef %idx) #0 {
entry:
; CHECK: call i32 @dx.op.waveReadLaneAt.i32(i32 117, i32 %expr, i32 %idx)
%ret = call i32 @llvm.dx.wave.read.lane.at.i32(i32 %expr, i32 %idx)
ret i32 %ret
}

declare half @llvm.dx.wave.read.lane.at.f16(half, i32) #1
declare float @llvm.dx.wave.read.lane.at.f32(float, i32) #1
declare double @llvm.dx.wave.read.lane.at.f64(double, i32) #1

declare i1 @llvm.dx.wave.read.lane.at.i1(i1, i32) #1
declare i16 @llvm.dx.wave.read.lane.at.i16(i16, i32) #1
declare i32 @llvm.dx.wave.read.lane.at.i32(i32, i32) #1

attributes #0 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #1 = { nocallback nofree nosync nounwind willreturn }
28 changes: 28 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveReadLaneAt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32v1.3-vulkan-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}

; Test lowering to spir-v backend

; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 2
; CHECK-DAG: %[[#f32:]] = OpTypeFloat 32
; CHECK-DAG: %[[#expr:]] = OpFunctionParameter %[[#f32]]
; CHECK-DAG: %[[#idx:]] = OpFunctionParameter %[[#uint]]

define spir_func void @test_1(float %expr, i32 %idx) #0 {
entry:
%0 = call token @llvm.experimental.convergence.entry()
; CHECK: %[[#ret:]] = OpGroupNonUniformShuffle %[[#f32]] %[[#expr]] %[[#idx]] %[[#scope]]
%1 = call float @llvm.spv.wave.read.lane.at(float %expr, i32 %idx) [ "convergencectrl"(token %0) ]
ret void
}

declare i32 @__hlsl_wave_get_lane_index() #1

attributes #0 = { convergent norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #1 = { convergent }

!llvm.module.flags = !{!0, !1}

!0 = !{i32 1, !"wchar_size", i32 4}
!1 = !{i32 4, !"dx.disable_optimizations", i32 1}
Loading