-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang] Lower MATMUL to type specific runtime calls. #97547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Lower MATMUL to the new runtime entries added in llvm#97406.
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-runtime Author: Slava Zakharin (vzakhari) ChangesLower MATMUL to the new runtime entries added in #97406. Patch is 53.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/97547.diff 15 Files Affected:
diff --git a/flang/include/flang/Optimizer/Support/Utils.h b/flang/include/flang/Optimizer/Support/Utils.h
index ae95a26be1d86..2ffb48335686c 100644
--- a/flang/include/flang/Optimizer/Support/Utils.h
+++ b/flang/include/flang/Optimizer/Support/Utils.h
@@ -84,9 +84,10 @@ inline std::string mlirTypeToString(mlir::Type type) {
return result;
}
-inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
- mlir::Type type, mlir::Location loc,
- const llvm::Twine &name) {
+inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
+ mlir::Type type,
+ mlir::Location loc,
+ const llvm::Twine &name) {
if (type.isF16())
return "REAL(KIND=2)";
else if (type.isBF16())
@@ -123,6 +124,14 @@ inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
return "COMPLEX(KIND=10)";
else if (type == fir::ComplexType::get(builder.getContext(), 16))
return "COMPLEX(KIND=16)";
+ else if (type == fir::LogicalType::get(builder.getContext(), 1))
+ return "LOGICAL(KIND=1)";
+ else if (type == fir::LogicalType::get(builder.getContext(), 2))
+ return "LOGICAL(KIND=2)";
+ else if (type == fir::LogicalType::get(builder.getContext(), 4))
+ return "LOGICAL(KIND=4)";
+ else if (type == fir::LogicalType::get(builder.getContext(), 8))
+ return "LOGICAL(KIND=8)";
else
fir::emitFatalError(loc, "unsupported type in " + name + ": " +
fir::mlirTypeToString(type));
@@ -133,10 +142,70 @@ inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
const llvm::Twine &intrinsicName) {
TODO(loc,
"intrinsic: " +
- fir::numericMlirTypeToFortran(builder, type, loc, intrinsicName) +
+ fir::mlirTypeToIntrinsicFortran(builder, type, loc, intrinsicName) +
" in " + intrinsicName);
}
+inline void intrinsicTypeTODO2(fir::FirOpBuilder &builder, mlir::Type type1,
+ mlir::Type type2, mlir::Location loc,
+ const llvm::Twine &intrinsicName) {
+ TODO(loc,
+ "intrinsic: {" +
+ fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
+ ", " +
+ fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
+ "} in " + intrinsicName);
+}
+
+inline std::pair<Fortran::common::TypeCategory, KindMapping::KindTy>
+mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
+ if (type.isF16())
+ return {Fortran::common::TypeCategory::Real, 2};
+ else if (type.isBF16())
+ return {Fortran::common::TypeCategory::Real, 3};
+ else if (type.isF32())
+ return {Fortran::common::TypeCategory::Real, 4};
+ else if (type.isF64())
+ return {Fortran::common::TypeCategory::Real, 8};
+ else if (type.isF80())
+ return {Fortran::common::TypeCategory::Real, 10};
+ else if (type.isF128())
+ return {Fortran::common::TypeCategory::Real, 16};
+ else if (type.isInteger(8))
+ return {Fortran::common::TypeCategory::Integer, 1};
+ else if (type.isInteger(16))
+ return {Fortran::common::TypeCategory::Integer, 2};
+ else if (type.isInteger(32))
+ return {Fortran::common::TypeCategory::Integer, 4};
+ else if (type.isInteger(64))
+ return {Fortran::common::TypeCategory::Integer, 8};
+ else if (type.isInteger(128))
+ return {Fortran::common::TypeCategory::Integer, 16};
+ else if (type == fir::ComplexType::get(loc.getContext(), 2))
+ return {Fortran::common::TypeCategory::Complex, 2};
+ else if (type == fir::ComplexType::get(loc.getContext(), 3))
+ return {Fortran::common::TypeCategory::Complex, 3};
+ else if (type == fir::ComplexType::get(loc.getContext(), 4))
+ return {Fortran::common::TypeCategory::Complex, 4};
+ else if (type == fir::ComplexType::get(loc.getContext(), 8))
+ return {Fortran::common::TypeCategory::Complex, 8};
+ else if (type == fir::ComplexType::get(loc.getContext(), 10))
+ return {Fortran::common::TypeCategory::Complex, 10};
+ else if (type == fir::ComplexType::get(loc.getContext(), 16))
+ return {Fortran::common::TypeCategory::Complex, 16};
+ else if (type == fir::LogicalType::get(loc.getContext(), 1))
+ return {Fortran::common::TypeCategory::Logical, 1};
+ else if (type == fir::LogicalType::get(loc.getContext(), 2))
+ return {Fortran::common::TypeCategory::Logical, 2};
+ else if (type == fir::LogicalType::get(loc.getContext(), 4))
+ return {Fortran::common::TypeCategory::Logical, 4};
+ else if (type == fir::LogicalType::get(loc.getContext(), 8))
+ return {Fortran::common::TypeCategory::Logical, 8};
+ else
+ fir::emitFatalError(loc,
+ "unsupported type: " + fir::mlirTypeToString(type));
+}
+
/// Find the fir.type_info that was created for this \p recordType in \p module,
/// if any. \p symbolTable can be provided to speed-up the lookup. This tool
/// will match record type even if they have been "altered" in type conversion
diff --git a/flang/include/flang/Runtime/matmul-instances.inc b/flang/include/flang/Runtime/matmul-instances.inc
index 970b03339cd5e..32c6ab06d2521 100644
--- a/flang/include/flang/Runtime/matmul-instances.inc
+++ b/flang/include/flang/Runtime/matmul-instances.inc
@@ -17,6 +17,10 @@
#error "Define MATMUL_DIRECT_INSTANCE before including this file"
#endif
+#ifndef MATMUL_FORCE_ALL_TYPES
+#error "Define MATMUL_FORCE_ALL_TYPES to 0 or 1 before including this file"
+#endif
+
// clang-format off
#define FOREACH_MATMUL_TYPE_PAIR(macro) \
@@ -88,7 +92,7 @@
FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
-#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#if MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
#define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro) \
macro(Integer, 16, Integer, 1) \
macro(Integer, 16, Integer, 2) \
@@ -107,7 +111,7 @@ FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
-#if LDBL_MANT_DIG == 64
+#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
MATMUL_INSTANCE(Integer, 16, Real, 10)
MATMUL_INSTANCE(Integer, 16, Complex, 10)
MATMUL_INSTANCE(Real, 10, Integer, 16)
@@ -117,7 +121,7 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
#endif
-#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
MATMUL_INSTANCE(Integer, 16, Real, 16)
MATMUL_INSTANCE(Integer, 16, Complex, 16)
MATMUL_INSTANCE(Real, 16, Integer, 16)
@@ -127,9 +131,9 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
#endif
-#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
+#endif // MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
-#if LDBL_MANT_DIG == 64
+#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro) \
macro(Integer, 1, Real, 10) \
macro(Integer, 1, Complex, 10) \
@@ -171,7 +175,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)
-#if HAS_FLOAT128
+#if MATMUL_FORCE_ALL_TYPES || HAS_FLOAT128
MATMUL_INSTANCE(Real, 10, Real, 16)
MATMUL_INSTANCE(Real, 10, Complex, 16)
MATMUL_INSTANCE(Real, 16, Real, 10)
@@ -189,9 +193,9 @@ MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
#endif
-#endif // LDBL_MANT_DIG == 64
+#endif // MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
-#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro) \
macro(Integer, 1, Real, 16) \
macro(Integer, 1, Complex, 16) \
@@ -232,7 +236,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
-#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
+#endif // MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
#define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
macro(Logical, 1, Logical, 1) \
@@ -257,5 +261,6 @@ FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
#undef MATMUL_INSTANCE
#undef MATMUL_DIRECT_INSTANCE
+#undef MATMUL_FORCE_ALL_TYPES
// clang-format on
diff --git a/flang/include/flang/Runtime/matmul-transpose.h b/flang/include/flang/Runtime/matmul-transpose.h
index d0a5005a1a8bd..2d79ca10e0895 100644
--- a/flang/include/flang/Runtime/matmul-transpose.h
+++ b/flang/include/flang/Runtime/matmul-transpose.h
@@ -40,6 +40,8 @@ void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
Descriptor & result, const Descriptor &x, const Descriptor &y, \
const char *sourceFile, int line);
+#define MATMUL_FORCE_ALL_TYPES 0
+
#include "matmul-instances.inc"
} // extern "C"
diff --git a/flang/include/flang/Runtime/matmul.h b/flang/include/flang/Runtime/matmul.h
index 1a5e39eb8813f..a72d4a06ee459 100644
--- a/flang/include/flang/Runtime/matmul.h
+++ b/flang/include/flang/Runtime/matmul.h
@@ -39,6 +39,8 @@ void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
int line);
+#define MATMUL_FORCE_ALL_TYPES 0
+
#include "matmul-instances.inc"
} // extern "C"
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 8dd1904939f3e..a1cef7437fa2d 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -701,18 +701,19 @@ prettyPrintIntrinsicName(fir::FirOpBuilder &builder, mlir::Location loc,
if (name == "pow") {
assert(funcType.getNumInputs() == 2 && "power operator has two arguments");
std::string displayName{" ** "};
- sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
- displayName)
+ sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
+ displayName)
<< displayName
- << numericMlirTypeToFortran(builder, funcType.getInput(1), loc,
- displayName);
+ << mlirTypeToIntrinsicFortran(builder, funcType.getInput(1), loc,
+ displayName);
} else {
sstream << name.upper() << "(";
if (funcType.getNumInputs() > 0)
- sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
- name);
+ sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
+ name);
for (mlir::Type argType : funcType.getInputs().drop_front()) {
- sstream << ", " << numericMlirTypeToFortran(builder, argType, loc, name);
+ sstream << ", "
+ << mlirTypeToIntrinsicFortran(builder, argType, loc, name);
}
sstream << ")";
}
diff --git a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
index 6d3d85e8df69f..8f08b01fe0097 100644
--- a/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
+++ b/flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
@@ -329,11 +329,64 @@ void fir::runtime::genEoshiftVector(fir::FirOpBuilder &builder,
builder.create<fir::CallOp>(loc, eoshiftFunc, args);
}
+/// Define ForcedMatmul<ACAT><AKIND><BCAT><BKIND> models.
+struct ForcedMatmulTypeModel {
+ static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
+ return [](mlir::MLIRContext *ctx) {
+ auto boxRefTy =
+ fir::runtime::getModel<Fortran::runtime::Descriptor &>()(ctx);
+ auto boxTy =
+ fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
+ auto strTy = fir::runtime::getModel<const char *>()(ctx);
+ auto intTy = fir::runtime::getModel<int>()(ctx);
+ auto voidTy = fir::runtime::getModel<void>()(ctx);
+ return mlir::FunctionType::get(
+ ctx, {boxRefTy, boxTy, boxTy, strTy, intTy}, {voidTy});
+ };
+ }
+};
+
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
+ struct ForcedMatmul##ACAT##AKIND##BCAT##BKIND \
+ : public ForcedMatmulTypeModel { \
+ static constexpr const char *name = \
+ ExpandAndQuoteKey(RTNAME(Matmul##ACAT##AKIND##BCAT##BKIND)); \
+ };
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+
+#include "flang/Runtime/matmul-instances.inc"
+
/// Generate call to Matmul intrinsic runtime routine.
void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultBox, mlir::Value matrixABox,
mlir::Value matrixBBox) {
- auto func = fir::runtime::getRuntimeFunc<mkRTKey(Matmul)>(loc, builder);
+ mlir::func::FuncOp func;
+ auto boxATy = matrixABox.getType();
+ auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
+ auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
+ auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
+ auto boxBTy = matrixBBox.getType();
+ auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
+ auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
+ auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);
+
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
+ if (!func && aCat == TypeCategory::ACAT && aKind == AKIND && \
+ bCat == TypeCategory::BCAT && bKind == BKIND) { \
+ func = \
+ fir::runtime::getRuntimeFunc<ForcedMatmul##ACAT##AKIND##BCAT##BKIND>( \
+ loc, builder); \
+ }
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+#include "flang/Runtime/matmul-instances.inc"
+
+ if (!func) {
+ fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc, "MATMUL");
+ }
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
auto sourceLine =
@@ -344,13 +397,48 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
builder.create<fir::CallOp>(loc, func, args);
}
-/// Generate call to MatmulTranspose intrinsic runtime routine.
+/// Define ForcedMatmulTranspose<ACAT><AKIND><BCAT><BKIND> models.
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
+ struct ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND \
+ : public ForcedMatmulTypeModel { \
+ static constexpr const char *name = \
+ ExpandAndQuoteKey(RTNAME(MatmulTranspose##ACAT##AKIND##BCAT##BKIND)); \
+ };
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+
+#include "flang/Runtime/matmul-instances.inc"
+
void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Value resultBox,
mlir::Value matrixABox,
mlir::Value matrixBBox) {
- auto func =
- fir::runtime::getRuntimeFunc<mkRTKey(MatmulTranspose)>(loc, builder);
+ mlir::func::FuncOp func;
+ auto boxATy = matrixABox.getType();
+ auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
+ auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
+ auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
+ auto boxBTy = matrixBBox.getType();
+ auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
+ auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
+ auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);
+
+#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
+ if (!func && aCat == TypeCategory::ACAT && aKind == AKIND && \
+ bCat == TypeCategory::BCAT && bKind == BKIND) { \
+ func = fir::runtime::getRuntimeFunc< \
+ ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND>(loc, builder); \
+ }
+
+#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
+#define MATMUL_FORCE_ALL_TYPES 1
+#include "flang/Runtime/matmul-instances.inc"
+
+ if (!func) {
+ fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc,
+ "MATMUL-TRANSPOSE");
+ }
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
auto sourceLine =
diff --git a/flang/runtime/matmul-transpose.cpp b/flang/runtime/matmul-transpose.cpp
index 1c998fa8cf6c1..283472650a1c6 100644
--- a/flang/runtime/matmul-transpose.cpp
+++ b/flang/runtime/matmul-transpose.cpp
@@ -343,48 +343,6 @@ inline static RT_API_ATTRS void DoMatmulTranspose(
RT_DIAG_POP
-// Maps the dynamic type information from the arguments' descriptors
-// to the right instantiation of DoMatmul() for valid combinations of
-// types.
-template <bool IS_ALLOCATING> struct MatmulTranspose {
- using ResultDescriptor =
- std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
- template <TypeCategory XCAT, int XKIND> struct MM1 {
- template <TypeCategory YCAT, int YKIND> struct MM2 {
- RT_API_ATTRS void operator()(ResultDescriptor &result,
- const Descriptor &x, const Descriptor &y,
- Terminator &terminator) const {
- if constexpr (constexpr auto resultType{
- GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
- if constexpr (Fortran::common::IsNumericTypeCategory(
- resultType->first) ||
- resultType->first == TypeCategory::Logical) {
- return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
- resultType->second, CppTypeFor<XCAT, XKIND>,
- CppTypeFor<YCAT, YKIND>>(result, x, y, terminator);
- }
- }
- terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
- static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
- }
- };
- RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
- const Descriptor &y, Terminator &terminator, TypeCategory yCat,
- int yKind) const {
- ApplyType<MM2, void>(yCat, yKind, terminator, result, x, y, terminator);
- }
- };
- RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
- const Descriptor &y, const char *sourceFile, int line) const {
- Terminator terminator{sourceFile, line};
- auto xCatKind{x.type().GetCategoryAndKind()};
- auto yCatKind{y.type().GetCategoryAndKind()};
- RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
- ApplyType<MM1, void>(xCatKind->first, xCatKind->second, terminator, result,
- x, y, terminator, yCatKind->first, yCatKind->second);
- }
-};
-
template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
int YKIND>
struct MatmulTransposeHelper {
@@ -414,15 +372,6 @@ namespace Fortran::runtime {
extern "C" {
RT_EXT_API_GROUP_BEGIN
-void RTDEF(MatmulTranspose)(Descriptor &result, const Descriptor &x,
- const Descriptor ...
[truncated]
|
@vzakhari Did you see any benefits for 178.galgel/spec2000 ? |
Kiran, I did not measure it, because I do not expect much performance change. I will keep an eye on our performance tracking, and will let you know if there is any noticeable change. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Slava!
else if (type == fir::ComplexType::get(loc.getContext(), 2)) | ||
return {Fortran::common::TypeCategory::Complex, 2}; | ||
else if (type == fir::ComplexType::get(loc.getContext(), 3)) | ||
return {Fortran::common::TypeCategory::Complex, 3}; | ||
else if (type == fir::ComplexType::get(loc.getContext(), 4)) | ||
return {Fortran::common::TypeCategory::Complex, 4}; | ||
else if (type == fir::ComplexType::get(loc.getContext(), 8)) | ||
return {Fortran::common::TypeCategory::Complex, 8}; | ||
else if (type == fir::ComplexType::get(loc.getContext(), 10)) | ||
return {Fortran::common::TypeCategory::Complex, 10}; | ||
else if (type == fir::ComplexType::get(loc.getContext(), 16)) | ||
return {Fortran::common::TypeCategory::Complex, 16}; | ||
else if (type == fir::LogicalType::get(loc.getContext(), 1)) | ||
return {Fortran::common::TypeCategory::Logical, 1}; | ||
else if (type == fir::LogicalType::get(loc.getContext(), 2)) | ||
return {Fortran::common::TypeCategory::Logical, 2}; | ||
else if (type == fir::LogicalType::get(loc.getContext(), 4)) | ||
return {Fortran::common::TypeCategory::Logical, 4}; | ||
else if (type == fir::LogicalType::get(loc.getContext(), 8)) | ||
return {Fortran::common::TypeCategory::Logical, 8}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
if(complexType = mlir::dyn_cast<fir::ComplexType>(type))
return {Fortran::common::TypeCategory::Complex, complexType.getFKind()};
else if (logicalType = mlir::dyn_cast<fir::LogicalType>(type))
return {Fortran::common::TypeCategory::Logical, logicalType.getFKind()};
Lower MATMUL to the new runtime entries added in llvm#97406.
Lower MATMUL to the new runtime entries added in #97406.