Skip to content

[flang][runtime] Split MATMUL[_TRANSPOSE] into separate entries. #97406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2024

Conversation

vzakhari
Copy link
Contributor

@vzakhari vzakhari commented Jul 2, 2024

Device compilation is much faster for separate MATMUL[_TRANPOSE]
entries than for a single one that covers all data types.
The lowering changes and the removal of the generic entries will follow.

Device compilation is much faster for separate MATMUL[_TRANPOSE]
entries than for a single one that covers all data types.
The lowering changes and the removal of the generic entries will follow.
@vzakhari vzakhari requested a review from klausler July 2, 2024 12:06
@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category labels Jul 2, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2024

@llvm/pr-subscribers-flang-runtime

Author: Slava Zakharin (vzakhari)

Changes

Device compilation is much faster for separate MATMUL[_TRANPOSE]
entries than for a single one that covers all data types.
The lowering changes and the removal of the generic entries will follow.


Patch is 39.66 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97406.diff

7 Files Affected:

  • (added) flang/include/flang/Runtime/matmul-instances.inc (+261)
  • (modified) flang/include/flang/Runtime/matmul-transpose.h (+17)
  • (modified) flang/include/flang/Runtime/matmul.h (+17)
  • (modified) flang/runtime/matmul-transpose.cpp (+42)
  • (modified) flang/runtime/matmul.cpp (+48-2)
  • (modified) flang/unittests/Runtime/Matmul.cpp (+121)
  • (modified) flang/unittests/Runtime/MatmulTranspose.cpp (+140)
diff --git a/flang/include/flang/Runtime/matmul-instances.inc b/flang/include/flang/Runtime/matmul-instances.inc
new file mode 100644
index 0000000000000..970b03339cd5e
--- /dev/null
+++ b/flang/include/flang/Runtime/matmul-instances.inc
@@ -0,0 +1,261 @@
+//===-- include/flang/Runtime/matmul-instances.inc --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Helper macros to instantiate MATMUL/MATMUL_TRANSPOSE definitions
+// for different data types of the input arguments.
+//===----------------------------------------------------------------------===//
+
+#ifndef MATMUL_INSTANCE
+#error "Define MATMUL_INSTANCE before including this file"
+#endif
+
+#ifndef MATMUL_DIRECT_INSTANCE
+#error "Define MATMUL_DIRECT_INSTANCE before including this file"
+#endif
+
+// clang-format off
+
+#define FOREACH_MATMUL_TYPE_PAIR(macro)         \
+  macro(Integer, 1, Integer, 1)                 \
+  macro(Integer, 1, Integer, 2)                 \
+  macro(Integer, 1, Integer, 4)                 \
+  macro(Integer, 1, Integer, 8)                 \
+  macro(Integer, 2, Integer, 1)                 \
+  macro(Integer, 2, Integer, 2)                 \
+  macro(Integer, 2, Integer, 4)                 \
+  macro(Integer, 2, Integer, 8)                 \
+  macro(Integer, 4, Integer, 1)                 \
+  macro(Integer, 4, Integer, 2)                 \
+  macro(Integer, 4, Integer, 4)                 \
+  macro(Integer, 4, Integer, 8)                 \
+  macro(Integer, 8, Integer, 1)                 \
+  macro(Integer, 8, Integer, 2)                 \
+  macro(Integer, 8, Integer, 4)                 \
+  macro(Integer, 8, Integer, 8)                 \
+  macro(Integer, 1, Real, 4)                    \
+  macro(Integer, 1, Real, 8)                    \
+  macro(Integer, 2, Real, 4)                    \
+  macro(Integer, 2, Real, 8)                    \
+  macro(Integer, 4, Real, 4)                    \
+  macro(Integer, 4, Real, 8)                    \
+  macro(Integer, 8, Real, 4)                    \
+  macro(Integer, 8, Real, 8)                    \
+  macro(Integer, 1, Complex, 4)                 \
+  macro(Integer, 1, Complex, 8)                 \
+  macro(Integer, 2, Complex, 4)                 \
+  macro(Integer, 2, Complex, 8)                 \
+  macro(Integer, 4, Complex, 4)                 \
+  macro(Integer, 4, Complex, 8)                 \
+  macro(Integer, 8, Complex, 4)                 \
+  macro(Integer, 8, Complex, 8)                 \
+  macro(Real, 4, Integer, 1)                    \
+  macro(Real, 4, Integer, 2)                    \
+  macro(Real, 4, Integer, 4)                    \
+  macro(Real, 4, Integer, 8)                    \
+  macro(Real, 8, Integer, 1)                    \
+  macro(Real, 8, Integer, 2)                    \
+  macro(Real, 8, Integer, 4)                    \
+  macro(Real, 8, Integer, 8)                    \
+  macro(Real, 4, Real, 4)                       \
+  macro(Real, 4, Real, 8)                       \
+  macro(Real, 8, Real, 4)                       \
+  macro(Real, 8, Real, 8)                       \
+  macro(Real, 4, Complex, 4)                    \
+  macro(Real, 4, Complex, 8)                    \
+  macro(Real, 8, Complex, 4)                    \
+  macro(Real, 8, Complex, 8)                    \
+  macro(Complex, 4, Integer, 1)                 \
+  macro(Complex, 4, Integer, 2)                 \
+  macro(Complex, 4, Integer, 4)                 \
+  macro(Complex, 4, Integer, 8)                 \
+  macro(Complex, 8, Integer, 1)                 \
+  macro(Complex, 8, Integer, 2)                 \
+  macro(Complex, 8, Integer, 4)                 \
+  macro(Complex, 8, Integer, 8)                 \
+  macro(Complex, 4, Real, 4)                    \
+  macro(Complex, 4, Real, 8)                    \
+  macro(Complex, 8, Real, 4)                    \
+  macro(Complex, 8, Real, 8)                    \
+  macro(Complex, 4, Complex, 4)                 \
+  macro(Complex, 4, Complex, 8)                 \
+  macro(Complex, 8, Complex, 4)                 \
+  macro(Complex, 8, Complex, 8)                 \
+
+FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
+
+#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro)      \
+  macro(Integer, 16, Integer, 1)                        \
+  macro(Integer, 16, Integer, 2)                        \
+  macro(Integer, 16, Integer, 4)                        \
+  macro(Integer, 16, Integer, 8)                        \
+  macro(Integer, 16, Integer, 16)                       \
+  macro(Integer, 16, Real, 4)                           \
+  macro(Integer, 16, Real, 8)                           \
+  macro(Integer, 16, Complex, 4)                        \
+  macro(Integer, 16, Complex, 8)                        \
+  macro(Real, 4, Integer, 16)                           \
+  macro(Real, 8, Integer, 16)                           \
+  macro(Complex, 4, Integer, 16)                        \
+  macro(Complex, 8, Integer, 16)                        \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
+
+#if LDBL_MANT_DIG == 64
+MATMUL_INSTANCE(Integer, 16, Real, 10)
+MATMUL_INSTANCE(Integer, 16, Complex, 10)
+MATMUL_INSTANCE(Real, 10, Integer, 16)
+MATMUL_INSTANCE(Complex, 10, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
+#endif
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+MATMUL_INSTANCE(Integer, 16, Real, 16)
+MATMUL_INSTANCE(Integer, 16, Complex, 16)
+MATMUL_INSTANCE(Real, 16, Integer, 16)
+MATMUL_INSTANCE(Complex, 16, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 16)
+MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
+#endif
+#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+
+#if LDBL_MANT_DIG == 64
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro)         \
+  macro(Integer, 1, Real, 10)                               \
+  macro(Integer, 1, Complex, 10)                            \
+  macro(Integer, 2, Real, 10)                               \
+  macro(Integer, 2, Complex, 10)                            \
+  macro(Integer, 4, Real, 10)                               \
+  macro(Integer, 4, Complex, 10)                            \
+  macro(Integer, 8, Real, 10)                               \
+  macro(Integer, 8, Complex, 10)                            \
+  macro(Real, 4, Real, 10)                                  \
+  macro(Real, 4, Complex, 10)                               \
+  macro(Real, 8, Real, 10)                                  \
+  macro(Real, 8, Complex, 10)                               \
+  macro(Real, 10, Integer, 1)                               \
+  macro(Real, 10, Integer, 2)                               \
+  macro(Real, 10, Integer, 4)                               \
+  macro(Real, 10, Integer, 8)                               \
+  macro(Real, 10, Real, 4)                                  \
+  macro(Real, 10, Real, 8)                                  \
+  macro(Real, 10, Real, 10)                                 \
+  macro(Real, 10, Complex, 4)                               \
+  macro(Real, 10, Complex, 8)                               \
+  macro(Real, 10, Complex, 10)                              \
+  macro(Complex, 4, Real, 10)                               \
+  macro(Complex, 4, Complex, 10)                            \
+  macro(Complex, 8, Real, 10)                               \
+  macro(Complex, 8, Complex, 10)                            \
+  macro(Complex, 10, Integer, 1)                            \
+  macro(Complex, 10, Integer, 2)                            \
+  macro(Complex, 10, Integer, 4)                            \
+  macro(Complex, 10, Integer, 8)                            \
+  macro(Complex, 10, Real, 4)                               \
+  macro(Complex, 10, Real, 8)                               \
+  macro(Complex, 10, Real, 10)                              \
+  macro(Complex, 10, Complex, 4)                            \
+  macro(Complex, 10, Complex, 8)                            \
+  macro(Complex, 10, Complex, 10)                           \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)
+
+#if HAS_FLOAT128
+MATMUL_INSTANCE(Real, 10, Real, 16)
+MATMUL_INSTANCE(Real, 10, Complex, 16)
+MATMUL_INSTANCE(Real, 16, Real, 10)
+MATMUL_INSTANCE(Real, 16, Complex, 10)
+MATMUL_INSTANCE(Complex, 10, Real, 16)
+MATMUL_INSTANCE(Complex, 10, Complex, 16)
+MATMUL_INSTANCE(Complex, 16, Real, 10)
+MATMUL_INSTANCE(Complex, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Real, 10, Real, 16)
+MATMUL_DIRECT_INSTANCE(Real, 10, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Real, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Real, 16, Complex, 10)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Real, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
+MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
+#endif
+#endif // LDBL_MANT_DIG == 64
+
+#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro)         \
+  macro(Integer, 1, Real, 16)                               \
+  macro(Integer, 1, Complex, 16)                            \
+  macro(Integer, 2, Real, 16)                               \
+  macro(Integer, 2, Complex, 16)                            \
+  macro(Integer, 4, Real, 16)                               \
+  macro(Integer, 4, Complex, 16)                            \
+  macro(Integer, 8, Real, 16)                               \
+  macro(Integer, 8, Complex, 16)                            \
+  macro(Real, 4, Real, 16)                                  \
+  macro(Real, 4, Complex, 16)                               \
+  macro(Real, 8, Real, 16)                                  \
+  macro(Real, 8, Complex, 16)                               \
+  macro(Real, 16, Integer, 1)                               \
+  macro(Real, 16, Integer, 2)                               \
+  macro(Real, 16, Integer, 4)                               \
+  macro(Real, 16, Integer, 8)                               \
+  macro(Real, 16, Real, 4)                                  \
+  macro(Real, 16, Real, 8)                                  \
+  macro(Real, 16, Real, 16)                                 \
+  macro(Real, 16, Complex, 4)                               \
+  macro(Real, 16, Complex, 8)                               \
+  macro(Real, 16, Complex, 16)                              \
+  macro(Complex, 4, Real, 16)                               \
+  macro(Complex, 4, Complex, 16)                            \
+  macro(Complex, 8, Real, 16)                               \
+  macro(Complex, 8, Complex, 16)                            \
+  macro(Complex, 16, Integer, 1)                            \
+  macro(Complex, 16, Integer, 2)                            \
+  macro(Complex, 16, Integer, 4)                            \
+  macro(Complex, 16, Integer, 8)                            \
+  macro(Complex, 16, Real, 4)                               \
+  macro(Complex, 16, Real, 8)                               \
+  macro(Complex, 16, Real, 16)                              \
+  macro(Complex, 16, Complex, 4)                            \
+  macro(Complex, 16, Complex, 8)                            \
+  macro(Complex, 16, Complex, 16)                           \
+
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
+FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
+#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
+
+#define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
+  macro(Logical, 1, Logical, 1)                 \
+  macro(Logical, 1, Logical, 2)                 \
+  macro(Logical, 1, Logical, 4)                 \
+  macro(Logical, 1, Logical, 8)                 \
+  macro(Logical, 2, Logical, 1)                 \
+  macro(Logical, 2, Logical, 2)                 \
+  macro(Logical, 2, Logical, 4)                 \
+  macro(Logical, 2, Logical, 8)                 \
+  macro(Logical, 4, Logical, 1)                 \
+  macro(Logical, 4, Logical, 2)                 \
+  macro(Logical, 4, Logical, 4)                 \
+  macro(Logical, 4, Logical, 8)                 \
+  macro(Logical, 8, Logical, 1)                 \
+  macro(Logical, 8, Logical, 2)                 \
+  macro(Logical, 8, Logical, 4)                 \
+  macro(Logical, 8, Logical, 8)                 \
+
+FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_INSTANCE)
+FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
+
+#undef MATMUL_INSTANCE
+#undef MATMUL_DIRECT_INSTANCE
+
+// clang-format on
diff --git a/flang/include/flang/Runtime/matmul-transpose.h b/flang/include/flang/Runtime/matmul-transpose.h
index 5eb5896972e0f..d0a5005a1a8bd 100644
--- a/flang/include/flang/Runtime/matmul-transpose.h
+++ b/flang/include/flang/Runtime/matmul-transpose.h
@@ -10,6 +10,8 @@
 
 #ifndef FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
 #define FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
