Skip to content

[SPIRV][HLSL] Add lowering of rsqrt to SPIRV #95849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 18, 2024
Merged

Conversation

hekota
Copy link
Member

@hekota hekota commented Jun 17, 2024

Add lowering of rsqrt to SPIRV.

Fixes #88949

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support backend:SPIR-V llvm:ir labels Jun 17, 2024
@llvmbot
Copy link
Member

llvmbot commented Jun 17, 2024

@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-backend-spir-v
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-hlsl

Author: Helena Kotas (hekota)

Changes

Add lowering of rsqrt to SPIRV.

Fixes #88949


Full diff: https://github.com/llvm/llvm-project/pull/95849.diff

5 Files Affected:

  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+1-1)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+22)
  • (added) llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll (+29)
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 511e1fd4016d7..3c233eb3f2dbf 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18331,7 +18331,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
       llvm_unreachable("rsqrt operand must have a float representation");
     return Builder.CreateIntrinsic(
-        /*ReturnType=*/Op0->getType(), Intrinsic::dx_rsqrt,
+        /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(),
         ArrayRef<Value *>{Op0}, nullptr, "dx.rsqrt");
   }
   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 0abe39dedcb96..4036ce711bea1 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -75,6 +75,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Any, any)
   GENERATE_HLSL_INTRINSIC_FUNCTION(Lerp, lerp)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(Rsqrt, rsqrt)
   GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)
 
   //===----------------------------------------------------------------------===//
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 90f12674d0470..683acf4a6ffa9 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -62,4 +62,5 @@ let TargetPrefix = "spv" in {
   def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
   def int_spv_lerp : Intrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>,LLVMMatchType<0>], 
     [IntrNoMem, IntrWillReturn] >;
+  def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]>;
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index db83172f7fa9c..b9e5569029cfd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -173,6 +173,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   bool selectFmix(Register ResVReg, const SPIRVType *ResType,
                   MachineInstr &I) const;
 
+  bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
+                   MachineInstr &I) const;
+
   void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
                    int OpIdx) const;
   void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
@@ -1315,6 +1318,23 @@ bool SPIRVInstructionSelector::selectFmix(Register ResVReg,
       .constrainAllUses(TII, TRI, RBI);
 }
 
+bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
+                                           const SPIRVType *ResType,
+                                           MachineInstr &I) const {
+
+  assert(I.getNumOperands() == 3);
+  assert(I.getOperand(2).isReg());
+  MachineBasicBlock &BB = *I.getParent();
+
+  return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpExtInst))
+      .addDef(ResVReg)
+      .addUse(GR.getSPIRVTypeID(ResType))
+      .addImm(static_cast<uint32_t>(SPIRV::InstructionSet::GLSL_std_450))
+      .addImm(GL::InverseSqrt)
+      .addUse(I.getOperand(2).getReg())
+      .constrainAllUses(TII, TRI, RBI);
+}
+
 bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
                                                 const SPIRVType *ResType,
                                                 MachineInstr &I) const {
@@ -1992,6 +2012,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
     return selectAny(ResVReg, ResType, I);
   case Intrinsic::spv_lerp:
     return selectFmix(ResVReg, ResType, I);
+  case Intrinsic::spv_rsqrt:
+    return selectRsqrt(ResVReg, ResType, I);
   case Intrinsic::spv_lifetime_start:
   case Intrinsic::spv_lifetime_end: {
     unsigned Op = IID == Intrinsic::spv_lifetime_start ? SPIRV::OpLifetimeStart
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll
new file mode 100644
index 0000000000000..1541a5715b952
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/rsqrt.ll
@@ -0,0 +1,29 @@
+; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpExtInstImport "GLSL.std.450"
+
+define noundef float @rsqrt_float(float noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]]
+  %elt.rsqrt = call float @llvm.spv.rsqrt.f32(float %a)
+  ret float %elt.rsqrt
+}
+
+define noundef half @rsqrt_half(half noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]]
+  %elt.rsqrt = call half @llvm.spv.rsqrt.f16(half %a)
+  ret half %elt.rsqrt
+}
+
+define noundef double @rsqrt_double(double noundef %a) {
+entry:
+; CHECK: %[[#]] = OpExtInst %[[#]] %[[#]] InverseSqrt %[[#]]
+  %elt.rsqrt = call double @llvm.spv.rsqrt.f64(double %a)
+  ret double %elt.rsqrt
+}
+
+declare half @llvm.spv.sqrt.f16(half)
+declare float @llvm.spv.sqrt.f32(float)
+declare float @llvm.spv.sqrt.f64(float)

@hekota hekota requested a review from farzonl June 17, 2024 21:59
Copy link
Member

@farzonl farzonl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, code looks right, just a few more tests needed.

Rename intrinsic from dx.rsqrt to to hlsl.rsqrt
Copy link
Contributor

@Keenuts Keenuts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@hekota hekota merged commit 35a2b60 into llvm:main Jun 18, 2024
8 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:SPIR-V clang:codegen IR generation bugs: mangling, exceptions, etc. clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[SPIRV][HLSL] implement rsqrt lowering
4 participants