Skip to content

[HLSL] Implement rsqrt intrinsic #84820

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4590,6 +4590,12 @@ def HLSLRcp : LangBuiltin<"HLSL_LANG"> {
let Prototype = "void(...)";
}

def HLSLRSqrt : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_rsqrt"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
let Prototype = "void(...)";
}

// Builtins for XRay.
def XRayCustomEvent : Builtin {
let Spellings = ["__xray_customevent"];
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18089,6 +18089,14 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/Op0->getType(), Intrinsic::dx_rcp,
ArrayRef<Value *>{Op0}, nullptr, "dx.rcp");
}
case Builtin::BI__builtin_hlsl_elementwise_rsqrt: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
llvm_unreachable("rsqrt operand must have a float representation");
return Builder.CreateIntrinsic(
/*ReturnType=*/Op0->getType(), Intrinsic::dx_rsqrt,
ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
}
}
return nullptr;
}
Expand Down
33 changes: 33 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,39 @@ double3 rcp(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rcp)
double4 rcp(double4);

//===----------------------------------------------------------------------===//
// rsqrt builtins
//===----------------------------------------------------------------------===//

/// \fn T rsqrt(T x)
/// \brief Returns the reciprocal of the square root of the specified value.
/// ie 1 / sqrt( \a x).
/// \param x The specified input value.
///
/// This function uses the following formula: 1 / sqrt(x).

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
half rsqrt(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
half2 rsqrt(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
half3 rsqrt(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
half4 rsqrt(half4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
float rsqrt(float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
float2 rsqrt(float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
float3 rsqrt(float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_rsqrt)
float4 rsqrt(float4);

//===----------------------------------------------------------------------===//
// round builtins
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 3 additions & 2 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5286,12 +5286,13 @@ bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
return true;
break;
}
case Builtin::BI__builtin_hlsl_elementwise_rsqrt:
case Builtin::BI__builtin_hlsl_elementwise_rcp:
case Builtin::BI__builtin_hlsl_elementwise_frac: {
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
return true;
if (CheckAllArgsHaveFloatRepresentation(this, TheCall))
return true;
if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
return true;
break;
}
case Builtin::BI__builtin_hlsl_lerp: {
Expand Down
53 changes: 53 additions & 0 deletions clang/test/CodeGenHLSL/builtins/rsqrt.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
// RUN: -emit-llvm -disable-llvm-passes -o - | FileCheck %s \
// RUN: --check-prefixes=CHECK,NATIVE_HALF
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm -disable-llvm-passes \
// RUN: -o - | FileCheck %s --check-prefixes=CHECK,NO_HALF

// NATIVE_HALF: define noundef half @
// NATIVE_HALF: %dx.rsqrt = call half @llvm.dx.rsqrt.f16(
// NATIVE_HALF: ret half %dx.rsqrt
// NO_HALF: define noundef float @"?test_rsqrt_half@@YA$halff@$halff@@Z"(
// NO_HALF: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
// NO_HALF: ret float %dx.rsqrt
half test_rsqrt_half(half p0) { return rsqrt(p0); }
// NATIVE_HALF: define noundef <2 x half> @
// NATIVE_HALF: %dx.rsqrt = call <2 x half> @llvm.dx.rsqrt.v2f16
// NATIVE_HALF: ret <2 x half> %dx.rsqrt
// NO_HALF: define noundef <2 x float> @
// NO_HALF: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32(
// NO_HALF: ret <2 x float> %dx.rsqrt
half2 test_rsqrt_half2(half2 p0) { return rsqrt(p0); }
// NATIVE_HALF: define noundef <3 x half> @
// NATIVE_HALF: %dx.rsqrt = call <3 x half> @llvm.dx.rsqrt.v3f16
// NATIVE_HALF: ret <3 x half> %dx.rsqrt
// NO_HALF: define noundef <3 x float> @
// NO_HALF: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32(
// NO_HALF: ret <3 x float> %dx.rsqrt
half3 test_rsqrt_half3(half3 p0) { return rsqrt(p0); }
// NATIVE_HALF: define noundef <4 x half> @
// NATIVE_HALF: %dx.rsqrt = call <4 x half> @llvm.dx.rsqrt.v4f16
// NATIVE_HALF: ret <4 x half> %dx.rsqrt
// NO_HALF: define noundef <4 x float> @
// NO_HALF: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32(
// NO_HALF: ret <4 x float> %dx.rsqrt
half4 test_rsqrt_half4(half4 p0) { return rsqrt(p0); }

// CHECK: define noundef float @
// CHECK: %dx.rsqrt = call float @llvm.dx.rsqrt.f32(
// CHECK: ret float %dx.rsqrt
float test_rsqrt_float(float p0) { return rsqrt(p0); }
// CHECK: define noundef <2 x float> @
// CHECK: %dx.rsqrt = call <2 x float> @llvm.dx.rsqrt.v2f32
// CHECK: ret <2 x float> %dx.rsqrt
float2 test_rsqrt_float2(float2 p0) { return rsqrt(p0); }
// CHECK: define noundef <3 x float> @
// CHECK: %dx.rsqrt = call <3 x float> @llvm.dx.rsqrt.v3f32
// CHECK: ret <3 x float> %dx.rsqrt
float3 test_rsqrt_float3(float3 p0) { return rsqrt(p0); }
// CHECK: define noundef <4 x float> @
// CHECK: %dx.rsqrt = call <4 x float> @llvm.dx.rsqrt.v4f32
// CHECK: ret <4 x float> %dx.rsqrt
float4 test_rsqrt_float4(float4 p0) { return rsqrt(p0); }
2 changes: 1 addition & 1 deletion clang/test/SemaHLSL/BuiltIns/frac-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ float2 test_too_many_arg(float2 p0) {

float builtin_bool_to_float_type_promotion(bool p1) {
return __builtin_hlsl_elementwise_frac(p1);
// expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
// expected-error@-1 {{passing 'bool' to parameter of incompatible type 'float'}}
}

float builtin_frac_int_to_float_promotion(int p1) {
Expand Down
2 changes: 1 addition & 1 deletion clang/test/SemaHLSL/BuiltIns/rcp-errors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ float2 test_too_many_arg(float2 p0) {

float builtin_bool_to_float_type_promotion(bool p1) {
return __builtin_hlsl_elementwise_rcp(p1);
// expected-error@-1 {{1st argument must be a vector, integer or floating point type (was 'bool')}}
// expected-error@-1 {passing 'bool' to parameter of incompatible type 'float'}}
}

float builtin_rcp_int_to_float_promotion(int p1) {
Expand Down
27 changes: 27 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/rsqrt-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm -disable-llvm-passes -verify -verify-ignore-unexpected

float test_too_few_arg() {
return __builtin_hlsl_elementwise_rsqrt();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

float2 test_too_many_arg(float2 p0) {
return __builtin_hlsl_elementwise_rsqrt(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

float builtin_bool_to_float_type_promotion(bool p1) {
return __builtin_hlsl_elementwise_rsqrt(p1);
// expected-error@-1 {{passing 'bool' to parameter of incompatible type 'float'}}
}

float builtin_rsqrt_int_to_float_promotion(int p1) {
return __builtin_hlsl_elementwise_rsqrt(p1);
// expected-error@-1 {{passing 'int' to parameter of incompatible type 'float'}}
}

float2 builtin_rsqrt_int2_to_float2_promotion(int2 p1) {
return __builtin_hlsl_elementwise_rsqrt(p1);
// expected-error@-1 {{passing 'int2' (aka 'vector<int, 2>') to parameter of incompatible type '__attribute__((__vector_size__(2 * sizeof(float)))) float' (vector of 2 'float' values)}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ def int_dx_lerp :
def int_dx_imad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_umad : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_rcp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
}
4 changes: 4 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ def Frac : DXILOpMapping<22, unary, int_dx_frac,
"Returns a fraction from 0 to 1 that represents the "
"decimal part of the input.",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
def RSqrt : DXILOpMapping<25, unary, int_dx_rsqrt,
"Returns the reciprocal of the square root of the specified value."
"rsqrt(x) = 1 / sqrt(x).",
[llvm_halforfloat_ty, LLVMMatchType<0>]>;
def Round : DXILOpMapping<26, unary, int_round,
"Returns the input rounded to the nearest integer"
"within a floating-point type.",
Expand Down
28 changes: 28 additions & 0 deletions llvm/test/CodeGen/DirectX/rsqrt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
; RUN: opt -S -dxil-op-lower < %s | FileCheck %s

; Make sure dxil operation function calls for rsqrt are generated for float and half.

; CHECK-LABEL: rsqrt_float
; CHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}})
define noundef float @rsqrt_float(float noundef %a) {
entry:
%a.addr = alloca float, align 4
store float %a, ptr %a.addr, align 4
%0 = load float, ptr %a.addr, align 4
%dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %0)
ret float %dx.rsqrt
}

; CHECK-LABEL: rsqrt_half
; CHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}})
define noundef half @rsqrt_half(half noundef %a) {
entry:
%a.addr = alloca half, align 2
store half %a, ptr %a.addr, align 2
%0 = load half, ptr %a.addr, align 2
%dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %0)
ret half %dx.rsqrt
}

declare half @llvm.dx.rsqrt.f16(half)
declare float @llvm.dx.rsqrt.f32(float)
14 changes: 14 additions & 0 deletions llvm/test/CodeGen/DirectX/rsqrt_error.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s

; DXIL operation rsqrt does not support double overload type
; CHECK: LLVM ERROR: Invalid Overload Type

; Function Attrs: noinline nounwind optnone
define noundef double @rsqrt_double(double noundef %a) #0 {
entry:
%a.addr = alloca double, align 8
store double %a, ptr %a.addr, align 8
%0 = load double, ptr %a.addr, align 8
%dx.rsqrt = call double @llvm.dx.rsqrt.f64(double %0)
ret double %dx.rsqrt
}