Skip to content

[flang] Lower MATMUL to type specific runtime calls. #97547

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions flang/include/flang/Optimizer/Support/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ inline std::string mlirTypeToString(mlir::Type type) {
return result;
}

inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
mlir::Type type, mlir::Location loc,
const llvm::Twine &name) {
inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
mlir::Type type,
mlir::Location loc,
const llvm::Twine &name) {
if (type.isF16())
return "REAL(KIND=2)";
else if (type.isBF16())
Expand Down Expand Up @@ -123,6 +124,14 @@ inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
return "COMPLEX(KIND=10)";
else if (type == fir::ComplexType::get(builder.getContext(), 16))
return "COMPLEX(KIND=16)";
else if (type == fir::LogicalType::get(builder.getContext(), 1))
return "LOGICAL(KIND=1)";
else if (type == fir::LogicalType::get(builder.getContext(), 2))
return "LOGICAL(KIND=2)";
else if (type == fir::LogicalType::get(builder.getContext(), 4))
return "LOGICAL(KIND=4)";
else if (type == fir::LogicalType::get(builder.getContext(), 8))
return "LOGICAL(KIND=8)";
else
fir::emitFatalError(loc, "unsupported type in " + name + ": " +
fir::mlirTypeToString(type));
Expand All @@ -133,10 +142,54 @@ inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
const llvm::Twine &intrinsicName) {
TODO(loc,
"intrinsic: " +
fir::numericMlirTypeToFortran(builder, type, loc, intrinsicName) +
fir::mlirTypeToIntrinsicFortran(builder, type, loc, intrinsicName) +
" in " + intrinsicName);
}

inline void intrinsicTypeTODO2(fir::FirOpBuilder &builder, mlir::Type type1,
mlir::Type type2, mlir::Location loc,
const llvm::Twine &intrinsicName) {
TODO(loc,
"intrinsic: {" +
fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
", " +
fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
"} in " + intrinsicName);
}

inline std::pair<Fortran::common::TypeCategory, KindMapping::KindTy>
mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
if (type.isF16())
return {Fortran::common::TypeCategory::Real, 2};
else if (type.isBF16())
return {Fortran::common::TypeCategory::Real, 3};
else if (type.isF32())
return {Fortran::common::TypeCategory::Real, 4};
else if (type.isF64())
return {Fortran::common::TypeCategory::Real, 8};
else if (type.isF80())
return {Fortran::common::TypeCategory::Real, 10};
else if (type.isF128())
return {Fortran::common::TypeCategory::Real, 16};
else if (type.isInteger(8))
return {Fortran::common::TypeCategory::Integer, 1};
else if (type.isInteger(16))
return {Fortran::common::TypeCategory::Integer, 2};
else if (type.isInteger(32))
return {Fortran::common::TypeCategory::Integer, 4};
else if (type.isInteger(64))
return {Fortran::common::TypeCategory::Integer, 8};
else if (type.isInteger(128))
return {Fortran::common::TypeCategory::Integer, 16};
else if (auto complexType = mlir::dyn_cast<fir::ComplexType>(type))
return {Fortran::common::TypeCategory::Complex, complexType.getFKind()};
else if (auto logicalType = mlir::dyn_cast<fir::LogicalType>(type))
return {Fortran::common::TypeCategory::Logical, logicalType.getFKind()};
else
fir::emitFatalError(loc,
"unsupported type: " + fir::mlirTypeToString(type));
}

/// Find the fir.type_info that was created for this \p recordType in \p module,
/// if any. \p symbolTable can be provided to speed-up the lookup. This tool
/// will match record type even if they have been "altered" in type conversion
Expand Down
23 changes: 14 additions & 9 deletions flang/include/flang/Runtime/matmul-instances.inc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#error "Define MATMUL_DIRECT_INSTANCE before including this file"
#endif

#ifndef MATMUL_FORCE_ALL_TYPES
#error "Define MATMUL_FORCE_ALL_TYPES to 0 or 1 before including this file"
#endif

// clang-format off

#define FOREACH_MATMUL_TYPE_PAIR(macro) \
Expand Down Expand Up @@ -88,7 +92,7 @@
FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)

#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
#if MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
#define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro) \
macro(Integer, 16, Integer, 1) \
macro(Integer, 16, Integer, 2) \
Expand All @@ -107,7 +111,7 @@ FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)

#if LDBL_MANT_DIG == 64
#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
MATMUL_INSTANCE(Integer, 16, Real, 10)
MATMUL_INSTANCE(Integer, 16, Complex, 10)
MATMUL_INSTANCE(Real, 10, Integer, 16)
Expand All @@ -117,7 +121,7 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
MATMUL_INSTANCE(Integer, 16, Real, 16)
MATMUL_INSTANCE(Integer, 16, Complex, 16)
MATMUL_INSTANCE(Real, 16, Integer, 16)
Expand All @@ -127,9 +131,9 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
#endif
#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
#endif // MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)

#if LDBL_MANT_DIG == 64
#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro) \
macro(Integer, 1, Real, 10) \
macro(Integer, 1, Complex, 10) \
Expand Down Expand Up @@ -171,7 +175,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)

#if HAS_FLOAT128
#if MATMUL_FORCE_ALL_TYPES || HAS_FLOAT128
MATMUL_INSTANCE(Real, 10, Real, 16)
MATMUL_INSTANCE(Real, 10, Complex, 16)
MATMUL_INSTANCE(Real, 16, Real, 10)
Expand All @@ -189,9 +193,9 @@ MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
#endif
#endif // LDBL_MANT_DIG == 64
#endif // MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64

#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro) \
macro(Integer, 1, Real, 16) \
macro(Integer, 1, Complex, 16) \
Expand Down Expand Up @@ -232,7 +236,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)

FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
#endif // MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)

#define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
macro(Logical, 1, Logical, 1) \
Expand All @@ -257,5 +261,6 @@ FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)

#undef MATMUL_INSTANCE
#undef MATMUL_DIRECT_INSTANCE
#undef MATMUL_FORCE_ALL_TYPES

// clang-format on
2 changes: 2 additions & 0 deletions flang/include/flang/Runtime/matmul-transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
Descriptor & result, const Descriptor &x, const Descriptor &y, \
const char *sourceFile, int line);

#define MATMUL_FORCE_ALL_TYPES 0

#include "matmul-instances.inc"

} // extern "C"
Expand Down
2 changes: 2 additions & 0 deletions flang/include/flang/Runtime/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
int line);

#define MATMUL_FORCE_ALL_TYPES 0

#include "matmul-instances.inc"

} // extern "C"
Expand Down
15 changes: 8 additions & 7 deletions flang/lib/Optimizer/Builder/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,18 +701,19 @@ prettyPrintIntrinsicName(fir::FirOpBuilder &builder, mlir::Location loc,
if (name == "pow") {
assert(funcType.getNumInputs() == 2 && "power operator has two arguments");
std::string displayName{" ** "};
sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
displayName)
sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
displayName)
<< displayName
<< numericMlirTypeToFortran(builder, funcType.getInput(1), loc,
displayName);
<< mlirTypeToIntrinsicFortran(builder, funcType.getInput(1), loc,
displayName);
} else {
sstream << name.upper() << "(";
if (funcType.getNumInputs() > 0)
sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
name);
sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
name);
for (mlir::Type argType : funcType.getInputs().drop_front()) {
sstream << ", " << numericMlirTypeToFortran(builder, argType, loc, name);
sstream << ", "
<< mlirTypeToIntrinsicFortran(builder, argType, loc, name);
}
sstream << ")";
}
Expand Down
96 changes: 92 additions & 4 deletions flang/lib/Optimizer/Builder/Runtime/Transformational.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,64 @@ void fir::runtime::genEoshiftVector(fir::FirOpBuilder &builder,
builder.create<fir::CallOp>(loc, eoshiftFunc, args);
}

/// Define ForcedMatmul<ACAT><AKIND><BCAT><BKIND> models.
struct ForcedMatmulTypeModel {
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
return [](mlir::MLIRContext *ctx) {
auto boxRefTy =
fir::runtime::getModel<Fortran::runtime::Descriptor &>()(ctx);
auto boxTy =
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
auto strTy = fir::runtime::getModel<const char *>()(ctx);
auto intTy = fir::runtime::getModel<int>()(ctx);
auto voidTy = fir::runtime::getModel<void>()(ctx);
return mlir::FunctionType::get(
ctx, {boxRefTy, boxTy, boxTy, strTy, intTy}, {voidTy});
};
}
};

#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
struct ForcedMatmul##ACAT##AKIND##BCAT##BKIND \
: public ForcedMatmulTypeModel { \
static constexpr const char *name = \
ExpandAndQuoteKey(RTNAME(Matmul##ACAT##AKIND##BCAT##BKIND)); \
};

#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
#define MATMUL_FORCE_ALL_TYPES 1

#include "flang/Runtime/matmul-instances.inc"

/// Generate call to Matmul intrinsic runtime routine.
void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultBox, mlir::Value matrixABox,
mlir::Value matrixBBox) {
auto func = fir::runtime::getRuntimeFunc<mkRTKey(Matmul)>(loc, builder);
mlir::func::FuncOp func;
auto boxATy = matrixABox.getType();
auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
auto boxBTy = matrixBBox.getType();
auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);

#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
if (!func && aCat == TypeCategory::ACAT && aKind == AKIND && \
bCat == TypeCategory::BCAT && bKind == BKIND) { \
func = \
fir::runtime::getRuntimeFunc<ForcedMatmul##ACAT##AKIND##BCAT##BKIND>( \
loc, builder); \
}

#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
#define MATMUL_FORCE_ALL_TYPES 1
#include "flang/Runtime/matmul-instances.inc"

if (!func) {
fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc, "MATMUL");
}
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
auto sourceLine =
Expand All @@ -344,13 +397,48 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
builder.create<fir::CallOp>(loc, func, args);
}

/// Generate call to MatmulTranspose intrinsic runtime routine.
/// Define ForcedMatmulTranspose<ACAT><AKIND><BCAT><BKIND> models.
#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
struct ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND \
: public ForcedMatmulTypeModel { \
static constexpr const char *name = \
ExpandAndQuoteKey(RTNAME(MatmulTranspose##ACAT##AKIND##BCAT##BKIND)); \
};

#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
#define MATMUL_FORCE_ALL_TYPES 1

#include "flang/Runtime/matmul-instances.inc"

void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
mlir::Location loc, mlir::Value resultBox,
mlir::Value matrixABox,
mlir::Value matrixBBox) {
auto func =
fir::runtime::getRuntimeFunc<mkRTKey(MatmulTranspose)>(loc, builder);
mlir::func::FuncOp func;
auto boxATy = matrixABox.getType();
auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
auto boxBTy = matrixBBox.getType();
auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);

#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
if (!func && aCat == TypeCategory::ACAT && aKind == AKIND && \
bCat == TypeCategory::BCAT && bKind == BKIND) { \
func = fir::runtime::getRuntimeFunc< \
ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND>(loc, builder); \
}

#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
#define MATMUL_FORCE_ALL_TYPES 1
#include "flang/Runtime/matmul-instances.inc"

if (!func) {
fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc,
"MATMUL-TRANSPOSE");
}
auto fTy = func.getFunctionType();
auto sourceFile = fir::factory::locationToFilename(builder, loc);
auto sourceLine =
Expand Down
Loading
Loading