Skip to content

[flang][AMDGPU] Convert math ops to AMD GPU library calls instead of libm calls #99517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 10, 2024

Conversation

jsjodin
Copy link
Contributor

@jsjodin jsjodin commented Jul 18, 2024

This patch invokes a pass when compiling for an AMDGPU target to lower math operations to AMD GPU library calls library calls instead of libm calls.

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-flang-codegen
@llvm/pr-subscribers-flang-openmp

@llvm/pr-subscribers-backend-amdgpu

Author: Jan Leyonberg (jsjodin)

Changes

This patch invokes a pass when compiling for an AMDGPU target to lower math operations to AMD GPU library calls library calls instead of libm calls.


Full diff: https://github.com/llvm/llvm-project/pull/99517.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+11-1)
  • (added) flang/test/Lower/OpenMP/math-amdgpu.f90 (+184)
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 650448eee1099..646621cb01c15 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
+  MLIRMathToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f9ea92a843b23..02992857dde06 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -34,6 +34,7 @@
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -3610,6 +3611,14 @@ class FIRToLLVMLowering
     // as passes here.
     mlir::OpPassManager mathConvertionPM("builtin.module");
 
+    bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
+    // If compiling for AMD target some math operations must be lowered to AMD
+    // GPU library calls, the rest can be converted to LLVM intrinsics, which
+    // is handled in the mathToLLVM conversion. The lowering to libm calls is
+    // not needed since all math operations are handled this way.
+    if (isAMDGCN)
+      mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
     // it will be lowered to LLVM intrinsic operation by a later conversion.
@@ -3649,7 +3658,8 @@ class FIRToLLVMLowering
                                                           pattern);
     // Math operations that have not been converted yet must be converted
     // to Libm.
-    mlir::populateMathToLibmConversionPatterns(pattern);
+    if (!isAMDGCN)
+      mlir::populateMathToLibmConversionPatterns(pattern);
     mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
     mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);
 
diff --git a/flang/test/Lower/OpenMP/math-amdgpu.f90 b/flang/test/Lower/OpenMP/math-amdgpu.f90
new file mode 100644
index 0000000000000..b455b42d3ed34
--- /dev/null
+++ b/flang/test/Lower/OpenMP/math-amdgpu.f90
@@ -0,0 +1,184 @@
+!REQUIRES: amdgpu-registered-target
+!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+
+subroutine omp_pow_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_pow_f32(float {{.*}}, float {{.*}})
+  y = x ** x
+end subroutine omp_pow_f32
+
+subroutine omp_pow_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_pow_f64(double {{.*}}, double {{.*}})
+  y = x ** x
+end subroutine omp_pow_f64
+
+subroutine omp_sin_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_sin_f32(float {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f32
+
+subroutine omp_sin_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_sin_f64(double {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f64
+
+subroutine omp_abs_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_fabs_f32(float {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f32
+
+subroutine omp_abs_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_fabs_f64(double {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f64
+
+subroutine omp_atan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_atan_f32(float {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f32
+
+subroutine omp_atan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_atan_f64(double {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f64
+
+subroutine omp_atan2_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_atan2_f32(float {{.*}}, float {{.*}})
+  y = atan2(x, x)
+end subroutine omp_atan2_f32
+
+subroutine omp_atan2_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_atan2_f64(double {{.*}}, double {{.*}})
+  y = atan2(x ,x)
+end subroutine omp_atan2_f64
+
+subroutine omp_cos_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_cos_f32(float {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f32
+
+subroutine omp_cos_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_cos_f64(double {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f64
+
+subroutine omp_erf_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_erf_f32(float {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f32
+
+subroutine omp_erf_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_erf_f64(double {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f64
+
+subroutine omp_exp_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_exp_f32(float {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f32
+
+subroutine omp_exp_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_exp_f64(double {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f64
+
+subroutine omp_log_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_log_f32(float {{.*}})
+  y = log(x)
+end subroutine omp_log_f32
+
+subroutine omp_log_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_log_f64(double {{.*}})
+  y = log(x)
+end subroutine omp_log_f64
+
+subroutine omp_log10_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_log10_f32(float {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f32
+
+subroutine omp_log10_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_log10_f64(double {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f64
+
+subroutine omp_sqrt_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_sqrt_f32(float {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f32
+
+subroutine omp_sqrt_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_sqrt_f64(double {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f64
+
+subroutine omp_tan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_tan_f32(float {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f32
+
+subroutine omp_tan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_tan_f64(double {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f64
+
+subroutine omp_tanh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_tanh_f32(float {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f32
+
+subroutine omp_tanh_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_tanh_f64(double {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f64

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Jan Leyonberg (jsjodin)

Changes

This patch invokes a pass when compiling for an AMDGPU target to lower math operations to AMD GPU library calls library calls instead of libm calls.


Full diff: https://github.com/llvm/llvm-project/pull/99517.diff

3 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CMakeLists.txt (+1)
  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+11-1)
  • (added) flang/test/Lower/OpenMP/math-amdgpu.f90 (+184)
diff --git a/flang/lib/Optimizer/CodeGen/CMakeLists.txt b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
index 650448eee1099..646621cb01c15 100644
--- a/flang/lib/Optimizer/CodeGen/CMakeLists.txt
+++ b/flang/lib/Optimizer/CodeGen/CMakeLists.txt
@@ -26,6 +26,7 @@ add_flang_library(FIRCodeGen
   MLIRMathToFuncs
   MLIRMathToLLVM
   MLIRMathToLibm
+  MLIRMathToROCDL
   MLIROpenMPToLLVM
   MLIROpenACCDialect
   MLIRBuiltinToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index f9ea92a843b23..02992857dde06 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -34,6 +34,7 @@
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
@@ -3610,6 +3611,14 @@ class FIRToLLVMLowering
     // as passes here.
     mlir::OpPassManager mathConvertionPM("builtin.module");
 
+    bool isAMDGCN = fir::getTargetTriple(mod).isAMDGCN();
+    // If compiling for AMD target some math operations must be lowered to AMD
+    // GPU library calls, the rest can be converted to LLVM intrinsics, which
+    // is handled in the mathToLLVM conversion. The lowering to libm calls is
+    // not needed since all math operations are handled this way.
+    if (isAMDGCN)
+      mathConvertionPM.addPass(mlir::createConvertMathToROCDL());
+
     // Convert math::FPowI operations to inline implementation
     // only if the exponent's width is greater than 32, otherwise,
     // it will be lowered to LLVM intrinsic operation by a later conversion.
@@ -3649,7 +3658,8 @@ class FIRToLLVMLowering
                                                           pattern);
     // Math operations that have not been converted yet must be converted
     // to Libm.
-    mlir::populateMathToLibmConversionPatterns(pattern);
+    if (!isAMDGCN)
+      mlir::populateMathToLibmConversionPatterns(pattern);
     mlir::populateComplexToLLVMConversionPatterns(typeConverter, pattern);
     mlir::populateVectorToLLVMConversionPatterns(typeConverter, pattern);
 
diff --git a/flang/test/Lower/OpenMP/math-amdgpu.f90 b/flang/test/Lower/OpenMP/math-amdgpu.f90
new file mode 100644
index 0000000000000..b455b42d3ed34
--- /dev/null
+++ b/flang/test/Lower/OpenMP/math-amdgpu.f90
@@ -0,0 +1,184 @@
+!REQUIRES: amdgpu-registered-target
+!RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-llvm -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+
+subroutine omp_pow_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_pow_f32(float {{.*}}, float {{.*}})
+  y = x ** x
+end subroutine omp_pow_f32
+
+subroutine omp_pow_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_pow_f64(double {{.*}}, double {{.*}})
+  y = x ** x
+end subroutine omp_pow_f64
+
+subroutine omp_sin_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_sin_f32(float {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f32
+
+subroutine omp_sin_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_sin_f64(double {{.*}})
+  y = sin(x)
+end subroutine omp_sin_f64
+
+subroutine omp_abs_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_fabs_f32(float {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f32
+
+subroutine omp_abs_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_fabs_f64(double {{.*}})
+  y = abs(x)
+end subroutine omp_abs_f64
+
+subroutine omp_atan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_atan_f32(float {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f32
+
+subroutine omp_atan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_atan_f64(double {{.*}})
+  y = atan(x)
+end subroutine omp_atan_f64
+
+subroutine omp_atan2_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_atan2_f32(float {{.*}}, float {{.*}})
+  y = atan2(x, x)
+end subroutine omp_atan2_f32
+
+subroutine omp_atan2_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_atan2_f64(double {{.*}}, double {{.*}})
+  y = atan2(x ,x)
+end subroutine omp_atan2_f64
+
+subroutine omp_cos_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_cos_f32(float {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f32
+
+subroutine omp_cos_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_cos_f64(double {{.*}})
+  y = cos(x)
+end subroutine omp_cos_f64
+
+subroutine omp_erf_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_erf_f32(float {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f32
+
+subroutine omp_erf_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_erf_f64(double {{.*}})
+  y = erf(x)
+end subroutine omp_erf_f64
+
+subroutine omp_exp_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_exp_f32(float {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f32
+
+subroutine omp_exp_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_exp_f64(double {{.*}})
+  y = exp(x)
+end subroutine omp_exp_f64
+
+subroutine omp_log_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_log_f32(float {{.*}})
+  y = log(x)
+end subroutine omp_log_f32
+
+subroutine omp_log_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_log_f64(double {{.*}})
+  y = log(x)
+end subroutine omp_log_f64
+
+subroutine omp_log10_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_log10_f32(float {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f32
+
+subroutine omp_log10_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_log10_f64(double {{.*}})
+  y = log10(x)
+end subroutine omp_log10_f64
+
+subroutine omp_sqrt_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_sqrt_f32(float {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f32
+
+subroutine omp_sqrt_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_sqrt_f64(double {{.*}})
+  y = sqrt(x)
+end subroutine omp_sqrt_f64
+
+subroutine omp_tan_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_tan_f32(float {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f32
+
+subroutine omp_tan_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_tan_f64(double {{.*}})
+  y = tan(x)
+end subroutine omp_tan_f64
+
+subroutine omp_tanh_f32(x, y)
+!$omp declare target
+  real :: x, y
+!CHECK: call float @__ocml_tanh_f32(float {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f32
+
+subroutine omp_tanh_f64(x, y)
+!$omp declare target
+  real(8) :: x, y
+!CHECK: call double @__ocml_tanh_f64(double {{.*}})
+  y = tanh(x)
+end subroutine omp_tanh_f64

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is the way the system should work, but I guess it's reusing an existing pass

@jsjodin
Copy link
Contributor Author

jsjodin commented Jul 19, 2024

I don't think this is the way the system should work, but I guess it's reusing an existing pass

I haven't found a better way to do this with what exists right now. The problem is that the math ops get converted to LLVM intrinsic calls in MLIR which may not be supported by the backend and there's no mechanism to map intrinsics to functions that implement them right now. There is a discussion thread on discourse, and the main proposal requires changes to the library where an attribute would be added to a function to indicate it implements a certain intrinsic.

@arsenm
Copy link
Contributor

arsenm commented Jul 19, 2024

I don't think this is the way the system should work, but I guess it's reusing an existing pass

I haven't found a better way to do this with what exists right now. The problem is that the math ops get converted to LLVM intrinsic calls in MLIR which may not be supported by the backend and there's no mechanism to map intrinsics to functions that implement them right now. There is a discussion thread on discourse, and the main proposal requires changes to the library where an attribute would be added to a function to indicate it implements a certain intrinsic.

Right, we don't have a real platform definition for how this is supposed to work. Every frontend stitches all of these pieces together slightly differently, and the result is a big mess

subroutine omp_abs_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_fabs_f32(float {{.*}})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one should definitely just go to llvm.fabs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... then why is there an OCML function for it? Backwards compatibility?

subroutine omp_abs_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_fabs_f64(double {{.*}})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. OCML shouldn't even provide these

subroutine omp_exp_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_exp_f32(float {{.*}})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just call the llvm intrinsic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Yeah, that one does seem to just wrap the intrinsic going off of an llvm-dis

subroutine omp_log_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_log_f32(float {{.*}})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should just use the llvm intrinsic

subroutine omp_sqrt_f64(x, y)
!$omp declare target
real(8) :: x, y
!CHECK: call double @__ocml_sqrt_f64(double {{.*}})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should just use the llvm intrinsic

subroutine omp_sqrt_f32(x, y)
!$omp declare target
real :: x, y
!CHECK: call float @__ocml_sqrt_f32(float {{.*}})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should just use the llvm intrinsic

@jsjodin jsjodin requested a review from krzysz00 July 25, 2024 17:22
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arsenm Overall, is there a problem with using the OCML functions uniformly, even if they're trivial? That leaves the ability for them to become non-trivial in the future

@arsenm
Copy link
Contributor

arsenm commented Jul 25, 2024

@arsenm Overall, is there a problem with using the OCML functions uniformly, even if they're trivial?

Yes, especially when they're trivial. The cases that just map to an llvm intrinsic add system complexity for no benefit.

That leaves the ability for them to become non-trivial in the future

Fundamentally we should be treating it more as a runtime library for the compiler and not a user facing abstraction. Right now it's trying to act like a user library or libm and compiler-rt all at the same time when they have different/incompatible constraints.

@jsjodin
Copy link
Contributor Author

jsjodin commented Jul 26, 2024

@arsenm Overall, is there a problem with using the OCML functions uniformly, even if they're trivial?

Yes, especially when they're trivial. The cases that just map to an llvm intrinsic add system complexity for no benefit.

Having a more uniform interface reduces complexity. That interface should be the llvm intrinsics (or maybe libm calls), but there are holes and forcing the higher level transforms to know about what intrinsics that are implemented and not in the backend increases complexity in my opinion.

That leaves the ability for them to become non-trivial in the future

Fundamentally we should be treating it more as a runtime library for the compiler and not a user facing abstraction. Right now it's trying to act like a user library or libm and compiler-rt all at the same time when they have different/incompatible constraints.

If I understand this correctly the main issue is that clang and flang use OCML at all, and it should be purely internal to LLVM.
OCML is only used within the LLVM projects from what I can tell. It doesn't seem to be user-facing, since it is not mentioned in the ROCm documentation anymore (not for a long time), so we can perhaps assume we control where and how it is used. Once a real solution is implemented we should be able to just delete this pass and let the standard lowering passes do the work. In the mean time I think it is reasonable to use OCML as the uniform interface.

What are your opinions on adding the "implements" attribute as long term solution? This would add an attribute to tell the compiler that a function implements an intrinsic something like attribute implements(llvm.pow). I was looking into this, but I don't have time to work on it right now. The discussion thread is here:
https://discourse.llvm.org/t/nvptx-codegen-for-llvm-sin-and-friends/58170/33

@arsenm
Copy link
Contributor

arsenm commented Jul 26, 2024

If I understand this correctly the main issue is that clang and flang use OCML at all, and it should be purely internal to LLVM.

Mainly yes. OCML implicitly acts like an implementation detail of language library functions, rather than part of the system that we can reliably do libcall recognition and emission with. As such every language has to wrap this in some other layer (which they all do differently), and we can't reliably expect the functions to exist later at any point. OpenCL hacks around this with separate "prelink" and "postlink" options, and none of the other languages get any libcall optimizations. Unless we move this to some system-like position, we can't do anything language independent with it.

OCML is only used within the LLVM projects from what I can tell. It doesn't seem to be user-facing, since it is not mentioned in the ROCm documentation anymore (not for a long time), so we can perhaps assume we control where and how it is used. Once a real solution is implemented we should be able to just delete this pass and let the standard lowering passes do the work. In the mean time I think it is reasonable to use OCML as the uniform interface.

LLVM is the interface. OCML is an implementation detail.

What are your opinions on adding the "implements" attribute as long term solution? This would add an attribute to tell the compiler that a function implements an intrinsic something like attribute implements(llvm.pow). I was looking into this, but I don't have time to work on it right now. The discussion thread is here: https://discourse.llvm.org/t/nvptx-codegen-for-llvm-sin-and-friends/58170/33

It's not really how I wanted this to go, but the current trend is to have all of the math library functions available as intrinsics. Given that, I don't see much point in adding such an attribute. I was leaning towards removing any intrinsics that no backend can reasonably implement without a runtime call, and that would be complemented by improved libcall-by-name handling. It would make more sense to add implements if we went in the other direction. With this trend, I think it's jumping through hoops to maintain implementation detail names we control and can replace.

We lack a proper platform definition. I think it would be best if we defined amdhsa like a normal operating system, with a well defined set of provided libm functions, using the standard names. The __ocml prefixes are a relic from how OpenCL was implemented long ago. It would make more sense to put any extension functions in an __amd or __amdhsa prefix. Another issue is amdpal and the other triples we use. Different projects have thrown assorted builtin libraries together in different ways, so it would be better to define this bottom up.

@jsjodin
Copy link
Contributor Author

jsjodin commented Aug 12, 2024

@arsenm, if I change the conversion pass to map to the supported llvm intinsics, which you pointed out, would that be acceptable for now?

@arsenm
Copy link
Contributor

arsenm commented Aug 16, 2024

@arsenm, if I change the conversion pass to map to the supported llvm intinsics, which you pointed out, would that be acceptable for now?

Yes. Hopefully in the future we will have a sensible intrinsic lowering solution (and I'd like to minimize the number of places that need to be touched to match then), so a common MLIR pass would be preferable

…libm calls

This patch invokes a pass when compiling for an AMDGPU target to lower math
operations to AMD GPU library calls library calls instead of libm calls.
@jsjodin jsjodin force-pushed the jleyonberg/flangmathfuncs branch from 665a76f to 1d7ec8e Compare September 5, 2024 16:06
@jsjodin jsjodin requested review from arsenm and krzysz00 September 5, 2024 16:11
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine to me and I'm trusting you're happy with the test coverage

@jsjodin jsjodin merged commit 4290e34 into llvm:main Sep 10, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants