Skip to content

[flang] Lower MATMUL to type specific runtime calls. #97547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 4, 2024

Conversation

vzakhari
Copy link
Contributor

@vzakhari vzakhari commented Jul 3, 2024

Lower MATMUL to the new runtime entries added in #97406.

Lower MATMUL to the new runtime entries added in llvm#97406.
@vzakhari vzakhari requested a review from jeanPerier July 3, 2024 09:40
@llvmbot llvmbot added flang:runtime flang Flang issues not falling into any other category flang:fir-hlfir labels Jul 3, 2024
@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-runtime

Author: Slava Zakharin (vzakhari)

Changes

Lower MATMUL to the new runtime entries added in #97406.


Patch is 53.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97547.diff

15 Files Affected:

  • (modified) flang/include/flang/Optimizer/Support/Utils.h (+73-4)
  • (modified) flang/include/flang/Runtime/matmul-instances.inc (+14-9)
  • (modified) flang/include/flang/Runtime/matmul-transpose.h (+2)
  • (modified) flang/include/flang/Runtime/matmul.h (+2)
  • (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+8-7)
  • (modified) flang/lib/Optimizer/Builder/Runtime/Transformational.cpp (+92-4)
  • (modified) flang/runtime/matmul-transpose.cpp (+2-51)
  • (modified) flang/runtime/matmul.cpp (+2-51)
  • (modified) flang/test/HLFIR/matmul-lowering.fir (+3-3)
  • (modified) flang/test/HLFIR/mul_transpose.f90 (+3-3)
  • (modified) flang/test/Lower/Intrinsics/matmul.f90 (+2-2)
  • (modified) flang/unittests/Optimizer/Builder/Runtime/RuntimeCallTestBase.h (+9)
  • (modified) flang/unittests/Optimizer/Builder/Runtime/TransformationalTest.cpp (+34-8)
  • (modified) flang/unittests/Runtime/Matmul.cpp (-119)
  • (modified) flang/unittests/Runtime/MatmulTranspose.cpp (-131)
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index ae95a26be1d86..2ffb48335686c 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -84,9 +84,10 @@ inline std::string mlirTypeToString(mlir::Type type) {
   return result;
 }
 
-inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
-                                            mlir::Type type, mlir::Location loc,
-                                            const llvm::Twine &name) {
+inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
+                                              mlir::Type type,
+                                              mlir::Location loc,
+                                              const llvm::Twine &name) {
   if (type.isF16())
     return "REAL(KIND=2)";
   else if (type.isBF16())
@@ -123,6 +124,14 @@ inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
     return "COMPLEX(KIND=10)";
   else if (type == fir::ComplexType::get(builder.getContext(), 16))
     return "COMPLEX(KIND=16)";
+  else if (type == fir::LogicalType::get(builder.getContext(), 1))
+    return "LOGICAL(KIND=1)";
+  else if (type == fir::LogicalType::get(builder.getContext(), 2))
+    return "LOGICAL(KIND=2)";
+  else if (type == fir::LogicalType::get(builder.getContext(), 4))
+    return "LOGICAL(KIND=4)";
+  else if (type == fir::LogicalType::get(builder.getContext(), 8))
+    return "LOGICAL(KIND=8)";
   else
     fir::emitFatalError(loc, "unsupported type in " + name + ": " +
                                  fir::mlirTypeToString(type));
@@ -133,10 +142,70 @@ inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
                               const llvm::Twine &intrinsicName) {
   TODO(loc,
        "intrinsic: " +
-           fir::numericMlirTypeToFortran(builder, type, loc, intrinsicName) +
+           fir::mlirTypeToIntrinsicFortran(builder, type, loc, intrinsicName) +
            " in " + intrinsicName);
 }
 
+inline void intrinsicTypeTODO2(fir::FirOpBuilder &builder, mlir::Type type1,
+                               mlir::Type type2, mlir::Location loc,
+                               const llvm::Twine &intrinsicName) {
+  TODO(loc,
+       "intrinsic: {" +
+           fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
+           ", " +
+           fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
+           "} in " + intrinsicName);
+}
+
+inline std::pair<Fortran::common::TypeCategory, KindMapping::KindTy>
+mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
+  if (type.isF16())
+    return {Fortran::common::TypeCategory::Real, 2};
+  else if (type.isBF16())
+    return {Fortran::common::TypeCategory::Real, 3};
+  else if (type.isF32())
+    return {Fortran::common::TypeCategory::Real, 4};
+  else if (type.isF64())
+    return {Fortran::common::TypeCategory::Real, 8};
+  else if (type.isF80())
+    return {Fortran::common::TypeCategory::Real, 10};
+  else if (type.isF128())
+    return {Fortran::common::TypeCategory::Real, 16};
+  else if (type.isInteger(8))
+    return {Fortran::common::TypeCategory::Integer, 1};
+  else if (type.isInteger(16))
+    return {Fortran::common::TypeCategory::Integer, 2};
+  else if (type.isInteger(32))
+    return {Fortran::common::TypeCategory::Integer, 4};
+  else if (type.isInteger(64))
+    return {Fortran::common::TypeCategory::Integer, 8};
+  else if (type.isInteger(128))
+    return {Fortran::common::TypeCategory::Integer, 16};
+  else if (type == fir::ComplexType::get(loc.getContext(), 2))
+    return {Fortran::common::TypeCategory::Complex, 2};
+  else if (type == fir::ComplexType::get(loc.getContext(), 3))
+    return {Fortran::common::TypeCategory::Complex, 3};
+  else if (type == fir::ComplexType::get(loc.getContext(), 4))
+    return {Fortran::common::TypeCategory::Complex, 4};
+  else if (type == fir::ComplexType::get(loc.getContext(), 8))
+    return {Fortran::common::TypeCategory::Complex, 8};
+  else if (type == fir::ComplexType::get(loc.getContext(), 10))
+    return {Fortran::common::TypeCategory::Complex, 10};
+  else if (type == fir::ComplexType::get(loc.getContext(), 16))
+    return {Fortran::common::TypeCategory::Complex, 16};
+  else if (type == fir::LogicalType::get(loc.getContext(), 1))
+    return {Fortran::common::TypeCategory::Logical, 1};
+  else if (type == fir::LogicalType::get(loc.getContext(), 2))
+    return {Fortran::common::TypeCategory::Logical, 2};
+  else if (type == fir::LogicalType::get(loc.getContext(), 4))
+    return {Fortran::common::TypeCategory::Logical, 4};
+  else if (type == fir::LogicalType::get(loc.getContext(), 8))
+    return {Fortran::common::TypeCategory::Logical, 8};
+  else
+    fir::emitFatalError(loc,
+                        "unsupported type: " + fir::mlirTypeToString(type));
+}
+
 /// Find the fir.type_info that was created for this \p recordType in \p module,
 /// if any. \p  symbolTable can be provided to speed-up the lookup. This tool
 /// will match record type even if they have been "altered" in type conversion
diff --git a/flang/include/flang/Runtime/matmul-instances.inc b/flang/include/flang/Runtime/matmul-instances.inc
index 970b03339cd5e..32c6ab06d2521 100644
--- a/flang/include/flang/Runtime/matmul-instances.inc
+++ b/flang/include/flang/Runtime/matmul-instances.inc
@@ -17,6 +17,10 @@
 #error "Define MATMUL_DIRECT_INSTANCE before including this file"
 #endif
 
+#ifndef MATMUL_FORCE_ALL_TYPES
+#error "Define MATMUL_FORCE_ALL_TYPES to 0 or 1 before including this file"
+#endif
+
 // clang-format off
 
 #define FOREACH_MATMUL_TYPE_PAIR(macro)         \
@@ -88,7 +92,7 @@
 FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
 
-#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#if MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
 #define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro)      \
   macro(Integer, 16, Integer, 1)                        \
   macro(Integer, 16, Integer, 2)                        \
@@ -107,7 +111,7 @@ FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
 
-#if LDBL_MANT_DIG == 64
+#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
 MATMUL_INSTANCE(Integer, 16, Real, 10)
 MATMUL_INSTANCE(Integer, 16, Complex, 10)
 MATMUL_INSTANCE(Real, 10, Integer, 16)
@@ -117,7 +121,7 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
 MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
 MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
 #endif
-#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
 MATMUL_INSTANCE(Integer, 16, Real, 16)
 MATMUL_INSTANCE(Integer, 16, Complex, 16)
 MATMUL_INSTANCE(Real, 16, Integer, 16)
@@ -127,9 +131,9 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
 MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
 MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
 #endif
-#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#endif // MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
 
-#if LDBL_MANT_DIG == 64
+#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
 #define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro)         \
   macro(Integer, 1, Real, 10)                               \
   macro(Integer, 1, Complex, 10)                            \
@@ -171,7 +175,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
 FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)
 
-#if HAS_FLOAT128
+#if MATMUL_FORCE_ALL_TYPES || HAS_FLOAT128
 MATMUL_INSTANCE(Real, 10, Real, 16)
 MATMUL_INSTANCE(Real, 10, Complex, 16)
 MATMUL_INSTANCE(Real, 16, Real, 10)
@@ -189,9 +193,9 @@ MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
 MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
 MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
 #endif
-#endif // LDBL_MANT_DIG == 64
+#endif // MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
 
-#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
 #define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro)         \
   macro(Integer, 1, Real, 16)                               \
   macro(Integer, 1, Complex, 16)                            \
@@ -232,7 +236,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
 
 FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
 FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
-#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#endif // MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
 
 #define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
   macro(Logical, 1, Logical, 1)                 \
@@ -257,5 +261,6 @@ FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
 
 #undef MATMUL_INSTANCE
 #undef MATMUL_DIRECT_INSTANCE
+#undef MATMUL_FORCE_ALL_TYPES
 
 // clang-format on
diff --git a/flang/include/flang/Runtime/matmul-transpose.h b/flang/include/flang/Runtime/matmul-transpose.h
index d0a5005a1a8bd..2d79ca10e0895 100644
--- a/flang/include/flang/Runtime/matmul-transpose.h
+++ b/flang/include/flang/Runtime/matmul-transpose.h
@@ -40,6 +40,8 @@ void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
       Descriptor & result, const Descriptor &x, const Descriptor &y, \
       const char *sourceFile, int line);
 
+#define MATMUL_FORCE_ALL_TYPES 0
+
 #include "matmul-instances.inc"
 
 } // extern "C"
diff --git a/flang/include/flang/Runtime/matmul.h b/flang/include/flang/Runtime/matmul.h
index 1a5e39eb8813f..a72d4a06ee459 100644
--- a/flang/include/flang/Runtime/matmul.h
+++ b/flang/include/flang/Runtime/matmul.h
@@ -39,6 +39,8 @@ void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
       const Descriptor &x, const Descriptor &y, const char *sourceFile, \
       int line);
 
+#define MATMUL_FORCE_ALL_TYPES 0
+
 #include "matmul-instances.inc"
 
 } // extern "C"
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 8dd1904939f3e..a1cef7437fa2d 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -701,18 +701,19 @@ prettyPrintIntrinsicName(fir::FirOpBuilder &builder, mlir::Location loc,
   if (name == "pow") {
     assert(funcType.getNumInputs() == 2 && "power operator has two arguments");
     std::string displayName{" ** "};
-    sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
-                                        displayName)
+    sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
+                                          displayName)
             << displayName
-            << numericMlirTypeToFortran(builder, funcType.getInput(1), loc,
-                                        displayName);
+            << mlirTypeToIntrinsicFortran(builder, funcType.getInput(1), loc,
+                                          displayName);
   } else {
     sstream << name.upper() << "(";
     if (funcType.getNumInputs() > 0)
-      sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
-                                          name);
+      sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
+                                            name);
     for (mlir::Type argType : funcType.getInputs().drop_front()) {
-      sstream << ", " << numericMlirTypeToFortran(builder, argType, loc, name);
+      sstream << ", "
+              << mlirTypeToIntrinsicFortran(builder, argType, loc, name);
     }
     sstream << ")";
   }
diff --git a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
index 6d3d85e8df69f..8f08b01fe0097 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
@@ -329,11 +329,64 @@ void fir::runtime::genEoshiftVector(fir::FirOpBuilder &builder,
   builder.create<fir::CallOp>(loc, eoshiftFunc, args);
 }
 
+/// Define ForcedMatmul<ACAT><AKIND><BCAT><BKIND> models.
+struct ForcedMatmulTypeModel {
+  static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+    return [](mlir::MLIRContext *ctx) {
+      auto boxRefTy =
+          fir::runtime::getModel<Fortran::runtime::Descriptor &>()(ctx);
+      auto boxTy =
+          fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+      auto strTy = fir::runtime::getModel<const char *>()(ctx);
+      auto intTy = fir::runtime::getModel<int>()(ctx);
+      auto voidTy = fir::runtime::getModel<void>()(ctx);
+      return mlir::FunctionType::get(
+          ctx, {boxRefTy, boxTy, boxTy, strTy, intTy}, {voidTy});
+    };
+  }
+};
+
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND)                              \
+  struct ForcedMatmul##ACAT##AKIND##BCAT##BKIND                                \
+      : public ForcedMatmulTypeModel {                                         \
+    static constexpr const char *name =                                        \
+        ExpandAndQuoteKey(RTNAME(Matmul##ACAT##AKIND##BCAT##BKIND));           \
+  };
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+
+#include "flang/Runtime/matmul-instances.inc"
+
 /// Generate call to Matmul intrinsic runtime routine.
 void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
                              mlir::Value resultBox, mlir::Value matrixABox,
                              mlir::Value matrixBBox) {
-  auto func = fir::runtime::getRuntimeFunc<mkRTKey(Matmul)>(loc, builder);
+  mlir::func::FuncOp func;
+  auto boxATy = matrixABox.getType();
+  auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
+  auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
+  auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
+  auto boxBTy = matrixBBox.getType();
+  auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
+  auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
+  auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);
+
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND)                              \
+  if (!func && aCat == TypeCategory::ACAT && aKind == AKIND &&                 \
+      bCat == TypeCategory::BCAT && bKind == BKIND) {                          \
+    func =                                                                     \
+        fir::runtime::getRuntimeFunc<ForcedMatmul##ACAT##AKIND##BCAT##BKIND>(  \
+            loc, builder);                                                     \
+  }
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+#include "flang/Runtime/matmul-instances.inc"
+
+  if (!func) {
+    fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc, "MATMUL");
+  }
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);
   auto sourceLine =
@@ -344,13 +397,48 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
   builder.create<fir::CallOp>(loc, func, args);
 }
 
-/// Generate call to MatmulTranspose intrinsic runtime routine.
+/// Define ForcedMatmulTranspose<ACAT><AKIND><BCAT><BKIND> models.
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND)                              \
+  struct ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND                       \
+      : public ForcedMatmulTypeModel {                                         \
+    static constexpr const char *name =                                        \
+        ExpandAndQuoteKey(RTNAME(MatmulTranspose##ACAT##AKIND##BCAT##BKIND));  \
+  };
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+
+#include "flang/Runtime/matmul-instances.inc"
+
 void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
                                       mlir::Location loc, mlir::Value resultBox,
                                       mlir::Value matrixABox,
                                       mlir::Value matrixBBox) {
-  auto func =
-      fir::runtime::getRuntimeFunc<mkRTKey(MatmulTranspose)>(loc, builder);
+  mlir::func::FuncOp func;
+  auto boxATy = matrixABox.getType();
+  auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
+  auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
+  auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
+  auto boxBTy = matrixBBox.getType();
+  auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
+  auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
+  auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);
+
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND)                              \
+  if (!func && aCat == TypeCategory::ACAT && aKind == AKIND &&                 \
+      bCat == TypeCategory::BCAT && bKind == BKIND) {                          \
+    func = fir::runtime::getRuntimeFunc<                                       \
+        ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND>(loc, builder);        \
+  }
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+#include "flang/Runtime/matmul-instances.inc"
+
+  if (!func) {
+    fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc,
+                            "MATMUL-TRANSPOSE");
+  }
   auto fTy = func.getFunctionType();
   auto sourceFile = fir::factory::locationToFilename(builder, loc);
   auto sourceLine =
diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp
index 1c998fa8cf6c1..283472650a1c6 100644
--- a/flang/runtime/matmul-transpose.cpp
+++ b/flang/runtime/matmul-transpose.cpp
@@ -343,48 +343,6 @@ inline static RT_API_ATTRS void DoMatmulTranspose(
 
 RT_DIAG_POP
 
-// Maps the dynamic type information from the arguments' descriptors
-// to the right instantiation of DoMatmul() for valid combinations of
-// types.
-template <bool IS_ALLOCATING> struct MatmulTranspose {
-  using ResultDescriptor =
-      std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
-  template <TypeCategory XCAT, int XKIND> struct MM1 {
-    template <TypeCategory YCAT, int YKIND> struct MM2 {
-      RT_API_ATTRS void operator()(ResultDescriptor &result,
-          const Descriptor &x, const Descriptor &y,
-          Terminator &terminator) const {
-        if constexpr (constexpr auto resultType{
-                          GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
-          if constexpr (Fortran::common::IsNumericTypeCategory(
-                            resultType->first) ||
-              resultType->first == TypeCategory::Logical) {
-            return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
-                resultType->second, CppTypeFor<XCAT, XKIND>,
-                CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
-          }
-        }
-        terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
-            static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
-      }
-    };
-    RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
-        const Descriptor &y, Terminator &terminator, TypeCategory yCat,
-        int yKind) const {
-      ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
-    }
-  };
-  RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
-      const Descriptor &y, const char *sourceFile, int line) const {
-    Terminator terminator{sourceFile, line};
-    auto xCatKind{x.type().GetCategoryAndKind()};
-    auto yCatKind{y.type().GetCategoryAndKind()};
-    RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
-    ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
-        x, y, terminator, yCatKind->first, yCatKind->second);
-  }
-};
-
 template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
     int YKIND>
 struct MatmulTransposeHelper {
@@ -414,15 +372,6 @@ namespace Fortran::runtime {
 extern "C" {
 RT_EXT_API_GROUP_BEGIN
 
-void RTDEF(MatmulTranspose)(Descriptor &result, const Descriptor &x,
-    const Descriptor ...
[truncated]

@kiranchandramohan
Copy link
Contributor

@vzakhari Did you see any benefits for 178.galgel/spec2000 ?

@vzakhari
Copy link
Contributor Author

vzakhari commented Jul 3, 2024

@vzakhari Did you see any benefits for 178.galgel/spec2000 ?

Kiran, I did not measure it, because I do not expect much performance change. I will keep an eye on our performance tracking, and will let you know if there is any noticeable change.

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Slava!

Comment on lines 184 to 203
else if (type == fir::ComplexType::get(loc.getContext(), 2))
return {Fortran::common::TypeCategory::Complex, 2};
else if (type == fir::ComplexType::get(loc.getContext(), 3))
return {Fortran::common::TypeCategory::Complex, 3};
else if (type == fir::ComplexType::get(loc.getContext(), 4))
return {Fortran::common::TypeCategory::Complex, 4};
else if (type == fir::ComplexType::get(loc.getContext(), 8))
return {Fortran::common::TypeCategory::Complex, 8};
else if (type == fir::ComplexType::get(loc.getContext(), 10))
return {Fortran::common::TypeCategory::Complex, 10};
else if (type == fir::ComplexType::get(loc.getContext(), 16))
return {Fortran::common::TypeCategory::Complex, 16};
else if (type == fir::LogicalType::get(loc.getContext(), 1))
return {Fortran::common::TypeCategory::Logical, 1};
else if (type == fir::LogicalType::get(loc.getContext(), 2))
return {Fortran::common::TypeCategory::Logical, 2};
else if (type == fir::LogicalType::get(loc.getContext(), 4))
return {Fortran::common::TypeCategory::Logical, 4};
else if (type == fir::LogicalType::get(loc.getContext(), 8))
return {Fortran::common::TypeCategory::Logical, 8};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

if(complexType = mlir::dyn_cast<fir::ComplexType>(type))
  return {Fortran::common::TypeCategory::Complex, complexType.getFKind()};
else if (logicalType = mlir::dyn_cast<fir::LogicalType>(type))
  return {Fortran::common::TypeCategory::Logical, logicalType.getFKind()};

@vzakhari vzakhari merged commit 8ce1aed into llvm:main Jul 4, 2024
7 checks passed
kbluck pushed a commit to kbluck/llvm-project that referenced this pull request Jul 6, 2024
Lower MATMUL to the new runtime entries added in llvm#97406.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:runtime flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants