Skip to content

AMDGPU: Replace sqrt OpenCL libcalls with llvm.sqrt #74197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 9, 2024

Conversation

arsenm
Copy link
Contributor

@arsenm arsenm commented Dec 2, 2023

The library implementation is just a wrapper around a call to the intrinsic, but loses metadata. Swap out the call site to the intrinsic so that the lowering can see the !fpmath metadata and fast math flags.

Since d56e0d0, clang started placing !fpmath on OpenCL library sqrt calls. Also don't bother emitting native_sqrt anymore, it's just another wrapper around llvm.sqrt.

The library implementation is just a wrapper around a call to the
intrinsic, but loses metadata. Swap out the call site to the intrinsic
so that the lowering can see the !fpmath metadata and fast math flags.

Since d56e0d0, clang started placing
!fpmath on OpenCL library sqrt calls. Also don't bother emitting
native_sqrt anymore, it's just another wrapper around llvm.sqrt.
@llvmbot
Copy link
Member

llvmbot commented Dec 2, 2023

@llvm/pr-subscribers-backend-amdgpu

Author: Matt Arsenault (arsenm)

Changes

The library implementation is just a wrapper around a call to the intrinsic, but loses metadata. Swap out the call site to the intrinsic so that the lowering can see the !fpmath metadata and fast math flags.

Since d56e0d0, clang started placing !fpmath on OpenCL library sqrt calls. Also don't bother emitting native_sqrt anymore, it's just another wrapper around llvm.sqrt.


Patch is 24.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/74197.diff

3 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp (+3-29)
  • (modified) llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-sqrt.ll (+38-38)
  • (modified) llvm/test/CodeGen/AMDGPU/simplify-libcalls.ll (+2-3)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index 5c66fd2b180f7..245c42af43483 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -87,9 +87,6 @@ class AMDGPULibCalls {
                               Constant *copr0, Constant *copr1);
   bool evaluateCall(CallInst *aCI, const FuncInfo &FInfo);
 
-  // sqrt
-  bool fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B, const FuncInfo &FInfo);
-
   /// Insert a value to sincos function \p Fsincos. Returns (value of sin, value
   /// of cos, sincos call).
   std::tuple<Value *, Value *, Value *> insertSinCos(Value *Arg,
@@ -673,8 +670,6 @@ bool AMDGPULibCalls::fold(CallInst *CI) {
 
     // Specialized optimizations for each function call.
     //
-    // TODO: Handle other simple intrinsic wrappers. Sqrt.
-    //
     // TODO: Handle native functions
     switch (FInfo.getId()) {
     case AMDGPULibFunc::EI_EXP:
@@ -795,7 +790,9 @@ bool AMDGPULibCalls::fold(CallInst *CI) {
     case AMDGPULibFunc::EI_ROOTN:
       return fold_rootn(FPOp, B, FInfo);
     case AMDGPULibFunc::EI_SQRT:
-      return fold_sqrt(FPOp, B, FInfo);
+      // TODO: Allow with strictfp + constrained intrinsic
+      return tryReplaceLibcallWithSimpleIntrinsic(
+          B, CI, Intrinsic::sqrt, true, true, /*AllowStrictFP=*/false);
     case AMDGPULibFunc::EI_COS:
     case AMDGPULibFunc::EI_SIN:
       return fold_sincos(FPOp, B, FInfo);
@@ -1275,29 +1272,6 @@ bool AMDGPULibCalls::tryReplaceLibcallWithSimpleIntrinsic(
   return true;
 }
 
-// fold sqrt -> native_sqrt (x)
-bool AMDGPULibCalls::fold_sqrt(FPMathOperator *FPOp, IRBuilder<> &B,
-                               const FuncInfo &FInfo) {
-  if (!isUnsafeMath(FPOp))
-    return false;
-
-  if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
-      (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
-    Module *M = B.GetInsertBlock()->getModule();
-
-    if (FunctionCallee FPExpr = getNativeFunction(
-            M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
-      Value *opr0 = FPOp->getOperand(0);
-      LLVM_DEBUG(errs() << "AMDIC: " << *FPOp << " ---> "
-                        << "sqrt(" << *opr0 << ")\n");
-      Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt");
-      replaceCall(FPOp, nval);
-      return true;
-    }
-  }
-  return false;
-}
-
 std::tuple<Value *, Value *, Value *>
 AMDGPULibCalls::insertSinCos(Value *Arg, FastMathFlags FMF, IRBuilder<> &B,
                              FunctionCallee Fsincos) {
diff --git a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-sqrt.ll b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-sqrt.ll
index d1a58a7a0148d..f5b6f2e170777 100644
--- a/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-sqrt.ll
+++ b/llvm/test/CodeGen/AMDGPU/amdgpu-simplify-libcall-sqrt.ll
@@ -27,7 +27,7 @@ declare <16 x half> @_Z4sqrtDv16_Dh(<16 x half>)
 define float @test_sqrt_f32(float %arg) {
 ; CHECK-LABEL: define float @test_sqrt_f32
 ; CHECK-SAME: (float [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]]), !fpmath [[META0:![0-9]+]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @llvm.sqrt.f32(float [[ARG]]), !fpmath [[META0:![0-9]+]]
 ; CHECK-NEXT:    ret float [[SQRT]]
 ;
   %sqrt = tail call float @_Z4sqrtf(float %arg), !fpmath !0
@@ -37,7 +37,7 @@ define float @test_sqrt_f32(float %arg) {
 define <2 x float> @test_sqrt_v2f32(<2 x float> %arg) {
 ; CHECK-LABEL: define <2 x float> @test_sqrt_v2f32
 ; CHECK-SAME: (<2 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]]), !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[ARG]]), !fpmath [[META0]]
 ; CHECK-NEXT:    ret <2 x float> [[SQRT]]
 ;
   %sqrt = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> %arg), !fpmath !0
@@ -47,7 +47,7 @@ define <2 x float> @test_sqrt_v2f32(<2 x float> %arg) {
 define <3 x float> @test_sqrt_v3f32(<3 x float> %arg) {
 ; CHECK-LABEL: define <3 x float> @test_sqrt_v3f32
 ; CHECK-SAME: (<3 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <3 x float> @_Z4sqrtDv3_f(<3 x float> [[ARG]]), !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <3 x float> @llvm.sqrt.v3f32(<3 x float> [[ARG]]), !fpmath [[META0]]
 ; CHECK-NEXT:    ret <3 x float> [[SQRT]]
 ;
   %sqrt = tail call <3 x float> @_Z4sqrtDv3_f(<3 x float> %arg), !fpmath !0
@@ -57,7 +57,7 @@ define <3 x float> @test_sqrt_v3f32(<3 x float> %arg) {
 define <4 x float> @test_sqrt_v4f32(<4 x float> %arg) {
 ; CHECK-LABEL: define <4 x float> @test_sqrt_v4f32
 ; CHECK-SAME: (<4 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <4 x float> @_Z4sqrtDv4_f(<4 x float> [[ARG]]), !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> [[ARG]]), !fpmath [[META0]]
 ; CHECK-NEXT:    ret <4 x float> [[SQRT]]
 ;
   %sqrt = tail call <4 x float> @_Z4sqrtDv4_f(<4 x float> %arg), !fpmath !0
@@ -67,7 +67,7 @@ define <4 x float> @test_sqrt_v4f32(<4 x float> %arg) {
 define <8 x float> @test_sqrt_v8f32(<8 x float> %arg) {
 ; CHECK-LABEL: define <8 x float> @test_sqrt_v8f32
 ; CHECK-SAME: (<8 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <8 x float> @_Z4sqrtDv8_f(<8 x float> [[ARG]]), !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> [[ARG]]), !fpmath [[META0]]
 ; CHECK-NEXT:    ret <8 x float> [[SQRT]]
 ;
   %sqrt = tail call <8 x float> @_Z4sqrtDv8_f(<8 x float> %arg), !fpmath !0
@@ -77,7 +77,7 @@ define <8 x float> @test_sqrt_v8f32(<8 x float> %arg) {
 define <16 x float> @test_sqrt_v16f32(<16 x float> %arg) {
 ; CHECK-LABEL: define <16 x float> @test_sqrt_v16f32
 ; CHECK-SAME: (<16 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <16 x float> @_Z4sqrtDv16_f(<16 x float> [[ARG]]), !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <16 x float> @llvm.sqrt.v16f32(<16 x float> [[ARG]]), !fpmath [[META0]]
 ; CHECK-NEXT:    ret <16 x float> [[SQRT]]
 ;
   %sqrt = tail call <16 x float> @_Z4sqrtDv16_f(<16 x float> %arg), !fpmath !0
@@ -87,7 +87,7 @@ define <16 x float> @test_sqrt_v16f32(<16 x float> %arg) {
 define float @test_sqrt_cr_f32(float %arg) {
 ; CHECK-LABEL: define float @test_sqrt_cr_f32
 ; CHECK-SAME: (float [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @llvm.sqrt.f32(float [[ARG]])
 ; CHECK-NEXT:    ret float [[SQRT]]
 ;
   %sqrt = tail call float @_Z4sqrtf(float %arg)
@@ -97,7 +97,7 @@ define float @test_sqrt_cr_f32(float %arg) {
 define <2 x float> @test_sqrt_cr_v2f32(<2 x float> %arg) {
 ; CHECK-LABEL: define <2 x float> @test_sqrt_cr_v2f32
 ; CHECK-SAME: (<2 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @llvm.sqrt.v2f32(<2 x float> [[ARG]])
 ; CHECK-NEXT:    ret <2 x float> [[SQRT]]
 ;
   %sqrt = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> %arg)
@@ -107,7 +107,7 @@ define <2 x float> @test_sqrt_cr_v2f32(<2 x float> %arg) {
 define <3 x float> @test_sqrt_cr_v3f32(<3 x float> %arg) {
 ; CHECK-LABEL: define <3 x float> @test_sqrt_cr_v3f32
 ; CHECK-SAME: (<3 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <3 x float> @_Z4sqrtDv3_f(<3 x float> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <3 x float> @llvm.sqrt.v3f32(<3 x float> [[ARG]])
 ; CHECK-NEXT:    ret <3 x float> [[SQRT]]
 ;
   %sqrt = tail call <3 x float> @_Z4sqrtDv3_f(<3 x float> %arg)
@@ -117,7 +117,7 @@ define <3 x float> @test_sqrt_cr_v3f32(<3 x float> %arg) {
 define <4 x float> @test_sqrt_cr_v4f32(<4 x float> %arg) {
 ; CHECK-LABEL: define <4 x float> @test_sqrt_cr_v4f32
 ; CHECK-SAME: (<4 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <4 x float> @_Z4sqrtDv4_f(<4 x float> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> [[ARG]])
 ; CHECK-NEXT:    ret <4 x float> [[SQRT]]
 ;
   %sqrt = tail call <4 x float> @_Z4sqrtDv4_f(<4 x float> %arg)
@@ -127,7 +127,7 @@ define <4 x float> @test_sqrt_cr_v4f32(<4 x float> %arg) {
 define <8 x float> @test_sqrt_cr_v8f32(<8 x float> %arg) {
 ; CHECK-LABEL: define <8 x float> @test_sqrt_cr_v8f32
 ; CHECK-SAME: (<8 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <8 x float> @_Z4sqrtDv8_f(<8 x float> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> [[ARG]])
 ; CHECK-NEXT:    ret <8 x float> [[SQRT]]
 ;
   %sqrt = tail call <8 x float> @_Z4sqrtDv8_f(<8 x float> %arg)
@@ -137,7 +137,7 @@ define <8 x float> @test_sqrt_cr_v8f32(<8 x float> %arg) {
 define <16 x float> @test_sqrt_cr_v16f32(<16 x float> %arg) {
 ; CHECK-LABEL: define <16 x float> @test_sqrt_cr_v16f32
 ; CHECK-SAME: (<16 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <16 x float> @_Z4sqrtDv16_f(<16 x float> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <16 x float> @llvm.sqrt.v16f32(<16 x float> [[ARG]])
 ; CHECK-NEXT:    ret <16 x float> [[SQRT]]
 ;
   %sqrt = tail call <16 x float> @_Z4sqrtDv16_f(<16 x float> %arg)
@@ -147,7 +147,7 @@ define <16 x float> @test_sqrt_cr_v16f32(<16 x float> %arg) {
 define double @test_sqrt_f64(double %arg) {
 ; CHECK-LABEL: define double @test_sqrt_f64
 ; CHECK-SAME: (double [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call double @_Z4sqrtd(double [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call double @llvm.sqrt.f64(double [[ARG]])
 ; CHECK-NEXT:    ret double [[SQRT]]
 ;
   %sqrt = tail call double @_Z4sqrtd(double %arg)
@@ -157,7 +157,7 @@ define double @test_sqrt_f64(double %arg) {
 define <2 x double> @test_sqrt_v2f64(<2 x double> %arg) {
 ; CHECK-LABEL: define <2 x double> @test_sqrt_v2f64
 ; CHECK-SAME: (<2 x double> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x double> @_Z4sqrtDv2_d(<2 x double> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x double> @llvm.sqrt.v2f64(<2 x double> [[ARG]])
 ; CHECK-NEXT:    ret <2 x double> [[SQRT]]
 ;
   %sqrt = tail call <2 x double> @_Z4sqrtDv2_d(<2 x double> %arg)
@@ -167,7 +167,7 @@ define <2 x double> @test_sqrt_v2f64(<2 x double> %arg) {
 define <3 x double> @test_sqrt_v3f64(<3 x double> %arg) {
 ; CHECK-LABEL: define <3 x double> @test_sqrt_v3f64
 ; CHECK-SAME: (<3 x double> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <3 x double> @_Z4sqrtDv3_d(<3 x double> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <3 x double> @llvm.sqrt.v3f64(<3 x double> [[ARG]])
 ; CHECK-NEXT:    ret <3 x double> [[SQRT]]
 ;
   %sqrt = tail call <3 x double> @_Z4sqrtDv3_d(<3 x double> %arg)
@@ -177,7 +177,7 @@ define <3 x double> @test_sqrt_v3f64(<3 x double> %arg) {
 define <4 x double> @test_sqrt_v4f64(<4 x double> %arg) {
 ; CHECK-LABEL: define <4 x double> @test_sqrt_v4f64
 ; CHECK-SAME: (<4 x double> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <4 x double> @_Z4sqrtDv4_d(<4 x double> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <4 x double> @llvm.sqrt.v4f64(<4 x double> [[ARG]])
 ; CHECK-NEXT:    ret <4 x double> [[SQRT]]
 ;
   %sqrt = tail call <4 x double> @_Z4sqrtDv4_d(<4 x double> %arg)
@@ -187,7 +187,7 @@ define <4 x double> @test_sqrt_v4f64(<4 x double> %arg) {
 define <8 x double> @test_sqrt_v8f64(<8 x double> %arg) {
 ; CHECK-LABEL: define <8 x double> @test_sqrt_v8f64
 ; CHECK-SAME: (<8 x double> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <8 x double> @_Z4sqrtDv8_d(<8 x double> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <8 x double> @llvm.sqrt.v8f64(<8 x double> [[ARG]])
 ; CHECK-NEXT:    ret <8 x double> [[SQRT]]
 ;
   %sqrt = tail call <8 x double> @_Z4sqrtDv8_d(<8 x double> %arg)
@@ -197,7 +197,7 @@ define <8 x double> @test_sqrt_v8f64(<8 x double> %arg) {
 define <16 x double> @test_sqrt_v16f64(<16 x double> %arg) {
 ; CHECK-LABEL: define <16 x double> @test_sqrt_v16f64
 ; CHECK-SAME: (<16 x double> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <16 x double> @_Z4sqrtDv16_d(<16 x double> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <16 x double> @llvm.sqrt.v16f64(<16 x double> [[ARG]])
 ; CHECK-NEXT:    ret <16 x double> [[SQRT]]
 ;
   %sqrt = tail call <16 x double> @_Z4sqrtDv16_d(<16 x double> %arg)
@@ -207,7 +207,7 @@ define <16 x double> @test_sqrt_v16f64(<16 x double> %arg) {
 define half @test_sqrt_f16(half %arg) {
 ; CHECK-LABEL: define half @test_sqrt_f16
 ; CHECK-SAME: (half [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call half @_Z4sqrtDh(half [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call half @llvm.sqrt.f16(half [[ARG]])
 ; CHECK-NEXT:    ret half [[SQRT]]
 ;
   %sqrt = tail call half @_Z4sqrtDh(half %arg)
@@ -217,7 +217,7 @@ define half @test_sqrt_f16(half %arg) {
 define <2 x half> @test_sqrt_v2f16(<2 x half> %arg) {
 ; CHECK-LABEL: define <2 x half> @test_sqrt_v2f16
 ; CHECK-SAME: (<2 x half> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x half> @_Z4sqrtDv2_Dh(<2 x half> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x half> @llvm.sqrt.v2f16(<2 x half> [[ARG]])
 ; CHECK-NEXT:    ret <2 x half> [[SQRT]]
 ;
   %sqrt = tail call <2 x half> @_Z4sqrtDv2_Dh(<2 x half> %arg)
@@ -227,7 +227,7 @@ define <2 x half> @test_sqrt_v2f16(<2 x half> %arg) {
 define <3 x half> @test_sqrt_v3f16(<3 x half> %arg) {
 ; CHECK-LABEL: define <3 x half> @test_sqrt_v3f16
 ; CHECK-SAME: (<3 x half> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <3 x half> @_Z4sqrtDv3_Dh(<3 x half> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <3 x half> @llvm.sqrt.v3f16(<3 x half> [[ARG]])
 ; CHECK-NEXT:    ret <3 x half> [[SQRT]]
 ;
   %sqrt = tail call <3 x half> @_Z4sqrtDv3_Dh(<3 x half> %arg)
@@ -237,7 +237,7 @@ define <3 x half> @test_sqrt_v3f16(<3 x half> %arg) {
 define <4 x half> @test_sqrt_v4f16(<4 x half> %arg) {
 ; CHECK-LABEL: define <4 x half> @test_sqrt_v4f16
 ; CHECK-SAME: (<4 x half> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <4 x half> @_Z4sqrtDv4_Dh(<4 x half> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <4 x half> @llvm.sqrt.v4f16(<4 x half> [[ARG]])
 ; CHECK-NEXT:    ret <4 x half> [[SQRT]]
 ;
   %sqrt = tail call <4 x half> @_Z4sqrtDv4_Dh(<4 x half> %arg)
@@ -247,7 +247,7 @@ define <4 x half> @test_sqrt_v4f16(<4 x half> %arg) {
 define <8 x half> @test_sqrt_v8f16(<8 x half> %arg) {
 ; CHECK-LABEL: define <8 x half> @test_sqrt_v8f16
 ; CHECK-SAME: (<8 x half> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <8 x half> @_Z4sqrtDv8_Dh(<8 x half> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <8 x half> @llvm.sqrt.v8f16(<8 x half> [[ARG]])
 ; CHECK-NEXT:    ret <8 x half> [[SQRT]]
 ;
   %sqrt = tail call <8 x half> @_Z4sqrtDv8_Dh(<8 x half> %arg)
@@ -257,7 +257,7 @@ define <8 x half> @test_sqrt_v8f16(<8 x half> %arg) {
 define <16 x half> @test_sqrt_v16f16(<16 x half> %arg) {
 ; CHECK-LABEL: define <16 x half> @test_sqrt_v16f16
 ; CHECK-SAME: (<16 x half> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <16 x half> @_Z4sqrtDv16_Dh(<16 x half> [[ARG]])
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <16 x half> @llvm.sqrt.v16f16(<16 x half> [[ARG]])
 ; CHECK-NEXT:    ret <16 x half> [[SQRT]]
 ;
   %sqrt = tail call <16 x half> @_Z4sqrtDv16_Dh(<16 x half> %arg)
@@ -267,7 +267,7 @@ define <16 x half> @test_sqrt_v16f16(<16 x half> %arg) {
 define float @test_sqrt_f32_nobuiltin_callsite(float %arg) {
 ; CHECK-LABEL: define float @test_sqrt_f32_nobuiltin_callsite
 ; CHECK-SAME: (float [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]]) #[[ATTR2:[0-9]+]], !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]]) #[[ATTR3:[0-9]+]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret float [[SQRT]]
 ;
   %sqrt = tail call float @_Z4sqrtf(float %arg) #0, !fpmath !0
@@ -277,7 +277,7 @@ define float @test_sqrt_f32_nobuiltin_callsite(float %arg) {
 define <2 x float> @test_sqrt_v2f32_nobuiltin_callsite(<2 x float> %arg) {
 ; CHECK-LABEL: define <2 x float> @test_sqrt_v2f32_nobuiltin_callsite
 ; CHECK-SAME: (<2 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]]) #[[ATTR2]], !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]]) #[[ATTR3]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret <2 x float> [[SQRT]]
 ;
   %sqrt = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> %arg) #0, !fpmath !0
@@ -287,7 +287,7 @@ define <2 x float> @test_sqrt_v2f32_nobuiltin_callsite(<2 x float> %arg) {
 define float @test_sqrt_cr_f32_nobuiltin_callsite(float %arg) {
 ; CHECK-LABEL: define float @test_sqrt_cr_f32_nobuiltin_callsite
 ; CHECK-SAME: (float [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]]) #[[ATTR2]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]]) #[[ATTR3]]
 ; CHECK-NEXT:    ret float [[SQRT]]
 ;
   %sqrt = tail call float @_Z4sqrtf(float %arg) #0
@@ -297,7 +297,7 @@ define float @test_sqrt_cr_f32_nobuiltin_callsite(float %arg) {
 define <2 x float> @test_sqrt_cr_v2f32_nobuiltin_callsite(<2 x float> %arg) {
 ; CHECK-LABEL: define <2 x float> @test_sqrt_cr_v2f32_nobuiltin_callsite
 ; CHECK-SAME: (<2 x float> [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]]) #[[ATTR2]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]]) #[[ATTR3]]
 ; CHECK-NEXT:    ret <2 x float> [[SQRT]]
 ;
   %sqrt = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> %arg) #0
@@ -308,7 +308,7 @@ define <2 x float> @test_sqrt_cr_v2f32_nobuiltin_callsite(<2 x float> %arg) {
 define float @test_sqrt_f32_nobuiltins(float %arg) #1 {
 ; CHECK-LABEL: define float @test_sqrt_f32_nobuiltins
 ; CHECK-SAME: (float [[ARG:%.*]]) #[[ATTR0:[0-9]+]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]]) #[[ATTR2]], !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]]) #[[ATTR3]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret float [[SQRT]]
 ;
   %sqrt = tail call float @_Z4sqrtf(float %arg) #0, !fpmath !0
@@ -318,7 +318,7 @@ define float @test_sqrt_f32_nobuiltins(float %arg) #1 {
 define <2 x float> @test_sqrt_v2f32_nobuiltins(<2 x float> %arg) #1 {
 ; CHECK-LABEL: define <2 x float> @test_sqrt_v2f32_nobuiltins
 ; CHECK-SAME: (<2 x float> [[ARG:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]]) #[[ATTR2]], !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]]) #[[ATTR3]], !fpmath [[META0]]
 ; CHECK-NEXT:    ret <2 x float> [[SQRT]]
 ;
   %sqrt = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> %arg) #0, !fpmath !0
@@ -328,7 +328,7 @@ define <2 x float> @test_sqrt_v2f32_nobuiltins(<2 x float> %arg) #1 {
 define float @test_sqrt_cr_f32_nobuiltins(float %arg) #1 {
 ; CHECK-LABEL: define float @test_sqrt_cr_f32_nobuiltins
 ; CHECK-SAME: (float [[ARG:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]]) #[[ATTR2]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call float @_Z4sqrtf(float [[ARG]]) #[[ATTR3]]
 ; CHECK-NEXT:    ret float [[SQRT]]
 ;
   %sqrt = tail call float @_Z4sqrtf(float %arg) #0
@@ -338,7 +338,7 @@ define float @test_sqrt_cr_f32_nobuiltins(float %arg) #1 {
 define <2 x float> @test_sqrt_cr_v2f32_nobuiltins(<2 x float> %arg) #1 {
 ; CHECK-LABEL: define <2 x float> @test_sqrt_cr_v2f32_nobuiltins
 ; CHECK-SAME: (<2 x float> [[ARG:%.*]]) #[[ATTR0]] {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]]) #[[ATTR2]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> [[ARG]]) #[[ATTR3]]
 ; CHECK-NEXT:    ret <2 x float> [[SQRT]]
 ;
   %sqrt = tail call <2 x float> @_Z4sqrtDv2_f(<2 x float> %arg) #0
@@ -348,7 +348,7 @@ define <2 x float> @test_sqrt_cr_v2f32_nobuiltins(<2 x float> %arg) #1 {
 define float @test_sqrt_f32_preserve_flags(float %arg) {
 ; CHECK-LABEL: define float @test_sqrt_f32_preserve_flags
 ; CHECK-SAME: (float [[ARG:%.*]]) {
-; CHECK-NEXT:    [[SQRT:%.*]] = tail call nnan ninf float @_Z4sqrtf(float [[ARG]]), !fpmath [[META0]]
+; CHECK-NEXT:    [[SQRT:%.*]] = tail...
[truncated]

Copy link
Collaborator

@rampitec rampitec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any codegen changes except IR? It is hard to reason looking at IR only.

@arsenm
Copy link
Contributor Author

arsenm commented Dec 12, 2023

Any codegen changes except IR? It is hard to reason looking at IR only.

The IR change is the point. The only meaningful codegen tests would be an end-to-end test which also links the libraries in. There's nothing to really test for codegen, other than this can now be directly codegened into a working program.

@rampitec
Copy link
Collaborator

Any codegen changes except IR? It is hard to reason looking at IR only.

The IR change is the point. The only meaningful codegen tests would be an end-to-end test which also links the libraries in. There's nothing to really test for codegen, other than this can now be directly codegened into a working program.

But do we have these tests which did not change? Can you identify them? I feel like this pass was ported from HSAIL without tests, which is explainable, because the tests were HSAIL tests. I mean, you did the change, you probably have looked into the ISA which did not became worse. What is that ISA? What is the problem solved here?

@arsenm
Copy link
Contributor Author

arsenm commented Dec 12, 2023

Any codegen changes except IR? It is hard to reason looking at IR only.

The IR change is the point. The only meaningful codegen tests would be an end-to-end test which also links the libraries in. There's nothing to really test for codegen, other than this can now be directly codegened into a working program.

But do we have these tests which did not change? Can you identify them? I feel like this pass was ported from HSAIL without tests, which is explainable, because the tests were HSAIL tests. I mean, you did the change, you probably have looked into the ISA which did not became worse. What is that ISA? What is the problem solved here?

The library implementation loses the fast math information, in forms of the fast math flags and !fpmath metadata. If it's converted to the intrinsic, the backend can directly swap out the implementation. This will help drop the ugly unsafe math control library

@rampitec
Copy link
Collaborator

Any codegen changes except IR? It is hard to reason looking at IR only.

The IR change is the point. The only meaningful codegen tests would be an end-to-end test which also links the libraries in. There's nothing to really test for codegen, other than this can now be directly codegened into a working program.

But do we have these tests which did not change? Can you identify them? I feel like this pass was ported from HSAIL without tests, which is explainable, because the tests were HSAIL tests. I mean, you did the change, you probably have looked into the ISA which did not became worse. What is that ISA? What is the problem solved here?

The library implementation loses the fast math information, in forms of the fast math flags and !fpmath metadata. If it's converted to the intrinsic, the backend can directly swap out the implementation. This will help drop the ugly unsafe math control library

I agree. I just want to see an impact.

@rampitec rampitec requested a review from b-sumner December 12, 2023 10:29
@b-sumner
Copy link

I think this is creating risk, but I am not prepared to oppose it.

@rampitec
Copy link
Collaborator

I have checked our current llvm.sqrt.f32 lowering. To get the same result as native_sqrt would give either a call needs to have afn attribute, or fpmath metadata has to be attached to the call requesting 2ulp or lower accuracy.

On the other hand current folding is done if either 'fast' flag is set on the call or "unsafe-fp-math" attribute is set on a caller function. So the question is: will conditions from the first list be satisfied if any one of the conditions from the second list is met? I.e. does it have a potential for regression?

For instance I do not see checks for 'call fast float @llvm.sqrt.f32' in the fsqrt.f32.ll.

@rampitec
Copy link
Collaborator

Speaking of the the "unsafe-fp-math" attribute, clang will produce it for the following options:

-cl-fast-relaxed-math
-cl-unsafe-math-optimizations
-ffast-math
-funsafe-math-optimizations AND -ffp-contract=fast (but not with -ffp-contract=on or off)

Then if we are not relying on the attribute anymore, either afn or !fpmath >= 2.0 shall be set on a call site with the same list of options. I am not sure if and why a call to @_Z4sqrtf would have any of these.

@rampitec
Copy link
Collaborator

With fast on a call site it gives expected results, it would be nice to add this to the fsqrt.f32.ll:

define float @v_sqrt_f32_fast(float %x) {
  %result = call fast float @llvm.sqrt.f32(float %x)
  ret float %result
}

declare float @llvm.sqrt.f32(float)

Test with unsafe-fp-math attribute present there (v_sqrt_f32__unsafe_attr), but needs a fix. This gives expected results though:

define float @v_sqrt_f32__unsafe_attr(float %x) #4 {
  %result = call float @llvm.sqrt.f32(float %x)
  ret float %result
}

declare float @llvm.sqrt.f32(float)

attributes #4 = { "unsafe-fp-math"="true" }

The difference is that in the fsqrt.f32.ll it has nsz attribute:

define float @v_sqrt_f32__unsafe_attr(float %x) #4 {
; GCN-LABEL: v_sqrt_f32__unsafe_attr:
; GCN:       ; %bb.0:
; GCN-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GCN-NEXT:    v_sqrt_f32_e32 v0, v0
; GCN-NEXT:    s_setpc_b64 s[30:31]
  %result = call nsz float @llvm.sqrt.f32(float %x)
  ret float %result
}

So I suggest adding the first test (with fast) and removing nsz in the second. The rest seems OK.

@arsenm
Copy link
Contributor Author

arsenm commented Dec 13, 2023

I have checked our current llvm.sqrt.f32 lowering. To get the same result as native_sqrt would give either a call needs to have afn attribute, or fpmath metadata has to be attached to the call requesting 2ulp or lower accuracy.

Also depends on the denormal mode. All of the combinations of llvm.sqrt.f32 lowering cases are handled already

On the other hand current folding is done if either 'fast' flag is set on the call or "unsafe-fp-math" attribute is set on a caller function.

The current folding is useless. All it does is replace sqrt with native_sqrt calls, if you use an off by default amdgpu-use-native flag. It's basically dead code, nothing uses the flag.

So the question is: will conditions from the first list be satisfied if any one of the conditions from the second list is met?
I.e. does it have a potential for regression?

No. AMDGPULibCalls was never as refined as it should be. unsafe-fp-math is deprecated-ish and doesn't matter, it's functionally an alias for an uncertain set of the other attributes and

For instance I do not see checks for 'call fast float @llvm.sqrt.f32' in the fsqrt.f32.ll.

fast is just a union of all the flags, it's excess. There's no real point in testing the effect of excess flags, it's just a much less refined variant of the tests that already are there.

@arsenm
Copy link
Contributor Author

arsenm commented Dec 13, 2023

So I suggest adding the first test (with fast) and removing nsz in the second. The rest seems OK.

nsz doesn't do anything for sqrt lowering. Signed zero works correctly in the instruction. We have an excess of messy and overlapping sqrt tests as is, I don't think we need to add any new ones. Adding tests with "fast" will just add to the mess. We should drop the nsz and possibly any using the unsafe-fp-math attributes

@rampitec
Copy link
Collaborator

rampitec commented Dec 15, 2023

nsz doesn't do anything for sqrt lowering.

I know it does not. Why not remove it?

I will feel much more confident if this one is fixed from nsz nonsense and extra 'fast' test added, because who knows what fast will mean tomorrow? Anyway, formally to have no regressions fast must resolve to a native op, so this test is formally needed.

Moreover, I wish it to be this change, so that 10 years from now git blame will point to this discussion and linked tests.

@rampitec
Copy link
Collaborator

So I suggest adding the first test (with fast) and removing nsz in the second. The rest seems OK.

nsz doesn't do anything for sqrt lowering. Signed zero works correctly in the instruction. We have an excess of messy and overlapping sqrt tests as is, I don't think we need to add any new ones. Adding tests with "fast" will just add to the mess. We should drop the nsz and possibly any using the unsafe-fp-math attributes

And then if you want to drop attribute handling that is a completely different change and a different discussion. You claim here is no regressions.

@arsenm
Copy link
Contributor Author

arsenm commented Jan 4, 2024

Then if we are not relying on the attribute anymore, either afn or !fpmath >= 2.0 shall be set on a call site with the same list > of options. I am not sure if and why a call to @_Z4sqrtf would have any of these.

OpenCL by default / without -cl-fp32-correctly-rounded-divide-sqrt emits !fpmath 2.5 on these. Otherwise it's missing. afn is set by -fapprox-func, and any of the superset/alias flags like -cl-unsafe-math-optimizations/-cl-fast-relaxed-math/-ffast-math

nsz doesn't do anything for sqrt lowering.

I know it does not. Why not remove it?

Because it almost does something, but isn't quite enough to actually simplify the f64 lowering case. This way it's captured if somebody attempts the same thing later.

I will feel much more confident if this one is fixed from nsz nonsense and extra 'fast' test added, because who knows what fast will mean tomorrow?

fast can only add more flags. We only need to test the minimum that we need today. Fast just makes the intention of the test less clear

Anyway, formally to have no regressions fast must resolve to a native op, so this test is formally needed.

The point of this change is that the libcall is a pointless wrapper. There is no codegen change. Fast or not is irrelevant, we're swapping out any recognized sqrt calls to the raw intrinsic, and excluding the odd case where we should not specially recognize the call.

All of the lowering cases are already tested. Any codegen logic is already captured in the variety of large sqrt lowering tests we already have. It's simply not useful to codegen any of these libcall recognition tests. An end-to-end test, which does include the library linking, would be useful but belongs in the device libs tests.

Moreover, I wish it to be this change, so that 10 years from now git blame will point to this discussion and linked tests.

That's in all of the sqrt lowering patches

@arsenm arsenm merged commit daecc30 into llvm:main Jan 9, 2024
@arsenm arsenm deleted the amdgpu-libcall-sqrt branch January 9, 2024 08:14
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
The library implementation is just a wrapper around a call to the
intrinsic, but loses metadata. Swap out the call site to the intrinsic
so that the lowering can see the !fpmath metadata and fast math flags.

Since d56e0d0, clang started placing
!fpmath on OpenCL library sqrt calls. Also don't bother emitting
native_sqrt anymore, it's just another wrapper around llvm.sqrt.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants