-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SPIRV][HLSL] Add lowering of rsqrt
to SPIRV
#95849
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-ir @llvm/pr-subscribers-hlsl Author: Helena Kotas (hekota) ChangesAdd lowering of Fixes #88949 Full diff: https://github.com/llvm/llvm-project/pull/95849.diff 5 Files Affected:
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 511e1fd4016d7..3c233eb3f2dbf 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18331,7 +18331,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
if (!E->getArg(0)->getType()->hasFloatingRepresentation())
llvm_unreachable("rsqrt operand must have a float representation");
return Builder.CreateIntrinsic(
- /*ReturnType=*/Op0->getType(), Intrinsic::dx_rsqrt,
+ /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(),
ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
}
case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 0abe39dedcb96..4036ce711bea1 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -75,6 +75,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
//===----------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 90f12674d0470..683acf4a6ffa9 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -62,4 +62,5 @@ let TargetPrefix = "spv" in {
def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>],
[IntrNoMem, IntrWillReturn] >;
+ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
}
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index db83172f7fa9c..b9e5569029cfd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -173,6 +173,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectFmix(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
int OpIdx) const;
void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -1315,6 +1318,23 @@ bool SPIRVInstructionSelector::selectFmix(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}
+bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+ .addImm(GL::InverseSqrt)
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
@@ -1992,6 +2012,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectAny(ResVReg, ResType, I);
case Intrinsic::spv_lerp:
return selectFmix(ResVReg, ResType, I);
+ case Intrinsic::spv_rsqrt:
+ return selectRsqrt(ResVReg, ResType, I);
case Intrinsic::spv_lifetime_start:
case Intrinsic::spv_lifetime_end: {
unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll
new file mode 100644
index 0000000000000..1541a5715b952
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll
@@ -0,0 +1,29 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpExtInstImport "GLSL.std.450"
+
+define noundef float @rsqrt_float(float noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]]
+ %elt.rsqrt = call float @llvm.spv.rsqrt.f32(float %a)
+ ret float %elt.rsqrt
+}
+
+define noundef half @rsqrt_half(half noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]]
+ %elt.rsqrt = call half @llvm.spv.rsqrt.f16(half %a)
+ ret half %elt.rsqrt
+}
+
+define noundef double @rsqrt_double(double noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]]
+ %elt.rsqrt = call double @llvm.spv.rsqrt.f64(double %a)
+ ret double %elt.rsqrt
+}
+
+declare half @llvm.spv.sqrt.f16(half)
+declare float @llvm.spv.sqrt.f32(float)
+declare float @llvm.spv.sqrt.f64(float)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, code looks right, just a few more tests needed.
Rename intrinsic from dx.rsqrt to to hlsl.rsqrt
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Add lowering of `rsqrt` to SPIRV. Fixes llvm#88949
Add lowering of
rsqrt
to SPIRV.Fixes #88949