Skip to content

Commit 6c81b4a

Browse files
committed
[flang] Fold transformational bessels when host runtime has bessels
Transformational bessel intrinsic functions require the same math runtime as elemental bessel intrinsics. Currently elemental bessels could be folded if f18 was linked with pgmath (cmake -DLIBPGMATH_DIR option). `j0`, `y0`, ... C libm functions were not used because they are not standard C functions: they are Posix extensions. This patch enable: - Using the Posix bessel host runtime functions when available. - folding the transformational bessel using the elemental version. Differential Revision: https://reviews.llvm.org/D124167
1 parent 9687ca9 commit 6c81b4a

File tree

3 files changed

+124
-2
lines changed

3 files changed

+124
-2
lines changed

flang/lib/Evaluate/fold-real.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,38 @@
1111

1212
namespace Fortran::evaluate {
1313

14+
template <typename T>
15+
static Expr<T> FoldTransformationalBessel(
16+
FunctionRef<T> &&funcRef, FoldingContext &context) {
17+
CHECK(funcRef.arguments().size() == 3);
18+
/// Bessel runtime functions use `int` integer arguments. Convert integer
19+
/// arguments to Int4, any overflow error will be reported during the
20+
/// conversion folding.
21+
using Int4 = Type<TypeCategory::Integer, 4>;
22+
if (auto args{
23+
GetConstantArguments<Int4, Int4, T>(context, funcRef.arguments())}) {
24+
const std::string &name{std::get<SpecificIntrinsic>(funcRef.proc().u).name};
25+
if (auto elementalBessel{GetHostRuntimeWrapper<T, Int4, T>(name)}) {
26+
std::vector<Scalar<T>> results;
27+
int n1{static_cast<int>(
28+
std::get<0>(*args)->GetScalarValue().value().ToInt64())};
29+
int n2{static_cast<int>(
30+
std::get<1>(*args)->GetScalarValue().value().ToInt64())};
31+
Scalar<T> x{std::get<2>(*args)->GetScalarValue().value()};
32+
for (int i{n1}; i <= n2; ++i) {
33+
results.emplace_back((*elementalBessel)(context, Scalar<Int4>{i}, x));
34+
}
35+
return Expr<T>{Constant<T>{
36+
std::move(results), ConstantSubscripts{std::max(n2 - n1 + 1, 0)}}};
37+
} else {
38+
context.messages().Say(
39+
"%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US,
40+
name, T::kind);
41+
}
42+
}
43+
return Expr<T>{std::move(funcRef)};
44+
}
45+
1446
template <int KIND>
1547
Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
1648
FoldingContext &context,
@@ -63,6 +95,8 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
6395
"%s(integer(kind=4), real(kind=%d)) cannot be folded on host"_warn_en_US,
6496
name, KIND);
6597
}
98+
} else {
99+
return FoldTransformationalBessel<T>(std::move(funcRef), context);
66100
}
67101
} else if (name == "abs") { // incl. zabs & cdabs
68102
// Argument can be complex or real
@@ -245,7 +279,6 @@ Expr<Type<TypeCategory::Real, KIND>> FoldIntrinsicFunction(
245279
// TODO: dim, dot_product, fraction, matmul,
246280
// modulo, norm2, rrspacing,
247281
// set_exponent, spacing, transfer,
248-
// bessel_jn (transformational) and bessel_yn (transformational)
249282
return Expr<T>{std::move(funcRef)};
250283
}
251284

flang/lib/Evaluate/intrinsics-library.cpp

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,13 @@ template <typename HostFuncType, HostFuncType func> class FolderFactory {
192192

193193
// Define host runtime libraries that can be used for folding and
194194
// fill their description if they are available.
195-
enum class LibraryVersion { Libm, PgmathFast, PgmathRelaxed, PgmathPrecise };
195+
enum class LibraryVersion {
196+
Libm,
197+
LibmExtensions,
198+
PgmathFast,
199+
PgmathRelaxed,
200+
PgmathPrecise
201+
};
196202
template <typename HostT, LibraryVersion> struct HostRuntimeLibrary {
197203
// When specialized, this class holds a static constexpr table containing
198204
// all the HostRuntimeLibrary for functions of library LibraryVersion
@@ -277,6 +283,64 @@ struct HostRuntimeLibrary<std::complex<HostT>, LibraryVersion::Libm> {
277283
static constexpr HostRuntimeMap map{table};
278284
static_assert(map.Verify(), "map must be sorted");
279285
};
286+
// Note regarding cmath:
287+
// - cmath does not have modulo and erfc_scaled equivalent
288+
// - C++17 defined standard Bessel math functions std::cyl_bessel_j
289+
// and std::cyl_neumann that can be used for Fortran j and y
290+
// bessel functions. However, they are not yet implemented in
291+
// clang libc++ (ok in GNU libstdc++). Instead, the Posix libm
292+
// extensions are used when available below.
293+
294+
#if _POSIX_C_SOURCE >= 200112L || _XOPEN_SOURCE >= 600
295+
/// Define libm extensions
296+
/// Bessel functions are defined in POSIX.1-2001.
297+
298+
template <> struct HostRuntimeLibrary<float, LibraryVersion::LibmExtensions> {
299+
using F = FuncPointer<float, float>;
300+
using FN = FuncPointer<float, int, float>;
301+
static constexpr HostRuntimeFunction table[]{
302+
FolderFactory<F, F{::j0f}>::Create("bessel_j0"),
303+
FolderFactory<F, F{::j1f}>::Create("bessel_j1"),
304+
FolderFactory<FN, FN{::jnf}>::Create("bessel_jn"),
305+
FolderFactory<F, F{::y0f}>::Create("bessel_y0"),
306+
FolderFactory<F, F{::y1f}>::Create("bessel_y1"),
307+
FolderFactory<FN, FN{::ynf}>::Create("bessel_yn"),
308+
};
309+
static constexpr HostRuntimeMap map{table};
310+
static_assert(map.Verify(), "map must be sorted");
311+
};
312+
313+
template <> struct HostRuntimeLibrary<double, LibraryVersion::LibmExtensions> {
314+
using F = FuncPointer<double, double>;
315+
using FN = FuncPointer<double, int, double>;
316+
static constexpr HostRuntimeFunction table[]{
317+
FolderFactory<F, F{::j0}>::Create("bessel_j0"),
318+
FolderFactory<F, F{::j1}>::Create("bessel_j1"),
319+
FolderFactory<FN, FN{::jn}>::Create("bessel_jn"),
320+
FolderFactory<F, F{::y0}>::Create("bessel_y0"),
321+
FolderFactory<F, F{::y1}>::Create("bessel_y1"),
322+
FolderFactory<FN, FN{::yn}>::Create("bessel_yn"),
323+
};
324+
static constexpr HostRuntimeMap map{table};
325+
static_assert(map.Verify(), "map must be sorted");
326+
};
327+
328+
template <>
329+
struct HostRuntimeLibrary<long double, LibraryVersion::LibmExtensions> {
330+
using F = FuncPointer<long double, long double>;
331+
using FN = FuncPointer<long double, int, long double>;
332+
static constexpr HostRuntimeFunction table[]{
333+
FolderFactory<F, F{::j0l}>::Create("bessel_j0"),
334+
FolderFactory<F, F{::j1l}>::Create("bessel_j1"),
335+
FolderFactory<FN, FN{::jnl}>::Create("bessel_jn"),
336+
FolderFactory<F, F{::y0l}>::Create("bessel_y0"),
337+
FolderFactory<F, F{::y1l}>::Create("bessel_y1"),
338+
FolderFactory<FN, FN{::ynl}>::Create("bessel_yn"),
339+
};
340+
static constexpr HostRuntimeMap map{table};
341+
static_assert(map.Verify(), "map must be sorted");
342+
};
343+
#endif
280344

281345
/// Define pgmath description
282346
#if LINK_WITH_LIBPGMATH
@@ -409,6 +473,8 @@ static const HostRuntimeMap *GetHostRuntimeMap(
409473
switch (version) {
410474
case LibraryVersion::Libm:
411475
return GetHostRuntimeMapVersion<LibraryVersion::Libm>(resultType);
476+
case LibraryVersion::LibmExtensions:
477+
return GetHostRuntimeMapVersion<LibraryVersion::LibmExtensions>(resultType);
412478
case LibraryVersion::PgmathPrecise:
413479
return GetHostRuntimeMapVersion<LibraryVersion::PgmathPrecise>(resultType);
414480
case LibraryVersion::PgmathRelaxed:
@@ -454,6 +520,13 @@ static const HostRuntimeFunction *SearchHostRuntime(const std::string &name,
454520
return hostFunction;
455521
}
456522
}
523+
if (const auto *map{
524+
GetHostRuntimeMap(LibraryVersion::LibmExtensions, resultType)}) {
525+
if (const auto *hostFunction{
526+
SearchInHostRuntimeMap(*map, name, resultType, argTypes)}) {
527+
return hostFunction;
528+
}
529+
}
457530
return nullptr;
458531
}
459532

flang/test/Evaluate/folding02.f90

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,22 @@ module m
249249
(-0.93219375976297402797143831776338629424571990966796875_8))
250250
TEST_R8(erfc_scaled, erfc_scaled(0.1_8), &
251251
0.89645697996912654392787089818739332258701324462890625_8)
252+
253+
real(4), parameter :: bessel_jn_transformational(*) = bessel_jn(1,3, 3.2_4)
254+
logical, parameter :: test_bessel_jn_shape = size(bessel_jn_transformational, 1).eq.3
255+
logical, parameter :: test_bessel_jn_t1 = bessel_jn_transformational(1).eq.bessel_jn(1, 3.2_4)
256+
logical, parameter :: test_bessel_jn_t2 = bessel_jn_transformational(2).eq.bessel_jn(2, 3.2_4)
257+
logical, parameter :: test_bessel_jn_t3 = bessel_jn_transformational(3).eq.bessel_jn(3, 3.2_4)
258+
real(4), parameter :: bessel_jn_empty(*) = bessel_jn(3,1, 3.2_4)
259+
logical, parameter :: test_bessel_jn_empty = size(bessel_jn_empty, 1).eq.0
260+
261+
real(4), parameter :: bessel_yn_transformational(*) = bessel_yn(1,3, 1.6_4)
262+
logical, parameter :: test_bessel_yn_shape = size(bessel_yn_transformational, 1).eq.3
263+
logical, parameter :: test_bessel_yn_t1 = bessel_yn_transformational(1).eq.bessel_yn(1, 1.6_4)
264+
logical, parameter :: test_bessel_yn_t2 = bessel_yn_transformational(2).eq.bessel_yn(2, 1.6_4)
265+
logical, parameter :: test_bessel_yn_t3 = bessel_yn_transformational(3).eq.bessel_yn(3, 1.6_4)
266+
real(4), parameter :: bessel_yn_empty(*) = bessel_yn(3,1, 3.2_4)
267+
logical, parameter :: test_bessel_yn_empty = size(bessel_yn_empty, 1).eq.0
252268
#endif
253269

254270
! Test exponentiation by real or complex folding (it is using host runtime)

0 commit comments

Comments
 (0)