Skip to content

Commit df1bee0

Browse files
authored
[mlir] Add math to LLVM lowering support for missing trigonometric & hyperbolic ops (#125753)
The patch adds support for math -> LLVM dialect lowering for TanOp, Sinh, Cosh, Tanh
1 parent 88f55d1 commit df1bee0

File tree

3 files changed

+43
-13
lines changed

3 files changed

+43
-13
lines changed

flang/test/Intrinsics/math-codegen.fir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,10 +1803,10 @@ func.func private @sinh(f64) -> f64
18031803
//--- tanh_fast.fir
18041804
// RUN: fir-opt %t/tanh_fast.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/tanh_fast.fir
18051805
// CHECK: @_QPtest_real4
1806-
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @tanhf({{%[A-Za-z0-9._]+}}) : (f32) -> f32
1806+
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.tanh({{%[A-Za-z0-9._]+}}) : (f32) -> f32
18071807

18081808
// CHECK: @_QPtest_real8
1809-
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @tanh({{%[A-Za-z0-9._]+}}) : (f64) -> f64
1809+
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.tanh({{%[A-Za-z0-9._]+}}) : (f64) -> f64
18101810

18111811
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
18121812
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
@@ -1828,10 +1828,10 @@ func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
18281828
//--- tanh_relaxed.fir
18291829
// RUN: fir-opt %t/tanh_relaxed.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/tanh_relaxed.fir
18301830
// CHECK: @_QPtest_real4
1831-
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @tanhf({{%[A-Za-z0-9._]+}}) : (f32) -> f32
1831+
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.tanh({{%[A-Za-z0-9._]+}}) : (f32) -> f32
18321832

18331833
// CHECK: @_QPtest_real8
1834-
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @tanh({{%[A-Za-z0-9._]+}}) : (f64) -> f64
1834+
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.tanh({{%[A-Za-z0-9._]+}}) : (f64) -> f64
18351835

18361836
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
18371837
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
@@ -1880,10 +1880,10 @@ func.func private @tanh(f64) -> f64
18801880
//--- tan_fast.fir
18811881
// RUN: fir-opt %t/tan_fast.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/tan_fast.fir
18821882
// CHECK: @_QPtest_real4
1883-
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @tanf({{%[A-Za-z0-9._]+}}) : (f32) -> f32
1883+
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.tan({{%[A-Za-z0-9._]+}}) : (f32) -> f32
18841884

18851885
// CHECK: @_QPtest_real8
1886-
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @tan({{%[A-Za-z0-9._]+}}) : (f64) -> f64
1886+
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.tan({{%[A-Za-z0-9._]+}}) : (f64) -> f64
18871887

18881888
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
18891889
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}
@@ -1905,10 +1905,10 @@ func.func @_QPtest_real8(%arg0: !fir.ref<f64> {fir.bindc_name = "x"}) -> f64 {
19051905
//--- tan_relaxed.fir
19061906
// RUN: fir-opt %t/tan_relaxed.fir --fir-to-llvm-ir="target=x86_64-unknown-linux-gnu" | FileCheck %t/tan_relaxed.fir
19071907
// CHECK: @_QPtest_real4
1908-
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @tanf({{%[A-Za-z0-9._]+}}) : (f32) -> f32
1908+
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.tan({{%[A-Za-z0-9._]+}}) : (f32) -> f32
19091909

19101910
// CHECK: @_QPtest_real8
1911-
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.call @tan({{%[A-Za-z0-9._]+}}) : (f64) -> f64
1911+
// CHECK: {{%[A-Za-z0-9._]+}} = llvm.intr.tan({{%[A-Za-z0-9._]+}}) : (f64) -> f64
19121912

19131913
func.func @_QPtest_real4(%arg0: !fir.ref<f32> {fir.bindc_name = "x"}) -> f32 {
19141914
%0 = fir.alloca f32 {bindc_name = "test_real4", uniq_name = "_QFtest_real4Etest_real4"}

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ using CeilOpLowering = ConvertFMFMathToLLVMPattern<math::CeilOp, LLVM::FCeilOp>;
3939
using CopySignOpLowering =
4040
ConvertFMFMathToLLVMPattern<math::CopySignOp, LLVM::CopySignOp>;
4141
using CosOpLowering = ConvertFMFMathToLLVMPattern<math::CosOp, LLVM::CosOp>;
42+
using CoshOpLowering = ConvertFMFMathToLLVMPattern<math::CoshOp, LLVM::CoshOp>;
4243
using CtPopFOpLowering =
4344
VectorConvertToLLVMPattern<math::CtPopOp, LLVM::CtPopOp>;
4445
using Exp2OpLowering = ConvertFMFMathToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
@@ -58,9 +59,12 @@ using RoundEvenOpLowering =
5859
using RoundOpLowering =
5960
ConvertFMFMathToLLVMPattern<math::RoundOp, LLVM::RoundOp>;
6061
using SinOpLowering = ConvertFMFMathToLLVMPattern<math::SinOp, LLVM::SinOp>;
62+
using SinhOpLowering = ConvertFMFMathToLLVMPattern<math::SinhOp, LLVM::SinhOp>;
6163
using SqrtOpLowering = ConvertFMFMathToLLVMPattern<math::SqrtOp, LLVM::SqrtOp>;
6264
using FTruncOpLowering =
6365
ConvertFMFMathToLLVMPattern<math::TruncOp, LLVM::FTruncOp>;
66+
using TanOpLowering = ConvertFMFMathToLLVMPattern<math::TanOp, LLVM::TanOp>;
67+
using TanhOpLowering = ConvertFMFMathToLLVMPattern<math::TanhOp, LLVM::TanhOp>;
6468

6569
// A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`.
6670
template <typename MathOp, typename LLVMOp>
@@ -310,6 +314,7 @@ void mlir::populateMathToLLVMConversionPatterns(
310314
CeilOpLowering,
311315
CopySignOpLowering,
312316
CosOpLowering,
317+
CoshOpLowering,
313318
CountLeadingZerosOpLowering,
314319
CountTrailingZerosOpLowering,
315320
CtPopFOpLowering,
@@ -327,8 +332,11 @@ void mlir::populateMathToLLVMConversionPatterns(
327332
RoundOpLowering,
328333
RsqrtOpLowering,
329334
SinOpLowering,
335+
SinhOpLowering,
330336
SqrtOpLowering,
331-
FTruncOpLowering
337+
FTruncOpLowering,
338+
TanOpLowering,
339+
TanhOpLowering
332340
>(converter);
333341
// clang-format on
334342
}

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,33 @@ func.func @rsqrt(%arg0 : f32) {
161161

162162
// -----
163163

164-
// CHECK-LABEL: func @sine(
165-
// CHECK-SAME: f32
166-
func.func @sine(%arg0 : f32) {
167-
// CHECK: llvm.intr.sin(%arg0) : (f32) -> f32
164+
// CHECK-LABEL: func @trigonometrics
165+
// CHECK-SAME: [[ARG0:%.+]]: f32
166+
func.func @trigonometrics(%arg0: f32) {
167+
// CHECK: llvm.intr.sin([[ARG0]]) : (f32) -> f32
168168
%0 = math.sin %arg0 : f32
169+
170+
// CHECK: llvm.intr.cos([[ARG0]]) : (f32) -> f32
171+
%1 = math.cos %arg0 : f32
172+
173+
// CHECK: llvm.intr.tan([[ARG0]]) : (f32) -> f32
174+
%2 = math.tan %arg0 : f32
175+
func.return
176+
}
177+
178+
// -----
179+
180+
// CHECK-LABEL: func @hyperbolics
181+
// CHECK-SAME: [[ARG0:%.+]]: f32
182+
func.func @hyperbolics(%arg0: f32) {
183+
// CHECK: llvm.intr.sinh([[ARG0]]) : (f32) -> f32
184+
%0 = math.sinh %arg0 : f32
185+
186+
// CHECK: llvm.intr.cosh([[ARG0]]) : (f32) -> f32
187+
%1 = math.cosh %arg0 : f32
188+
189+
// CHECK: llvm.intr.tanh([[ARG0]]) : (f32) -> f32
190+
%2 = math.tanh %arg0 : f32
169191
func.return
170192
}
171193

0 commit comments

Comments
 (0)