-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang][OpenMP] Convert repeatable clauses (except Map) in ClauseProc… #81623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-openmp Author: Krzysztof Parzyszek (kparzysz) Changes…essor Rename Leave Patch is 51.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/81623.diff 2 Files Affected:
diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h
index d257da1a709642..e9999974944e88 100644
--- a/flang/include/flang/Evaluate/tools.h
+++ b/flang/include/flang/Evaluate/tools.h
@@ -430,6 +430,29 @@ template <typename A> std::optional<CoarrayRef> ExtractCoarrayRef(const A &x) {
}
}
+struct ExtractSubstringHelper {
+ template <typename T> static std::optional<Substring> visit(T &&) {
+ return std::nullopt;
+ }
+
+ static std::optional<Substring> visit(const Substring &e) { return e; }
+
+ template <typename T>
+ static std::optional<Substring> visit(const Designator<T> &e) {
+ return std::visit([](auto &&s) { return visit(s); }, e.u);
+ }
+
+ template <typename T>
+ static std::optional<Substring> visit(const Expr<T> &e) {
+ return std::visit([](auto &&s) { return visit(s); }, e.u);
+ }
+};
+
+template <typename A>
+std::optional<Substring> ExtractSubstring(const A &x) {
+ return ExtractSubstringHelper::visit(x);
+}
+
// If an expression is simply a whole symbol data designator,
// extract and return that symbol, else null.
template <typename A> const Symbol *UnwrapWholeSymbolDataRef(const A &x) {
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index d7a93db15a4bb8..4b21ab934c9393 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -72,9 +72,9 @@ getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject) {
return sym;
}
-static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
- Fortran::lower::AbstractConverter &converter,
- llvm::SmallVectorImpl<mlir::Value> &operands) {
+static void genObjectList2(const Fortran::parser::OmpObjectList &objectList,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands) {
auto addOperands = [&](Fortran::lower::SymbolRef sym) {
const mlir::Value variable = converter.getSymbolAddress(sym);
if (variable) {
@@ -93,27 +93,6 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
}
}
-static void gatherFuncAndVarSyms(
- const Fortran::parser::OmpObjectList &objList,
- mlir::omp::DeclareTargetCaptureClause clause,
- llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
- for (const Fortran::parser::OmpObject &ompObject : objList.v) {
- Fortran::common::visit(
- Fortran::common::visitors{
- [&](const Fortran::parser::Designator &designator) {
- if (const Fortran::parser::Name *name =
- Fortran::semantics::getDesignatorNameIfDataRef(
- designator)) {
- symbolAndClause.emplace_back(clause, *name->symbol);
- }
- },
- [&](const Fortran::parser::Name &name) {
- symbolAndClause.emplace_back(clause, *name.symbol);
- }},
- ompObject.u);
- }
-}
-
static Fortran::lower::pft::Evaluation *
getCollapsedLoopEval(Fortran::lower::pft::Evaluation &eval, int collapseValue) {
// Return the Evaluation of the innermost collapsed loop, or the current one
@@ -1257,6 +1236,32 @@ List<Clause> makeList(const parser::OmpClauseList &clauses,
}
} // namespace omp
+static void genObjectList(const omp::ObjectList &objects,
+ Fortran::lower::AbstractConverter &converter,
+ llvm::SmallVectorImpl<mlir::Value> &operands) {
+ for (const omp::Object &object : objects) {
+ const Fortran::semantics::Symbol *sym = object.sym;
+ assert(sym && "Expected Symbol");
+ if (mlir::Value variable = converter.getSymbolAddress(*sym)) {
+ operands.push_back(variable);
+ } else {
+ if (const auto *details =
+ sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
+ operands.push_back(converter.getSymbolAddress(details->symbol()));
+ converter.copySymbolBinding(details->symbol(), *sym);
+ }
+ }
+ }
+}
+
+static void gatherFuncAndVarSyms(
+ const omp::ObjectList &objects,
+ mlir::omp::DeclareTargetCaptureClause clause,
+ llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+ for (const omp::Object &object : objects)
+ symbolAndClause.emplace_back(clause, *object.sym);
+}
+
//===----------------------------------------------------------------------===//
// DataSharingProcessor
//===----------------------------------------------------------------------===//
@@ -1718,9 +1723,8 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Value> &dependOperands) const;
bool
processEnter(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
- bool
- processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier directiveName,
- mlir::Value &result) const;
+ bool processIf(omp::clause::If::DirectiveNameModifier directiveName,
+ mlir::Value &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
@@ -1815,6 +1819,26 @@ class ClauseProcessor {
/// if at least one instance was found.
template <typename T>
bool findRepeatableClause(
+ std::function<void(const T &, const Fortran::parser::CharBlock &source)>
+ callbackFn) const {
+ bool found = false;
+ ClauseIterator nextIt, endIt = clauses.end();
+ for (ClauseIterator it = clauses.begin(); it != endIt; it = nextIt) {
+ nextIt = findClause<T>(it, endIt);
+
+ if (nextIt != endIt) {
+ callbackFn(std::get<T>(nextIt->u), nextIt->source);
+ found = true;
+ ++nextIt;
+ }
+ }
+ return found;
+ }
+
+ /// Call `callbackFn` for each occurrence of the given clause. Return `true`
+ /// if at least one instance was found.
+ template <typename T>
+ bool findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
@@ -1880,9 +1904,9 @@ class ReductionProcessor {
IEOR
};
static ReductionIdentifier
- getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
+ getReductionType(const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
- getRealName(pd).ToString())
+ getRealName(pd.v.sym).ToString())
.Case("max", ReductionIdentifier::MAX)
.Case("min", ReductionIdentifier::MIN)
.Case("iand", ReductionIdentifier::IAND)
@@ -1894,35 +1918,33 @@ class ReductionProcessor {
}
static ReductionIdentifier getReductionType(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+ omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Add:
return ReductionIdentifier::ADD;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
return ReductionIdentifier::SUBTRACT;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
return ReductionIdentifier::MULTIPLY;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return ReductionIdentifier::AND;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return ReductionIdentifier::EQV;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return ReductionIdentifier::OR;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return ReductionIdentifier::NEQV;
default:
llvm_unreachable("unexpected intrinsic operator in reduction");
}
}
- static bool supportedIntrinsicProcReduction(
- const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- if (!name->symbol->GetUltimate().attrs().test(
- Fortran::semantics::Attr::INTRINSIC))
+ static bool
+ supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd) {
+ Fortran::semantics::Symbol *sym = pd.v.sym;
+ if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
return false;
- auto redType = llvm::StringSwitch<bool>(getRealName(name).ToString())
+ auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
.Case("max", true)
.Case("min", true)
.Case("iand", true)
@@ -1933,15 +1955,13 @@ class ReductionProcessor {
}
static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::Name *name) {
- return name->symbol->GetUltimate().name();
+ getRealName(const Fortran::semantics::Symbol *symbol) {
+ return symbol->GetUltimate().name();
}
static const Fortran::semantics::SourceName
- getRealName(const Fortran::parser::ProcedureDesignator &pd) {
- const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
- assert(name && "Invalid Reduction Intrinsic.");
- return getRealName(name);
+ getRealName(const omp::clause::ProcedureDesignator &pd) {
+ return getRealName(pd.v.sym);
}
static std::string getReductionName(llvm::StringRef name, mlir::Type ty) {
@@ -1951,25 +1971,25 @@ class ReductionProcessor {
.str();
}
- static std::string getReductionName(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type ty) {
+ static std::string
+ getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
+ mlir::Type ty) {
std::string reductionName;
switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Add:
reductionName = "add_reduction";
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
reductionName = "multiply_reduction";
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ case omp::clause::DefinedOperator::IntrinsicOperator::AND:
return "and_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
return "eqv_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ case omp::clause::DefinedOperator::IntrinsicOperator::OR:
return "or_reduction";
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
return "neqv_reduction";
default:
reductionName = "other_reduction";
@@ -2213,7 +2233,7 @@ class ReductionProcessor {
static void
addReductionDecl(mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpReductionClause &reduction,
+ const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
@@ -2221,13 +2241,12 @@ class ReductionProcessor {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::omp::ReductionDeclareOp decl;
const auto &redOperator{
- std::get<Fortran::parser::OmpReductionOperator>(reduction.t)};
- const auto &objectList{
- std::get<Fortran::parser::OmpObjectList>(reduction.t)};
+ std::get<omp::clause::ReductionOperator>(reduction.t)};
+ const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
if (const auto &redDefinedOp =
- std::get_if<Fortran::parser::DefinedOperator>(&redOperator.u)) {
+ std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
const auto &intrinsicOp{
- std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
+ std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
redDefinedOp->u)};
ReductionIdentifier redId = getReductionType(intrinsicOp);
switch (redId) {
@@ -2243,10 +2262,41 @@ class ReductionProcessor {
"Reduction of some intrinsic operators is not supported");
break;
}
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.sym) {
+ if (reductionSymbols)
+ reductionSymbols->push_back(symbol);
+ mlir::Value symVal = converter.getSymbolAddress(*symbol);
+ if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
+ symVal = declOp.getBase();
+ mlir::Type redType =
+ symVal.getType().cast<fir::ReferenceType>().getEleTy();
+ reductionVars.push_back(symVal);
+ if (redType.isa<fir::LogicalType>())
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(intrinsicOp, firOpBuilder.getI1Type()), redId,
+ redType, currentLocation);
+ else if (redType.isIntOrIndexOrFloat()) {
+ decl = createReductionDecl(firOpBuilder,
+ getReductionName(intrinsicOp, redType),
+ redId, redType, currentLocation);
+ } else {
+ TODO(currentLocation, "Reduction of some types is not supported");
+ }
+ reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
+ firOpBuilder.getContext(), decl.getSymName()));
+ }
+ }
+ } else if (const auto *reductionIntrinsic =
+ std::get_if<omp::clause::ProcedureDesignator>(
+ &redOperator.u)) {
+ if (ReductionProcessor::supportedIntrinsicProcReduction(
+ *reductionIntrinsic)) {
+ ReductionProcessor::ReductionIdentifier redId =
+ ReductionProcessor::getReductionType(*reductionIntrinsic);
+ for (const omp::Object &object : objectList) {
+ if (const Fortran::semantics::Symbol *symbol = object.sym) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
@@ -2255,55 +2305,18 @@ class ReductionProcessor {
mlir::Type redType =
symVal.getType().cast<fir::ReferenceType>().getEleTy();
reductionVars.push_back(symVal);
- if (redType.isa<fir::LogicalType>())
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
- redId, redType, currentLocation);
- else if (redType.isIntOrIndexOrFloat()) {
- decl = createReductionDecl(firOpBuilder,
- getReductionName(intrinsicOp, redType),
- redId, redType, currentLocation);
- } else {
- TODO(currentLocation, "Reduction of some types is not supported");
- }
+ assert(redType.isIntOrIndexOrFloat() &&
+ "Unsupported reduction type");
+ decl = createReductionDecl(
+ firOpBuilder,
+ getReductionName(getRealName(*reductionIntrinsic).ToString(),
+ redType),
+ redId, redType, currentLocation);
reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
firOpBuilder.getContext(), decl.getSymName()));
}
}
}
- } else if (const auto *reductionIntrinsic =
- std::get_if<Fortran::parser::ProcedureDesignator>(
- &redOperator.u)) {
- if (ReductionProcessor::supportedIntrinsicProcReduction(
- *reductionIntrinsic)) {
- ReductionProcessor::ReductionIdentifier redId =
- ReductionProcessor::getReductionType(*reductionIntrinsic);
- for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
- if (const auto *name{
- Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
- if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
- if (reductionSymbols)
- reductionSymbols->push_back(symbol);
- mlir::Value symVal = converter.getSymbolAddress(*symbol);
- if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
- symVal = declOp.getBase();
- mlir::Type redType =
- symVal.getType().cast<fir::ReferenceType>().getEleTy();
- reductionVars.push_back(symVal);
- assert(redType.isIntOrIndexOrFloat() &&
- "Unsupported reduction type");
- decl = createReductionDecl(
- firOpBuilder,
- getReductionName(getRealName(*reductionIntrinsic).ToString(),
- redType),
- redId, redType, currentLocation);
- reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
- firOpBuilder.getContext(), decl.getSymName()));
- }
- }
- }
- }
}
}
};
@@ -2365,7 +2378,7 @@ getSimdModifier(const omp::clause::Schedule &clause) {
static void
genAllocateClause(Fortran::lower::AbstractConverter &converter,
- const Fortran::parser::OmpAllocateClause &ompAllocateClause,
+ const omp::clause::Allocate &clause,
llvm::SmallVectorImpl<mlir::Value> &allocatorOperands,
llvm::SmallVectorImpl<mlir::Value> &allocateOperands) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
@@ -2373,21 +2386,18 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
Fortran::lower::StatementContext stmtCtx;
mlir::Value allocatorOperand;
- const Fortran::parser::OmpObjectList &ompObjectList =
- std::get<Fortran::parser::OmpObjectList>(ompAllocateClause.t);
- const auto &allocateModifier = std::get<
- std::optional<Fortran::parser::OmpAllocateClause::AllocateModifier>>(
- ompAllocateClause.t);
+ const omp::ObjectList &objectList = std::get<omp::ObjectList>(clause.t);
+ const auto &modifier =
+ std::get<std::optional<omp::clause::Allocate::Modifier>>(clause.t);
// If the allocate modifier is present, check if we only use the allocator
// submodifier. ALIGN in this context is unimplemented
const bool onlyAllocator =
- allocateModifier &&
- std::holds_alternative<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
+ modifier &&
+ std::holds_alternative<omp::clause::Allocate::Modifier::Allocator>(
+ modifier->u);
- if (allocateModifier && !onlyAllocator) {
+ if (modifier && !onlyAllocator) {
TODO(currentLocation, "OmpAllocateClause ALIGN modifier");
}
@@ -2395,20 +2405,17 @@ genAllocateClause(Fortran::lower::AbstractConverter &converter,
// to list of allocators, otherwise, add default allocator to
// list of allocators.
if (onlyAllocator) {
- const auto &allocatorValue = std::get<
- Fortran::parser::OmpAllocateClause::AllocateModifier::Allocator>(
- allocateModifier->u);
- allocatorOperand = fir...
[truncated]
|
bdf3050
to
57c70c5
Compare
be33fa2
to
841f10e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, I just have a question on a part of the patch that I wasn't able to fully understand.
Introduce a set of generic classes (templates) that represent OpenMP clauses in a language-agnostic manner. OpenMP clauses can contain expressions and data objects and the exact representation of each depends on the source language of the compiled program. To deal with this, the templates depend on two type parameters: - IdType: type that represent object's identity (in a way that satisfied OpenMP requirements), and - ExprType: type that can represent numeric values, as well as data references (e.g. x.y[1].z[2]). In addition to that, implement code instantiating these templates from flang's AST. This patch only introduces the new classes, they are not yet used anywhere.
Temporarily rename old clause list to `clauses2`, old clause iterator to `ClauseIterator2`. Change `findUniqueClause` to iterate over `omp::Clause` objects, modify all handlers to operate on 'omp::clause::xyz` equivalents.
edc050c
to
fafbd98
Compare
…essor Rename `findRepeatableClause` to `findRepeatableClause2`, and make the new `findRepeatableClause` operate on new `omp::Clause` objects. Leave `Map` unchanged, because it will require more changes for it to work.
e32bfaa
to
655dce5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the clarification, LGTM.
flang/lib/Lower/OpenMP/Utils.cpp
Outdated
assert(sym && "Expected Symbol"); | ||
if (mlir::Value variable = converter.getSymbolAddress(*sym)) { | ||
operands.push_back(variable); | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I know this follows the previous implementation, but I think it'd be better to collapse this into an else if
and get rid of one nesting level. Feel free to ignore if you disagree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining. LGTM
✅ With the latest revision this PR passed the C/C++ code formatter. |
Another instance of |
…essor
Rename
findRepeatableClause
tofindRepeatableClause2
, and make the newfindRepeatableClause
operate on newomp::Clause
objects.Leave
Map
unchanged, because it will require more changes for it to work.[Clause representation 3/6]