+#include "flang/Common/float128.h"
+#include "flang/Common/uint128.h"
 #include "flang/Runtime/entry-names.h"
 namespace Fortran::runtime {
 class Descriptor;
@@ -25,6 +27,21 @@ void RTDECL(MatmulTranspose)(Descriptor &, const Descriptor &,
 // and have a valid base address.
 void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
     const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+
+// MATMUL(TRANSPOSE()) versions specialized by the categories of the operand
+// types. The KIND and shape information is taken from the argument's
+// descriptors.
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDECL(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line);
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDECL(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
+      Descriptor & result, const Descriptor &x, const Descriptor &y, \
+      const char *sourceFile, int line);
+
+#include "matmul-instances.inc"
+
 } // extern "C"
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
diff --git a/flang/include/flang/Runtime/matmul.h b/flang/include/flang/Runtime/matmul.h
index 40581d44de9e2..1a5e39eb8813f 100644
--- a/flang/include/flang/Runtime/matmul.h
+++ b/flang/include/flang/Runtime/matmul.h
@@ -10,6 +10,8 @@
 
 #ifndef FORTRAN_RUNTIME_MATMUL_H_
 #define FORTRAN_RUNTIME_MATMUL_H_
+#include "flang/Common/float128.h"
+#include "flang/Common/uint128.h"
 #include "flang/Runtime/entry-names.h"
 namespace Fortran::runtime {
 class Descriptor;
@@ -24,6 +26,21 @@ void RTDECL(Matmul)(Descriptor &, const Descriptor &, const Descriptor &,
 // and have a valid base address.
 void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
     const Descriptor &, const char *sourceFile = nullptr, int line = 0);
+
+// MATMUL versions specialized by the categories of the operand types.
+// The KIND and shape information is taken from the argument's
+// descriptors.
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDECL(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line);
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDECL(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line);
+
+#include "matmul-instances.inc"
+
 } // extern "C"
 } // namespace Fortran::runtime
 #endif // FORTRAN_RUNTIME_MATMUL_H_
diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp
index a12d188266f7c..1c998fa8cf6c1 100644
--- a/flang/runtime/matmul-transpose.cpp
+++ b/flang/runtime/matmul-transpose.cpp
@@ -384,6 +384,30 @@ template <bool IS_ALLOCATING> struct MatmulTranspose {
         x, y, terminator, yCatKind->first, yCatKind->second);
   }
 };
+
+template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
+    int YKIND>
+struct MatmulTransposeHelper {
+  using ResultDescriptor =
+      std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+  RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
+      const Descriptor &y, const char *sourceFile, int line) const {
+    Terminator terminator{sourceFile, line};
+    auto xCatKind{x.type().GetCategoryAndKind()};
+    auto yCatKind{y.type().GetCategoryAndKind()};
+    RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
+    RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
+    RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
+    if constexpr (constexpr auto resultType{
+                      GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
+      return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
+          resultType->second, CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
+          result, x, y, terminator);
+    }
+    terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
+        static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
+  }
+};
 } // namespace
 
 namespace Fortran::runtime {
@@ -399,6 +423,24 @@ void RTDEF(MatmulTransposeDirect)(const Descriptor &result, const Descriptor &x,
   MatmulTranspose<false>{}(result, x, y, sourceFile, line);
 }
 
+#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDEF(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
+      const Descriptor &x, const Descriptor &y, const char *sourceFile, \
+      int line) { \
+    MatmulTransposeHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
+        YKIND>{}(result, x, y, sourceFile, line); \
+  }
+
+#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
+  void RTDEF(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
+      Descriptor & result, const Descriptor &x, const Descriptor &y, \
+      const char *sourceFile, int line) { \
+    MatmulTransposeHelper<false, TypeCategory::XCAT, XKIND, \
+        TypeCategory::YCAT, YKIND>{}(result, x, y, sourceFile, line); \
+  }
+
+#include "flang/Runtime/matmul-instances.inc"
+
 RT_EXT_API_GROUP_END
 } // extern "C"
 } // namespace Fortran::runtime
diff --git a/flang/runtime/matmul.cpp b/flang/runtime/matmul.cpp
index 8f9b50a549e1f..504d1aa4dc4a4 100644
--- a/flang/runtime/matmul.cpp
+++ b/flang/runtime/matmul.cpp
@@ -28,7 +28,8 @@
 #include "flang/Runtime/descriptor.h"
 #include <cstring>
 
-namespace Fortran::runtime {
+namespace {
+using namespace Fortran::runtime;
 
 // Suppress the warnings about calling __host__-only std::complex operators,
 // defined in C++ STD header files, from __device__ code.
@@ -455,7 +456,8 @@ template <bool IS_ALLOCATING> struct Matmul {
           Terminator &terminator) const {
         if constexpr (constexpr auto resultType{
                           GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
-          if constexpr (common::IsNumericTypeCategory(resultType->first) ||
+          if constexpr (Fortran::common::IsNumericTypeCategory(
+                            resultType->first) ||
               resultType->first == TypeCategory::Logical) {
             return DoMatmul<IS_ALLOCATING, resultType->first,
                 resultType->second, CppTypeFor<XCAT, XKIND>,
@@ -483,6 +485,32 @@ template <bool IS_ALLOCATING> struct Matmul {
   }
 };
 
+template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
+    int YKIND>
+struct MatmulHelper {
+  using ResultDescriptor =
+      std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
+  RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
+      const Descriptor &y, const char *sourceFile, int line) const {
+    Terminator terminator{sourceFile, line};
+    auto xCatKind{x.type().GetCategoryAndKind()};
+    auto yCatKind{y...
[truncated]

@vzakhari vzakhari merged commit dd22085 into llvm:main Jul 3, 2024
10 checks passed
vzakhari added a commit to vzakhari/llvm-project that referenced this pull request Jul 3, 2024
Lower MATMUL to the new runtime entries added in llvm#97406.
lravenclaw pushed a commit to lravenclaw/llvm-project that referenced this pull request Jul 3, 2024
…m#97406)

Device compilation is much faster for separate MATMUL[_TRANPOSE]
entries than for a single one that covers all data types.
The lowering changes and the removal of the generic entries will follow.
vzakhari added a commit that referenced this pull request Jul 4, 2024
Lower MATMUL to the new runtime entries added in #97406.
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
…m#97406)

Device compilation is much faster for separate MATMUL[_TRANPOSE]
entries than for a single one that covers all data types.
The lowering changes and the removal of the generic entries will follow.
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
Lower MATMUL to the new runtime entries added in llvm#97406.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants