Skip to content

[flang][runtime] Split MATMUL[_TRANSPOSE] into separate entries. #97406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 261 additions & 0 deletions flang/include/flang/Runtime/matmul-instances.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
//===-- include/flang/Runtime/matmul-instances.inc --------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// Helper macros to instantiate MATMUL/MATMUL_TRANSPOSE definitions
// for different data types of the input arguments.
//===----------------------------------------------------------------------===//

#ifndef MATMUL_INSTANCE
#error "Define MATMUL_INSTANCE before including this file"
#endif

#ifndef MATMUL_DIRECT_INSTANCE
#error "Define MATMUL_DIRECT_INSTANCE before including this file"
#endif

// clang-format off

#define FOREACH_MATMUL_TYPE_PAIR(macro) \
macro(Integer, 1, Integer, 1) \
macro(Integer, 1, Integer, 2) \
macro(Integer, 1, Integer, 4) \
macro(Integer, 1, Integer, 8) \
macro(Integer, 2, Integer, 1) \
macro(Integer, 2, Integer, 2) \
macro(Integer, 2, Integer, 4) \
macro(Integer, 2, Integer, 8) \
macro(Integer, 4, Integer, 1) \
macro(Integer, 4, Integer, 2) \
macro(Integer, 4, Integer, 4) \
macro(Integer, 4, Integer, 8) \
macro(Integer, 8, Integer, 1) \
macro(Integer, 8, Integer, 2) \
macro(Integer, 8, Integer, 4) \
macro(Integer, 8, Integer, 8) \
macro(Integer, 1, Real, 4) \
macro(Integer, 1, Real, 8) \
macro(Integer, 2, Real, 4) \
macro(Integer, 2, Real, 8) \
macro(Integer, 4, Real, 4) \
macro(Integer, 4, Real, 8) \
macro(Integer, 8, Real, 4) \
macro(Integer, 8, Real, 8) \
macro(Integer, 1, Complex, 4) \
macro(Integer, 1, Complex, 8) \
macro(Integer, 2, Complex, 4) \
macro(Integer, 2, Complex, 8) \
macro(Integer, 4, Complex, 4) \
macro(Integer, 4, Complex, 8) \
macro(Integer, 8, Complex, 4) \
macro(Integer, 8, Complex, 8) \
macro(Real, 4, Integer, 1) \
macro(Real, 4, Integer, 2) \
macro(Real, 4, Integer, 4) \
macro(Real, 4, Integer, 8) \
macro(Real, 8, Integer, 1) \
macro(Real, 8, Integer, 2) \
macro(Real, 8, Integer, 4) \
macro(Real, 8, Integer, 8) \
macro(Real, 4, Real, 4) \
macro(Real, 4, Real, 8) \
macro(Real, 8, Real, 4) \
macro(Real, 8, Real, 8) \
macro(Real, 4, Complex, 4) \
macro(Real, 4, Complex, 8) \
macro(Real, 8, Complex, 4) \
macro(Real, 8, Complex, 8) \
macro(Complex, 4, Integer, 1) \
macro(Complex, 4, Integer, 2) \
macro(Complex, 4, Integer, 4) \
macro(Complex, 4, Integer, 8) \
macro(Complex, 8, Integer, 1) \
macro(Complex, 8, Integer, 2) \
macro(Complex, 8, Integer, 4) \
macro(Complex, 8, Integer, 8) \
macro(Complex, 4, Real, 4) \
macro(Complex, 4, Real, 8) \
macro(Complex, 8, Real, 4) \
macro(Complex, 8, Real, 8) \
macro(Complex, 4, Complex, 4) \
macro(Complex, 4, Complex, 8) \
macro(Complex, 8, Complex, 4) \
macro(Complex, 8, Complex, 8) \

FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)

#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
#define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro) \
macro(Integer, 16, Integer, 1) \
macro(Integer, 16, Integer, 2) \
macro(Integer, 16, Integer, 4) \
macro(Integer, 16, Integer, 8) \
macro(Integer, 16, Integer, 16) \
macro(Integer, 16, Real, 4) \
macro(Integer, 16, Real, 8) \
macro(Integer, 16, Complex, 4) \
macro(Integer, 16, Complex, 8) \
macro(Real, 4, Integer, 16) \
macro(Real, 8, Integer, 16) \
macro(Complex, 4, Integer, 16) \
macro(Complex, 8, Integer, 16) \

FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)

#if LDBL_MANT_DIG == 64
MATMUL_INSTANCE(Integer, 16, Real, 10)
MATMUL_INSTANCE(Integer, 16, Complex, 10)
MATMUL_INSTANCE(Real, 10, Integer, 16)
MATMUL_INSTANCE(Complex, 10, Integer, 16)
MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 10)
MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
#endif
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
MATMUL_INSTANCE(Integer, 16, Real, 16)
MATMUL_INSTANCE(Integer, 16, Complex, 16)
MATMUL_INSTANCE(Real, 16, Integer, 16)
MATMUL_INSTANCE(Complex, 16, Integer, 16)
MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 16)
MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
#endif
#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T

#if LDBL_MANT_DIG == 64
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro) \
macro(Integer, 1, Real, 10) \
macro(Integer, 1, Complex, 10) \
macro(Integer, 2, Real, 10) \
macro(Integer, 2, Complex, 10) \
macro(Integer, 4, Real, 10) \
macro(Integer, 4, Complex, 10) \
macro(Integer, 8, Real, 10) \
macro(Integer, 8, Complex, 10) \
macro(Real, 4, Real, 10) \
macro(Real, 4, Complex, 10) \
macro(Real, 8, Real, 10) \
macro(Real, 8, Complex, 10) \
macro(Real, 10, Integer, 1) \
macro(Real, 10, Integer, 2) \
macro(Real, 10, Integer, 4) \
macro(Real, 10, Integer, 8) \
macro(Real, 10, Real, 4) \
macro(Real, 10, Real, 8) \
macro(Real, 10, Real, 10) \
macro(Real, 10, Complex, 4) \
macro(Real, 10, Complex, 8) \
macro(Real, 10, Complex, 10) \
macro(Complex, 4, Real, 10) \
macro(Complex, 4, Complex, 10) \
macro(Complex, 8, Real, 10) \
macro(Complex, 8, Complex, 10) \
macro(Complex, 10, Integer, 1) \
macro(Complex, 10, Integer, 2) \
macro(Complex, 10, Integer, 4) \
macro(Complex, 10, Integer, 8) \
macro(Complex, 10, Real, 4) \
macro(Complex, 10, Real, 8) \
macro(Complex, 10, Real, 10) \
macro(Complex, 10, Complex, 4) \
macro(Complex, 10, Complex, 8) \
macro(Complex, 10, Complex, 10) \

FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)

#if HAS_FLOAT128
MATMUL_INSTANCE(Real, 10, Real, 16)
MATMUL_INSTANCE(Real, 10, Complex, 16)
MATMUL_INSTANCE(Real, 16, Real, 10)
MATMUL_INSTANCE(Real, 16, Complex, 10)
MATMUL_INSTANCE(Complex, 10, Real, 16)
MATMUL_INSTANCE(Complex, 10, Complex, 16)
MATMUL_INSTANCE(Complex, 16, Real, 10)
MATMUL_INSTANCE(Complex, 16, Complex, 10)
MATMUL_DIRECT_INSTANCE(Real, 10, Real, 16)
MATMUL_DIRECT_INSTANCE(Real, 10, Complex, 16)
MATMUL_DIRECT_INSTANCE(Real, 16, Real, 10)
MATMUL_DIRECT_INSTANCE(Real, 16, Complex, 10)
MATMUL_DIRECT_INSTANCE(Complex, 10, Real, 16)
MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
#endif
#endif // LDBL_MANT_DIG == 64

#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro) \
macro(Integer, 1, Real, 16) \
macro(Integer, 1, Complex, 16) \
macro(Integer, 2, Real, 16) \
macro(Integer, 2, Complex, 16) \
macro(Integer, 4, Real, 16) \
macro(Integer, 4, Complex, 16) \
macro(Integer, 8, Real, 16) \
macro(Integer, 8, Complex, 16) \
macro(Real, 4, Real, 16) \
macro(Real, 4, Complex, 16) \
macro(Real, 8, Real, 16) \
macro(Real, 8, Complex, 16) \
macro(Real, 16, Integer, 1) \
macro(Real, 16, Integer, 2) \
macro(Real, 16, Integer, 4) \
macro(Real, 16, Integer, 8) \
macro(Real, 16, Real, 4) \
macro(Real, 16, Real, 8) \
macro(Real, 16, Real, 16) \
macro(Real, 16, Complex, 4) \
macro(Real, 16, Complex, 8) \
macro(Real, 16, Complex, 16) \
macro(Complex, 4, Real, 16) \
macro(Complex, 4, Complex, 16) \
macro(Complex, 8, Real, 16) \
macro(Complex, 8, Complex, 16) \
macro(Complex, 16, Integer, 1) \
macro(Complex, 16, Integer, 2) \
macro(Complex, 16, Integer, 4) \
macro(Complex, 16, Integer, 8) \
macro(Complex, 16, Real, 4) \
macro(Complex, 16, Real, 8) \
macro(Complex, 16, Real, 16) \
macro(Complex, 16, Complex, 4) \
macro(Complex, 16, Complex, 8) \
macro(Complex, 16, Complex, 16) \

FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128

#define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
macro(Logical, 1, Logical, 1) \
macro(Logical, 1, Logical, 2) \
macro(Logical, 1, Logical, 4) \
macro(Logical, 1, Logical, 8) \
macro(Logical, 2, Logical, 1) \
macro(Logical, 2, Logical, 2) \
macro(Logical, 2, Logical, 4) \
macro(Logical, 2, Logical, 8) \
macro(Logical, 4, Logical, 1) \
macro(Logical, 4, Logical, 2) \
macro(Logical, 4, Logical, 4) \
macro(Logical, 4, Logical, 8) \
macro(Logical, 8, Logical, 1) \
macro(Logical, 8, Logical, 2) \
macro(Logical, 8, Logical, 4) \
macro(Logical, 8, Logical, 8) \

FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_INSTANCE)
FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)

#undef MATMUL_INSTANCE
#undef MATMUL_DIRECT_INSTANCE

// clang-format on
17 changes: 17 additions & 0 deletions flang/include/flang/Runtime/matmul-transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#ifndef FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
#define FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
#include "flang/Common/float128.h"
#include "flang/Common/uint128.h"
#include "flang/Runtime/entry-names.h"
namespace Fortran::runtime {
class Descriptor;
Expand All @@ -25,6 +27,21 @@ void RTDECL(MatmulTranspose)(Descriptor &, const Descriptor &,
// and have a valid base address.
void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
const Descriptor &, const char *sourceFile = nullptr, int line = 0);

// MATMUL(TRANSPOSE()) versions specialized by the categories of the operand
// types. The KIND and shape information is taken from the argument's
// descriptors.
#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
void RTDECL(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
int line);
#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
void RTDECL(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
Descriptor & result, const Descriptor &x, const Descriptor &y, \
const char *sourceFile, int line);

#include "matmul-instances.inc"

} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
17 changes: 17 additions & 0 deletions flang/include/flang/Runtime/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#ifndef FORTRAN_RUNTIME_MATMUL_H_
#define FORTRAN_RUNTIME_MATMUL_H_
#include "flang/Common/float128.h"
#include "flang/Common/uint128.h"
#include "flang/Runtime/entry-names.h"
namespace Fortran::runtime {
class Descriptor;
Expand All @@ -24,6 +26,21 @@ void RTDECL(Matmul)(Descriptor &, const Descriptor &, const Descriptor &,
// and have a valid base address.
void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
const Descriptor &, const char *sourceFile = nullptr, int line = 0);

// MATMUL versions specialized by the categories of the operand types.
// The KIND and shape information is taken from the argument's
// descriptors.
#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
void RTDECL(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
int line);
#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
void RTDECL(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
int line);

#include "matmul-instances.inc"

} // extern "C"
} // namespace Fortran::runtime
#endif // FORTRAN_RUNTIME_MATMUL_H_
42 changes: 42 additions & 0 deletions flang/runtime/matmul-transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,30 @@ template <bool IS_ALLOCATING> struct MatmulTranspose {
x, y, terminator, yCatKind->first, yCatKind->second);
}
};

template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
int YKIND>
struct MatmulTransposeHelper {
using ResultDescriptor =
std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
const Descriptor &y, const char *sourceFile, int line) const {
Terminator terminator{sourceFile, line};
auto xCatKind{x.type().GetCategoryAndKind()};
auto yCatKind{y.type().GetCategoryAndKind()};
RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
if constexpr (constexpr auto resultType{
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
resultType->second, CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
result, x, y, terminator);
}
terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
}
};
} // namespace

namespace Fortran::runtime {
Expand All @@ -399,6 +423,24 @@ void RTDEF(MatmulTransposeDirect)(const Descriptor &result, const Descriptor &x,
MatmulTranspose<false>{}(result, x, y, sourceFile, line);
}

#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
void RTDEF(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
int line) { \
MatmulTransposeHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
YKIND>{}(result, x, y, sourceFile, line); \
}

#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
void RTDEF(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
Descriptor & result, const Descriptor &x, const Descriptor &y, \
const char *sourceFile, int line) { \
MatmulTransposeHelper<false, TypeCategory::XCAT, XKIND, \
TypeCategory::YCAT, YKIND>{}(result, x, y, sourceFile, line); \
}

#include "flang/Runtime/matmul-instances.inc"

RT_EXT_API_GROUP_END
} // extern "C"
} // namespace Fortran::runtime
Loading
Loading