Skip to content

[flang][runtime] Support NORM2 for REAL(16) with FortranFloat128Math lib. #83219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 28, 2024

Conversation

vzakhari
Copy link
Contributor

Changed the lowering to call Norm2DimReal16 for REAL(16).
Added the corresponding entry point to FortranFloat128Math,
which required some restructuring in the related templates.

…lib.

Changed the lowering to call Norm2DimReal16 for REAL(16).
Added the corresponding entry point to FortranFloat128Math,
which required some restructuring in the related templates.
@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category flang:fir-hlfir labels Feb 28, 2024
@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2024

@llvm/pr-subscribers-flang-runtime

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

Changed the lowering to call Norm2DimReal16 for REAL(16).
Added the corresponding entry point to FortranFloat128Math,
which required some restructuring in the related templates.


Full diff: https://github.com/llvm/llvm-project/pull/83219.diff

9 Files Affected:

  • (modified) flang/include/flang/Runtime/reduction.h (+4-1)
  • (modified) flang/lib/Optimizer/Builder/Runtime/Reduction.cpp (+24-1)
  • (modified) flang/runtime/Float128Math/CMakeLists.txt (+1)
  • (modified) flang/runtime/Float128Math/math-entries.h (+16)
  • (added) flang/runtime/Float128Math/norm2.cpp (+59)
  • (modified) flang/runtime/extrema.cpp (+12-95)
  • (modified) flang/runtime/reduction-templates.h (+115)
  • (modified) flang/runtime/tools.h (+9-2)
  • (modified) flang/test/Lower/Intrinsics/norm2.f90 (+16)
diff --git a/flang/include/flang/Runtime/reduction.h b/flang/include/flang/Runtime/reduction.h
index 6d62f4016937e0..5b607765857523 100644
--- a/flang/include/flang/Runtime/reduction.h
+++ b/flang/include/flang/Runtime/reduction.h
@@ -364,9 +364,12 @@ double RTDECL(Norm2_8)(
 #if LDBL_MANT_DIG == 64
 long double RTDECL(Norm2_10)(
     const Descriptor &, const char *source, int line, int dim = 0);
-#elif LDBL_MANT_DIG == 113
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
 long double RTDECL(Norm2_16)(
     const Descriptor &, const char *source, int line, int dim = 0);
+void RTDECL(Norm2DimReal16)(
+    Descriptor &, const Descriptor &, int dim, const char *source, int line);
 #endif
 void RTDECL(Norm2Dim)(
     Descriptor &, const Descriptor &, int dim, const char *source, int line);
diff --git a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
index fabbff818b6f0e..66fbaddcbda1aa 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Reduction.cpp
@@ -149,6 +149,22 @@ struct ForcedNorm2Real16 {
   }
 };
 
+/// Placeholder for real*16 version of Norm2Dim Intrinsic
+struct ForcedNorm2DimReal16 {
+  static constexpr const char *name = ExpandAndQuoteKey(RTNAME(Norm2DimReal16));
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
+      auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
+      return mlir::FunctionType::get(
+          ctx, {fir::ReferenceType::get(boxTy), boxTy, intTy, strTy, intTy},
+          mlir::NoneType::get(ctx));
+    };
+  }
+};
+
 /// Placeholder for real*10 version of Product Intrinsic
 struct ForcedProductReal10 {
   static constexpr const char *name = ExpandAndQuoteKey(RTNAME(ProductReal10));
@@ -876,7 +892,14 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
 void fir::runtime::genNorm2Dim(fir::FirOpBuilder &builder, mlir::Location loc,
                                mlir::Value resultBox, mlir::Value arrayBox,
                                mlir::Value dim) {
-  auto func = fir::runtime::getRuntimeFunc<mkRTKey(Norm2Dim)>(loc, builder);
+  mlir::func::FuncOp func;
+  auto ty = arrayBox.getType();
+  auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
+  auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
+  if (eleTy.isF128())
+    func = fir::runtime::getRuntimeFunc<ForcedNorm2DimReal16>(loc, builder);
+  else
+    func = fir::runtime::getRuntimeFunc<mkRTKey(Norm2Dim)>(loc, builder);
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);
   auto sourceLine =
diff --git a/flang/runtime/Float128Math/CMakeLists.txt b/flang/runtime/Float128Math/CMakeLists.txt
index 8d276e8f122728..f11678cd70b769 100644
--- a/flang/runtime/Float128Math/CMakeLists.txt
+++ b/flang/runtime/Float128Math/CMakeLists.txt
@@ -69,6 +69,7 @@ set(sources
   log.cpp
   log10.cpp
   lround.cpp
+  norm2.cpp
   pow.cpp
   round.cpp
   sin.cpp
diff --git a/flang/runtime/Float128Math/math-entries.h b/flang/runtime/Float128Math/math-entries.h
index 83298674c4971f..a0d81d0cbb5407 100644
--- a/flang/runtime/Float128Math/math-entries.h
+++ b/flang/runtime/Float128Math/math-entries.h
@@ -54,6 +54,7 @@ namespace Fortran::runtime {
   };
 
 // Define fallback callers.
+DEFINE_FALLBACK(Abs)
 DEFINE_FALLBACK(Acos)
 DEFINE_FALLBACK(Acosh)
 DEFINE_FALLBACK(Asin)
@@ -99,6 +100,7 @@ DEFINE_FALLBACK(Yn)
 // Use STD math functions. They provide IEEE-754 128-bit float
 // support either via 'long double' or __float128.
 // The Bessel's functions are not present in STD namespace.
+DEFINE_SIMPLE_ALIAS(Abs, std::abs)
 DEFINE_SIMPLE_ALIAS(Acos, std::acos)
 DEFINE_SIMPLE_ALIAS(Acosh, std::acosh)
 DEFINE_SIMPLE_ALIAS(Asin, std::asin)
@@ -155,6 +157,7 @@ DEFINE_SIMPLE_ALIAS(Yn, ynl)
 #elif HAS_QUADMATHLIB
 // Define wrapper callers for libquadmath.
 #include "quadmath.h"
+DEFINE_SIMPLE_ALIAS(Abs, fabsq)
 DEFINE_SIMPLE_ALIAS(Acos, acosq)
 DEFINE_SIMPLE_ALIAS(Acosh, acoshq)
 DEFINE_SIMPLE_ALIAS(Asin, asinq)
@@ -191,6 +194,19 @@ DEFINE_SIMPLE_ALIAS(Y0, y0q)
 DEFINE_SIMPLE_ALIAS(Y1, y1q)
 DEFINE_SIMPLE_ALIAS(Yn, ynq)
 #endif
+
+extern "C" {
+// Declarations of the entry points that might be referenced
+// within the Float128Math library itself.
+// Note that not all of these entry points are actually
+// defined in this library. Some of them are used just
+// as template parameters to call the corresponding callee directly.
+CppTypeFor<TypeCategory::Real, 16> RTDECL(AbsF128)(
+    CppTypeFor<TypeCategory::Real, 16> x);
+CppTypeFor<TypeCategory::Real, 16> RTDECL(SqrtF128)(
+    CppTypeFor<TypeCategory::Real, 16> x);
+} // extern "C"
+
 } // namespace Fortran::runtime
 
 #endif // FORTRAN_RUNTIME_FLOAT128MATH_MATH_ENTRIES_H_
diff --git a/flang/runtime/Float128Math/norm2.cpp b/flang/runtime/Float128Math/norm2.cpp
new file mode 100644
index 00000000000000..17453bd2d6cbd7
--- /dev/null
+++ b/flang/runtime/Float128Math/norm2.cpp
@@ -0,0 +1,59 @@
+//===-- runtime/Float128Math/norm2.cpp ------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "math-entries.h"
+#include "reduction-templates.h"
+#include <cmath>
+
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+
+namespace {
+using namespace Fortran::runtime;
+
+using AccumType = Norm2AccumType<16>;
+
+struct ABSTy {
+  static AccumType compute(AccumType x) {
+    return Sqrt<RTNAME(AbsF128)>::invoke(x);
+  }
+};
+
+struct SQRTTy {
+  static AccumType compute(AccumType x) {
+    return Sqrt<RTNAME(SqrtF128)>::invoke(x);
+  }
+};
+
+using Float128Norm2Accumulator = Norm2Accumulator<16, ABSTy, SQRTTy>;
+} // namespace
+
+namespace Fortran::runtime {
+extern "C" {
+
+CppTypeFor<TypeCategory::Real, 16> RTDEF(Norm2_16)(
+    const Descriptor &x, const char *source, int line, int dim) {
+  auto accumulator{::Float128Norm2Accumulator(x)};
+  return GetTotalReduction<TypeCategory::Real, 16>(
+      x, source, line, dim, nullptr, accumulator, "NORM2");
+}
+
+void RTDEF(Norm2DimReal16)(Descriptor &result, const Descriptor &x, int dim,
+    const char *source, int line) {
+  Terminator terminator{source, line};
+  auto type{x.type().GetCategoryAndKind()};
+  RUNTIME_CHECK(terminator, type);
+  RUNTIME_CHECK(
+      terminator, type->first == TypeCategory::Real && type->second == 16);
+  DoMaxMinNorm2<TypeCategory::Real, 16, ::Float128Norm2Accumulator>(
+      result, x, dim, nullptr, "NORM2", terminator);
+}
+
+} // extern "C"
+} // namespace Fortran::runtime
+
+#endif
diff --git a/flang/runtime/extrema.cpp b/flang/runtime/extrema.cpp
index 3fdc8e159866d1..fc2b4e165cb269 100644
--- a/flang/runtime/extrema.cpp
+++ b/flang/runtime/extrema.cpp
@@ -528,35 +528,6 @@ inline RT_API_ATTRS CppTypeFor<CAT, KIND> TotalNumericMaxOrMin(
       NumericExtremumAccumulator<CAT, KIND, IS_MAXVAL>{x}, intrinsic);
 }
 
-template <TypeCategory CAT, int KIND, typename ACCUMULATOR>
-static RT_API_ATTRS void DoMaxMinNorm2(Descriptor &result, const Descriptor &x,
-    int dim, const Descriptor *mask, const char *intrinsic,
-    Terminator &terminator) {
-  using Type = CppTypeFor<CAT, KIND>;
-  ACCUMULATOR accumulator{x};
-  if (dim == 0 || x.rank() == 1) {
-    // Total reduction
-
-    // Element size of the destination descriptor is the same
-    // as the element size of the source.
-    result.Establish(x.type(), x.ElementBytes(), nullptr, 0, nullptr,
-        CFI_attribute_allocatable);
-    if (int stat{result.Allocate()}) {
-      terminator.Crash(
-          "%s: could not allocate memory for result; STAT=%d", intrinsic, stat);
-    }
-    DoTotalReduction<Type>(x, dim, mask, accumulator, intrinsic, terminator);
-    accumulator.GetResult(result.OffsetElement<Type>());
-  } else {
-    // Partial reduction
-
-    // Element size of the destination descriptor is the same
-    // as the element size of the source.
-    PartialReduction<ACCUMULATOR, CAT, KIND>(result, x, x.ElementBytes(), dim,
-        mask, terminator, intrinsic, accumulator);
-  }
-}
-
 template <TypeCategory CAT, bool IS_MAXVAL> struct MaxOrMinHelper {
   template <int KIND> struct Functor {
     RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
@@ -802,66 +773,11 @@ RT_EXT_API_GROUP_END
 
 // NORM2
 
-RT_VAR_GROUP_BEGIN
-
-// Use at least double precision for accumulators.
-// Don't use __float128, it doesn't work with abs() or sqrt() yet.
-static constexpr RT_CONST_VAR_ATTRS int largestLDKind {
-#if LDBL_MANT_DIG == 113
-  16
-#elif LDBL_MANT_DIG == 64
-  10
-#else
-  8
-#endif
-};
-
-RT_VAR_GROUP_END
-
-template <int KIND> class Norm2Accumulator {
-public:
-  using Type = CppTypeFor<TypeCategory::Real, KIND>;
-  using AccumType =
-      CppTypeFor<TypeCategory::Real, std::clamp(KIND, 8, largestLDKind)>;
-  explicit RT_API_ATTRS Norm2Accumulator(const Descriptor &array)
-      : array_{array} {}
-  RT_API_ATTRS void Reinitialize() { max_ = sum_ = 0; }
-  template <typename A>
-  RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
-    // m * sqrt(1 + sum((others(:)/m)**2))
-    *p = static_cast<Type>(max_ * std::sqrt(1 + sum_));
-  }
-  RT_API_ATTRS bool Accumulate(Type x) {
-    auto absX{std::abs(static_cast<AccumType>(x))};
-    if (!max_) {
-      max_ = absX;
-    } else if (absX > max_) {
-      auto t{max_ / absX}; // < 1.0
-      auto tsq{t * t};
-      sum_ *= tsq; // scale sum to reflect change to the max
-      sum_ += tsq; // include a term for the previous max
-      max_ = absX;
-    } else { // absX <= max_
-      auto t{absX / max_};
-      sum_ += t * t;
-    }
-    return true;
-  }
-  template <typename A>
-  RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
-    return Accumulate(*array_.Element<A>(at));
-  }
-
-private:
-  const Descriptor &array_;
-  AccumType max_{0}; // value (m) with largest magnitude
-  AccumType sum_{0}; // sum((others(:)/m)**2)
-};
-
 template <int KIND> struct Norm2Helper {
   RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x, int dim,
       const Descriptor *mask, Terminator &terminator) const {
-    DoMaxMinNorm2<TypeCategory::Real, KIND, Norm2Accumulator<KIND>>(
+    DoMaxMinNorm2<TypeCategory::Real, KIND,
+        typename Norm2AccumulatorGetter<KIND>::Type>(
         result, x, dim, mask, "NORM2", terminator);
   }
 };
@@ -872,26 +788,27 @@ RT_EXT_API_GROUP_BEGIN
 // TODO: REAL(2 & 3)
 CppTypeFor<TypeCategory::Real, 4> RTDEF(Norm2_4)(
     const Descriptor &x, const char *source, int line, int dim) {
-  return GetTotalReduction<TypeCategory::Real, 4>(
-      x, source, line, dim, nullptr, Norm2Accumulator<4>{x}, "NORM2");
+  return GetTotalReduction<TypeCategory::Real, 4>(x, source, line, dim, nullptr,
+      Norm2AccumulatorGetter<4>::create(x), "NORM2");
 }
 CppTypeFor<TypeCategory::Real, 8> RTDEF(Norm2_8)(
     const Descriptor &x, const char *source, int line, int dim) {
-  return GetTotalReduction<TypeCategory::Real, 8>(
-      x, source, line, dim, nullptr, Norm2Accumulator<8>{x}, "NORM2");
+  return GetTotalReduction<TypeCategory::Real, 8>(x, source, line, dim, nullptr,
+      Norm2AccumulatorGetter<8>::create(x), "NORM2");
 }
 #if LDBL_MANT_DIG == 64
 CppTypeFor<TypeCategory::Real, 10> RTDEF(Norm2_10)(
     const Descriptor &x, const char *source, int line, int dim) {
-  return GetTotalReduction<TypeCategory::Real, 10>(
-      x, source, line, dim, nullptr, Norm2Accumulator<10>{x}, "NORM2");
+  return GetTotalReduction<TypeCategory::Real, 10>(x, source, line, dim,
+      nullptr, Norm2AccumulatorGetter<10>::create(x), "NORM2");
 }
 #endif
 #if LDBL_MANT_DIG == 113
+// The __float128 implementation resides in FortranFloat128Math library.
 CppTypeFor<TypeCategory::Real, 16> RTDEF(Norm2_16)(
     const Descriptor &x, const char *source, int line, int dim) {
-  return GetTotalReduction<TypeCategory::Real, 16>(
-      x, source, line, dim, nullptr, Norm2Accumulator<16>{x}, "NORM2");
+  return GetTotalReduction<TypeCategory::Real, 16>(x, source, line, dim,
+      nullptr, Norm2AccumulatorGetter<16>::create(x), "NORM2");
 }
 #endif
 
@@ -901,7 +818,7 @@ void RTDEF(Norm2Dim)(Descriptor &result, const Descriptor &x, int dim,
   auto type{x.type().GetCategoryAndKind()};
   RUNTIME_CHECK(terminator, type);
   if (type->first == TypeCategory::Real) {
-    ApplyFloatingPointKind<Norm2Helper, void>(
+    ApplyFloatingPointKind<Norm2Helper, void, true>(
         type->second, terminator, result, x, dim, nullptr, terminator);
   } else {
     terminator.Crash("NORM2: bad type code %d", x.type().raw());
diff --git a/flang/runtime/reduction-templates.h b/flang/runtime/reduction-templates.h
index 7d0f82d59a084d..0891bc021ff753 100644
--- a/flang/runtime/reduction-templates.h
+++ b/flang/runtime/reduction-templates.h
@@ -25,6 +25,7 @@
 #include "tools.h"
 #include "flang/Runtime/cpp-type.h"
 #include "flang/Runtime/descriptor.h"
+#include <algorithm>
 
 namespace Fortran::runtime {
 
@@ -332,5 +333,119 @@ template <typename ACCUMULATOR> struct PartialLocationHelper {
   };
 };
 
+// NORM2 templates
+
+RT_VAR_GROUP_BEGIN
+
+// Use at least double precision for accumulators.
+// Don't use __float128, it doesn't work with abs() or sqrt() yet.
+static constexpr RT_CONST_VAR_ATTRS int Norm2LargestLDKind {
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+  16
+#elif LDBL_MANT_DIG == 64
+  10
+#else
+  8
+#endif
+};
+
+RT_VAR_GROUP_END
+
+template <TypeCategory CAT, int KIND, typename ACCUMULATOR>
+inline RT_API_ATTRS void DoMaxMinNorm2(Descriptor &result, const Descriptor &x,
+    int dim, const Descriptor *mask, const char *intrinsic,
+    Terminator &terminator) {
+  using Type = CppTypeFor<CAT, KIND>;
+  ACCUMULATOR accumulator{x};
+  if (dim == 0 || x.rank() == 1) {
+    // Total reduction
+
+    // Element size of the destination descriptor is the same
+    // as the element size of the source.
+    result.Establish(x.type(), x.ElementBytes(), nullptr, 0, nullptr,
+        CFI_attribute_allocatable);
+    if (int stat{result.Allocate()}) {
+      terminator.Crash(
+          "%s: could not allocate memory for result; STAT=%d", intrinsic, stat);
+    }
+    DoTotalReduction<Type>(x, dim, mask, accumulator, intrinsic, terminator);
+    accumulator.GetResult(result.OffsetElement<Type>());
+  } else {
+    // Partial reduction
+
+    // Element size of the destination descriptor is the same
+    // as the element size of the source.
+    PartialReduction<ACCUMULATOR, CAT, KIND>(result, x, x.ElementBytes(), dim,
+        mask, terminator, intrinsic, accumulator);
+  }
+}
+
+// The data type used by Norm2Accumulator.
+template <int KIND>
+using Norm2AccumType =
+    CppTypeFor<TypeCategory::Real, std::clamp(KIND, 8, Norm2LargestLDKind)>;
+
+template <int KIND, typename ABS, typename SQRT> class Norm2Accumulator {
+public:
+  using Type = CppTypeFor<TypeCategory::Real, KIND>;
+  using AccumType = Norm2AccumType<KIND>;
+  explicit RT_API_ATTRS Norm2Accumulator(const Descriptor &array)
+      : array_{array} {}
+  RT_API_ATTRS void Reinitialize() { max_ = sum_ = 0; }
+  template <typename A>
+  RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
+    // m * sqrt(1 + sum((others(:)/m)**2))
+    *p = static_cast<Type>(max_ * SQRT::compute(1 + sum_));
+  }
+  RT_API_ATTRS bool Accumulate(Type x) {
+    auto absX{ABS::compute(static_cast<AccumType>(x))};
+    if (!max_) {
+      max_ = absX;
+    } else if (absX > max_) {
+      auto t{max_ / absX}; // < 1.0
+      auto tsq{t * t};
+      sum_ *= tsq; // scale sum to reflect change to the max
+      sum_ += tsq; // include a term for the previous max
+      max_ = absX;
+    } else { // absX <= max_
+      auto t{absX / max_};
+      sum_ += t * t;
+    }
+    return true;
+  }
+  template <typename A>
+  RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
+    return Accumulate(*array_.Element<A>(at));
+  }
+
+private:
+  const Descriptor &array_;
+  AccumType max_{0}; // value (m) with largest magnitude
+  AccumType sum_{0}; // sum((others(:)/m)**2)
+};
+
+// Helper class for creating Norm2Accumulator instance
+// based on the given KIND. This helper returns and instance
+// that uses std::abs and std::sqrt for the computations.
+template <int KIND> class Norm2AccumulatorGetter {
+  using AccumType = Norm2AccumType<KIND>;
+
+public:
+  struct ABSTy {
+    static constexpr RT_API_ATTRS AccumType compute(AccumType &&x) {
+      return std::abs(std::forward<AccumType>(x));
+    }
+  };
+  struct SQRTTy {
+    static constexpr RT_API_ATTRS AccumType compute(AccumType &&x) {
+      return std::sqrt(std::forward<AccumType>(x));
+    }
+  };
+
+  using Type = Norm2Accumulator<KIND, ABSTy, SQRTTy>;
+
+  static RT_API_ATTRS Type create(const Descriptor &x) { return Type(x); }
+};
+
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_REDUCTION_TEMPLATES_H_
diff --git a/flang/runtime/tools.h b/flang/runtime/tools.h
index 89e5069995748b..c1f89cadca06e7 100644
--- a/flang/runtime/tools.h
+++ b/flang/runtime/tools.h
@@ -266,7 +266,8 @@ inline RT_API_ATTRS RESULT ApplyIntegerKind(
   }
 }
 
-template <template <int KIND> class FUNC, typename RESULT, typename... A>
+template <template <int KIND> class FUNC, typename RESULT,
+    bool NEEDSMATH = false, typename... A>
 inline RT_API_ATTRS RESULT ApplyFloatingPointKind(
     int kind, Terminator &terminator, A &&...x) {
   switch (kind) {
@@ -287,7 +288,13 @@ inline RT_API_ATTRS RESULT ApplyFloatingPointKind(
     break;
   case 16:
     if constexpr (HasCppTypeFor<TypeCategory::Real, 16>) {
-      return FUNC<16>{}(std::forward<A>(x)...);
+      // If FUNC implemenation relies on FP math functions,
+      // then we should not be here. The compiler should have
+      // generated a call to an entry in FortranFloat128Math
+      // library.
+      if constexpr (!NEEDSMATH) {
+        return FUNC<16>{}(std::forward<A>(x)...);
+      }
     }
     break;
   }
diff --git a/flang/test/Lower/Intrinsics/norm2.f90 b/flang/test/Lower/Intrinsics/norm2.f90
index f14cad59d5bd3b..0d125e36f6650a 100644
--- a/flang/test/Lower/Intrinsics/norm2.f90
+++ b/flang/test/Lower/Intrinsics/norm2.f90
@@ -76,3 +76,19 @@ subroutine norm2_test_dim_3(a,r)
   ! CHECK-DAG:  %[[addr:.*]] = fir.box_addr %[[box]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.heap<!fir.array<?x?xf32>>
   ! CHECK-DAG:  fir.freemem %[[addr]]
 end subroutine norm2_test_dim_3
+
+! CHECK-LABEL: func @_QPnorm2_test_real16(
+! CHECK-SAME: %[[arg0:.*]]: !fir.box<!fir.array<?x?x?xf128>>{{.*}}, %[[arg1:.*]]: !fir.box<!fir.array<?x?xf128>>{{.*}})
+subroutine norm2_test_real16(a,r)
+  real(16) :: a(:,:,:)
+  real(16) :: r(:,:)
+  ! CHECK-DAG:  %[[dim:.*]] = arith.constant 3 : i32
+  ! CHECK-DAG:  %[[r:.*]] = fir.alloca !fir.box<!fir.heap<!fir.array<?x?xf128>>>
+  ! CHECK-DAG:  %[[res:.*]] = fir.convert %[[r]] : (!fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf128>>>>) -> !fir.ref<!fir.box<none>>
+  ! CHECK:  %[[arr:.*]] = fir.convert %[[arg0]] : (!fir.box<!fir.array<?x?x?xf128>>) -> !fir.box<none>
+  r = norm2(a,dim=3)
+  ! CHECK:  %{{.*}} = fir.call @_FortranANorm2DimReal16(%[[res]], %[[arr]], %[[dim]], %{{.*}}, %{{.*}}) {{.*}} : (!fir.ref<!fir.box<none>>, !fir.box<none>, i32, !fir.ref<i8>, i32) -> none
+  ! CHECK:  %[[box:.*]] = fir.load %[[r]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?x?xf128>>>>
+  ! CHECK-DAG:  %[[addr:.*]] = fir.box_addr %[[box]] : (!fir.box<!fir.heap<!fir.array<?x?xf128>>>) -> !fir.heap<!fir.array<?x?xf128>>
+  ! CHECK-DAG:  fir.freemem %[[addr]]
+end subroutine norm2_test_real16

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +159 to +160
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would advise using fir::runtime::getModel<const char*> and fir::runtime::getModel<int> for strTy and intTy too to make any future work enabling cross compilation easier/centralize the work.

If this is the way it is done for other Forced... helper, you can keep it that way and we should update all usages in one go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I created enhancement request #83290 for this. We use direct types creation all over the place.

@vzakhari vzakhari merged commit baf6725 into llvm:main Feb 28, 2024
mylai-mtk pushed a commit to mylai-mtk/llvm-project that referenced this pull request Jul 12, 2024
…lib. (llvm#83219)

Changed the lowering to call Norm2DimReal16 for REAL(16).
Added the corresponding entry point to FortranFloat128Math,
which required some restructuring in the related templates.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants