Skip to content

Commit ce603a0

Browse files
[flang][openmp]Add UserReductionDetails and use in DECLARE REDUCTION (#140066)
This adds another puzzle piece for the support of OpenMP DECLARE REDUCTION functionality. This adds support for operators with derived types, as well as declaring multiple different types with the same name or operator. A new detail class for UserReductionDetials is introduced to hold the list of types supported for a given reduction declaration. Tests for parsing and symbol generation added. Declare reduction is still not supported to lowering, it will generate a "Not yet implemented" fatal error. Fixes #141306 Fixes #97241 Fixes #92832 Fixes #66453 --------- Co-authored-by: Mats Petersson <[email protected]>
1 parent e4447e1 commit ce603a0

25 files changed

+1080
-38
lines changed

flang/include/flang/Semantics/symbol.h

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class raw_ostream;
3030
}
3131
namespace Fortran::parser {
3232
struct Expr;
33+
struct OpenMPDeclareReductionConstruct;
34+
struct OmpMetadirectiveDirective;
3335
}
3436

3537
namespace Fortran::semantics {
@@ -728,14 +730,48 @@ class GenericDetails {
728730
};
729731
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const GenericDetails &);
730732

733+
// Used for OpenMP DECLARE REDUCTION, it holds the information
734+
// needed to resolve which declaration (there could be multiple
735+
// with the same name) to use for a given type.
736+
class UserReductionDetails {
737+
public:
738+
using TypeVector = std::vector<const DeclTypeSpec *>;
739+
using DeclInfo = std::variant<const parser::OpenMPDeclareReductionConstruct *,
740+
const parser::OmpMetadirectiveDirective *>;
741+
using DeclVector = std::vector<DeclInfo>;
742+
743+
UserReductionDetails() = default;
744+
745+
void AddType(const DeclTypeSpec &type) { typeList_.push_back(&type); }
746+
const TypeVector &GetTypeList() const { return typeList_; }
747+
748+
bool SupportsType(const DeclTypeSpec &type) const {
749+
// We have to compare the actual type, not the pointer, as some
750+
// types are not guaranteed to be the same object.
751+
for (auto t : typeList_) {
752+
if (*t == type) {
753+
return true;
754+
}
755+
}
756+
return false;
757+
}
758+
759+
void AddDecl(const DeclInfo &decl) { declList_.emplace_back(decl); }
760+
const DeclVector &GetDeclList() const { return declList_; }
761+
762+
private:
763+
TypeVector typeList_;
764+
DeclVector declList_;
765+
};
766+
731767
class UnknownDetails {};
732768

733769
using Details = std::variant<UnknownDetails, MainProgramDetails, ModuleDetails,
734770
SubprogramDetails, SubprogramNameDetails, EntityDetails,
735771
ObjectEntityDetails, ProcEntityDetails, AssocEntityDetails,
736772
DerivedTypeDetails, UseDetails, UseErrorDetails, HostAssocDetails,
737773
GenericDetails, ProcBindingDetails, NamelistDetails, CommonBlockDetails,
738-
TypeParamDetails, MiscDetails>;
774+
TypeParamDetails, MiscDetails, UserReductionDetails>;
739775
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Details &);
740776
std::string DetailsToString(const Details &);
741777

flang/lib/Parser/unparse.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3368,4 +3368,12 @@ template void Unparse<Program>(llvm::raw_ostream &, const Program &,
33683368
template void Unparse<Expr>(llvm::raw_ostream &, const Expr &,
33693369
const common::LangOptions &, Encoding, bool, bool, preStatementType *,
33703370
AnalyzedObjectsAsFortran *);
3371+
3372+
template void Unparse<parser::OpenMPDeclareReductionConstruct>(
3373+
llvm::raw_ostream &, const parser::OpenMPDeclareReductionConstruct &,
3374+
const common::LangOptions &, Encoding, bool, bool, preStatementType *,
3375+
AnalyzedObjectsAsFortran *);
3376+
template void Unparse<parser::OmpMetadirectiveDirective>(llvm::raw_ostream &,
3377+
const parser::OmpMetadirectiveDirective &, const common::LangOptions &,
3378+
Encoding, bool, bool, preStatementType *, AnalyzedObjectsAsFortran *);
33713379
} // namespace Fortran::parser

flang/lib/Semantics/assignment.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class AssignmentContext {
4343
void Analyze(const parser::PointerAssignmentStmt &);
4444
void Analyze(const parser::ConcurrentControl &);
4545
int deviceConstructDepth_{0};
46+
SemanticsContext &context() { return context_; }
4647

4748
private:
4849
bool CheckForPureContext(const SomeExpr &rhs, parser::CharBlock rhsSource);
@@ -218,8 +219,17 @@ void AssignmentContext::PopWhereContext() {
218219

219220
AssignmentChecker::~AssignmentChecker() {}
220221

222+
SemanticsContext &AssignmentChecker::context() {
223+
return context_.value().context();
224+
}
225+
221226
AssignmentChecker::AssignmentChecker(SemanticsContext &context)
222227
: context_{new AssignmentContext{context}} {}
228+
229+
void AssignmentChecker::Enter(
230+
const parser::OpenMPDeclareReductionConstruct &x) {
231+
context().set_location(x.source);
232+
}
223233
void AssignmentChecker::Enter(const parser::AssignmentStmt &x) {
224234
context_.value().Analyze(x);
225235
}

flang/lib/Semantics/assignment.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class AssignmentChecker : public virtual BaseChecker {
3737
public:
3838
explicit AssignmentChecker(SemanticsContext &);
3939
~AssignmentChecker();
40+
void Enter(const parser::OpenMPDeclareReductionConstruct &x);
4041
void Enter(const parser::AssignmentStmt &);
4142
void Enter(const parser::PointerAssignmentStmt &);
4243
void Enter(const parser::WhereStmt &);
@@ -54,6 +55,8 @@ class AssignmentChecker : public virtual BaseChecker {
5455
void Enter(const parser::OpenACCLoopConstruct &);
5556
void Leave(const parser::OpenACCLoopConstruct &);
5657

58+
SemanticsContext &context();
59+
5760
private:
5861
common::Indirection<AssignmentContext> context_;
5962
};

flang/lib/Semantics/check-omp-structure.cpp

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "check-omp-structure.h"
1010
#include "definable.h"
11+
#include "resolve-names-utils.h"
1112
#include "flang/Evaluate/check-expression.h"
1213
#include "flang/Evaluate/expression.h"
1314
#include "flang/Evaluate/type.h"
@@ -3520,6 +3521,17 @@ bool OmpStructureChecker::CheckReductionOperator(
35203521
break;
35213522
}
35223523
}
3524+
// User-defined operators are OK if there has been a declared reduction
3525+
// for that. We mangle those names to store the user details.
3526+
if (const auto *definedOp{std::get_if<parser::DefinedOpName>(&dOpr.u)}) {
3527+
std::string mangled{MangleDefinedOperator(definedOp->v.symbol->name())};
3528+
const Scope &scope{definedOp->v.symbol->owner()};
3529+
if (const Symbol *symbol{scope.FindSymbol(mangled)}) {
3530+
if (symbol->detailsIf<UserReductionDetails>()) {
3531+
return true;
3532+
}
3533+
}
3534+
}
35233535
context_.Say(source, "Invalid reduction operator in %s clause."_err_en_US,
35243536
parser::ToUpperCaseLetters(getClauseName(clauseId).str()));
35253537
return false;
@@ -3533,8 +3545,7 @@ bool OmpStructureChecker::CheckReductionOperator(
35333545
valid =
35343546
llvm::is_contained({"max", "min", "iand", "ior", "ieor"}, realName);
35353547
if (!valid) {
3536-
auto *misc{name->symbol->detailsIf<MiscDetails>()};
3537-
valid = misc && misc->kind() == MiscDetails::Kind::ConstructName;
3548+
valid = name->symbol->detailsIf<UserReductionDetails>();
35383549
}
35393550
}
35403551
if (!valid) {
@@ -3614,8 +3625,20 @@ void OmpStructureChecker::CheckReductionObjects(
36143625
}
36153626
}
36163627

3628+
static bool CheckSymbolSupportsType(const Scope &scope,
3629+
const parser::CharBlock &name, const DeclTypeSpec &type) {
3630+
if (const auto *symbol{scope.FindSymbol(name)}) {
3631+
if (const auto *reductionDetails{
3632+
symbol->detailsIf<UserReductionDetails>()}) {
3633+
return reductionDetails->SupportsType(type);
3634+
}
3635+
}
3636+
return false;
3637+
}
3638+
36173639
static bool IsReductionAllowedForType(
3618-
const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type) {
3640+
const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type,
3641+
const Scope &scope, SemanticsContext &context) {
36193642
auto isLogical{[](const DeclTypeSpec &type) -> bool {
36203643
return type.category() == DeclTypeSpec::Logical;
36213644
}};
@@ -3635,27 +3658,40 @@ static bool IsReductionAllowedForType(
36353658
case parser::DefinedOperator::IntrinsicOperator::Multiply:
36363659
case parser::DefinedOperator::IntrinsicOperator::Add:
36373660
case parser::DefinedOperator::IntrinsicOperator::Subtract:
3638-
return type.IsNumeric(TypeCategory::Integer) ||
3661+
if (type.IsNumeric(TypeCategory::Integer) ||
36393662
type.IsNumeric(TypeCategory::Real) ||
3640-
type.IsNumeric(TypeCategory::Complex);
3663+
type.IsNumeric(TypeCategory::Complex))
3664+
return true;
3665+
break;
36413666

36423667
case parser::DefinedOperator::IntrinsicOperator::AND:
36433668
case parser::DefinedOperator::IntrinsicOperator::OR:
36443669
case parser::DefinedOperator::IntrinsicOperator::EQV:
36453670
case parser::DefinedOperator::IntrinsicOperator::NEQV:
3646-
return isLogical(type);
3671+
if (isLogical(type)) {
3672+
return true;
3673+
}
3674+
break;
36473675

36483676
// Reduction identifier is not in OMP5.2 Table 5.2
36493677
default:
36503678
DIE("This should have been caught in CheckIntrinsicOperator");
36513679
return false;
36523680
}
3653-
}
3654-
return true;
3681+
parser::CharBlock name{MakeNameFromOperator(*intrinsicOp, context)};
3682+
return CheckSymbolSupportsType(scope, name, type);
3683+
} else if (const auto *definedOp{
3684+
std::get_if<parser::DefinedOpName>(&dOpr.u)}) {
3685+
return CheckSymbolSupportsType(
3686+
scope, MangleDefinedOperator(definedOp->v.symbol->name()), type);
3687+
}
3688+
llvm_unreachable(
3689+
"A DefinedOperator is either a DefinedOpName or an IntrinsicOperator");
36553690
}};
36563691

36573692
auto checkDesignator{[&](const parser::ProcedureDesignator &procD) {
36583693
const parser::Name *name{std::get_if<parser::Name>(&procD.u)};
3694+
CHECK(name && name->symbol);
36593695
if (name && name->symbol) {
36603696
const SourceName &realName{name->symbol->GetUltimate().name()};
36613697
// OMP5.2: The type [...] of a list item that appears in a
@@ -3664,18 +3700,35 @@ static bool IsReductionAllowedForType(
36643700
// IAND: arguments must be integers: F2023 16.9.100
36653701
// IEOR: arguments must be integers: F2023 16.9.106
36663702
// IOR: arguments must be integers: F2023 16.9.111
3667-
return type.IsNumeric(TypeCategory::Integer);
3703+
if (type.IsNumeric(TypeCategory::Integer)) {
3704+
return true;
3705+
}
36683706
} else if (realName == "max" || realName == "min") {
36693707
// MAX: arguments must be integer, real, or character:
36703708
// F2023 16.9.135
36713709
// MIN: arguments must be integer, real, or character:
36723710
// F2023 16.9.141
3673-
return type.IsNumeric(TypeCategory::Integer) ||
3674-
type.IsNumeric(TypeCategory::Real) || isCharacter(type);
3711+
if (type.IsNumeric(TypeCategory::Integer) ||
3712+
type.IsNumeric(TypeCategory::Real) || isCharacter(type)) {
3713+
return true;
3714+
}
36753715
}
3716+
3717+
// If we get here, it may be a user declared reduction, so check
3718+
// if the symbol has UserReductionDetails, and if so, the type is
3719+
// supported.
3720+
if (const auto *reductionDetails{
3721+
name->symbol->detailsIf<UserReductionDetails>()}) {
3722+
return reductionDetails->SupportsType(type);
3723+
}
3724+
3725+
// We also need to check for mangled names (max, min, iand, ieor and ior)
3726+
// and then check if the type is there.
3727+
parser::CharBlock mangledName{MangleSpecialFunctions(name->source)};
3728+
return CheckSymbolSupportsType(scope, mangledName, type);
36763729
}
3677-
// TODO: user defined reduction operators. Just allow everything for now.
3678-
return true;
3730+
// Everything else is "not matching type".
3731+
return false;
36793732
}};
36803733

36813734
return common::visit(
@@ -3690,7 +3743,8 @@ void OmpStructureChecker::CheckReductionObjectTypes(
36903743

36913744
for (auto &[symbol, source] : symbols) {
36923745
if (auto *type{symbol->GetType()}) {
3693-
if (!IsReductionAllowedForType(ident, *type)) {
3746+
const auto &scope{context_.FindScope(symbol->name())};
3747+
if (!IsReductionAllowedForType(ident, *type, scope, context_)) {
36943748
context_.Say(source,
36953749
"The type of '%s' is incompatible with the reduction operator."_err_en_US,
36963750
symbol->name());

flang/lib/Semantics/mod-file.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,7 @@ void ModFileWriter::PutEntity(llvm::raw_ostream &os, const Symbol &symbol) {
894894
[&](const ObjectEntityDetails &) { PutObjectEntity(os, symbol); },
895895
[&](const ProcEntityDetails &) { PutProcEntity(os, symbol); },
896896
[&](const TypeParamDetails &) { PutTypeParam(os, symbol); },
897+
[&](const UserReductionDetails &) { PutUserReduction(os, symbol); },
897898
[&](const auto &) {
898899
common::die("PutEntity: unexpected details: %s",
899900
DetailsToString(symbol.details()).c_str());
@@ -1043,6 +1044,28 @@ void ModFileWriter::PutTypeParam(llvm::raw_ostream &os, const Symbol &symbol) {
10431044
os << '\n';
10441045
}
10451046

1047+
void ModFileWriter::PutUserReduction(
1048+
llvm::raw_ostream &os, const Symbol &symbol) {
1049+
const auto &details{symbol.get<UserReductionDetails>()};
1050+
// The module content for a OpenMP Declare Reduction is the OpenMP
1051+
// declaration. There may be multiple declarations.
1052+
// Decls are pointers, so do not use a reference.
1053+
for (const auto decl : details.GetDeclList()) {
1054+
common::visit( //
1055+
common::visitors{//
1056+
[&](const parser::OpenMPDeclareReductionConstruct *d) {
1057+
Unparse(os, *d, context_.langOptions());
1058+
},
1059+
[&](const parser::OmpMetadirectiveDirective *m) {
1060+
Unparse(os, *m, context_.langOptions());
1061+
},
1062+
[&](const auto &) {
1063+
DIE("Unknown OpenMP DECLARE REDUCTION content");
1064+
}},
1065+
decl);
1066+
}
1067+
}
1068+
10461069
void PutInit(llvm::raw_ostream &os, const Symbol &symbol, const MaybeExpr &init,
10471070
const parser::Expr *unanalyzed, SemanticsContext &context) {
10481071
if (IsNamedConstant(symbol) || symbol.owner().IsDerivedType()) {

flang/lib/Semantics/mod-file.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class ModFileWriter {
8080
void PutDerivedType(const Symbol &, const Scope * = nullptr);
8181
void PutDECStructure(const Symbol &, const Scope * = nullptr);
8282
void PutTypeParam(llvm::raw_ostream &, const Symbol &);
83+
void PutUserReduction(llvm::raw_ostream &, const Symbol &);
8384
void PutSubprogram(const Symbol &);
8485
void PutGeneric(const Symbol &);
8586
void PutUse(const Symbol &);

flang/lib/Semantics/resolve-names-utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,11 @@ struct SymbolAndTypeMappings;
146146
void MapSubprogramToNewSymbols(const Symbol &oldSymbol, Symbol &newSymbol,
147147
Scope &newScope, SymbolAndTypeMappings * = nullptr);
148148

149+
parser::CharBlock MakeNameFromOperator(
150+
const parser::DefinedOperator::IntrinsicOperator &op,
151+
SemanticsContext &context);
152+
parser::CharBlock MangleSpecialFunctions(const parser::CharBlock &name);
153+
std::string MangleDefinedOperator(const parser::CharBlock &name);
154+
149155
} // namespace Fortran::semantics
150156
#endif // FORTRAN_SEMANTICS_RESOLVE_NAMES_H_

0 commit comments

Comments
 (0)