Skip to content

Commit 8fde6f4

Browse files
committed
[Flang][OpenMP] Add lowering from PFT to new MapEntry and Bounds operations and tie them to relevant Target operations
This patch builds on top of a prior patch in review which adds a new map and bounds operation by modifying the OpenMP PFT lowering to support these operations and generate them from the PFT. A significant amount of the support for the Bounds operation is borrowed from OpenACC's own current implementation and lowering, just ported over to OpenMP. The patch also adds very preliminary/initial support for lowering to a new Capture attribute, which is stored on the new Map Operation, which helps the later lowering from OpenMP -> LLVM IR by indicating how a map argument should be handled. This capture type will influence how a map argument is accessed on device and passed by the host (different load/store handling etc.). It is reflective of a similar piece of information stored in the Clang AST which performs a similar role. As well as some minor adjustments to how the map type (map bitshift which dictates to the runtime how it should handle an argument) is generated to further support more use-cases for future patches that build on this work. Finally it adds the map entry operation creation and tying it to the relevant target operations as well as the addition of some new tests and alteration of previous tests to support the new changes. Depends on D158732 reviewers: kiranchandramohan, TIFitis, clementval, razvanlupusoru Differential Revision: https://reviews.llvm.org/D158734
1 parent 571df01 commit 8fde6f4

File tree

10 files changed

+868
-483
lines changed

10 files changed

+868
-483
lines changed

flang/include/flang/Lower/OpenMP.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ struct OmpClauseList;
3636

3737
namespace semantics {
3838
class Symbol;
39+
class SemanticsContext;
3940
} // namespace semantics
4041

4142
namespace lower {
@@ -51,8 +52,8 @@ struct Variable;
5152
void genOpenMPTerminator(fir::FirOpBuilder &, mlir::Operation *,
5253
mlir::Location);
5354

54-
void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &,
55-
const parser::OpenMPConstruct &);
55+
void genOpenMPConstruct(AbstractConverter &, semantics::SemanticsContext &,
56+
pft::Evaluation &, const parser::OpenMPConstruct &);
5657
void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &,
5758
const parser::OpenMPDeclarativeConstruct &);
5859
int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2320,7 +2320,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
23202320
void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
23212321
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
23222322
localSymbols.pushScope();
2323-
genOpenMPConstruct(*this, getEval(), omp);
2323+
genOpenMPConstruct(*this, bridge.getSemanticsContext(), getEval(), omp);
23242324

23252325
const Fortran::parser::OpenMPLoopConstruct *ompLoop =
23262326
std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u);

flang/lib/Lower/DirectivesCommon.h

Lines changed: 309 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,23 @@
1212
///
1313
/// A location to place directive utilities shared across multiple lowering
1414
/// files, e.g. utilities shared in OpenMP and OpenACC. The header file can
15-
/// be used for both declarations and templated/inline implementations.
15+
/// be used for both declarations and templated/inline implementations
1616
//===----------------------------------------------------------------------===//
1717

1818
#ifndef FORTRAN_LOWER_DIRECTIVES_COMMON_H
1919
#define FORTRAN_LOWER_DIRECTIVES_COMMON_H
2020

2121
#include "flang/Common/idioms.h"
22+
#include "flang/Evaluate/tools.h"
23+
#include "flang/Lower/AbstractConverter.h"
2224
#include "flang/Lower/Bridge.h"
2325
#include "flang/Lower/ConvertExpr.h"
2426
#include "flang/Lower/ConvertVariable.h"
2527
#include "flang/Lower/OpenACC.h"
2628
#include "flang/Lower/OpenMP.h"
2729
#include "flang/Lower/PFTBuilder.h"
2830
#include "flang/Lower/StatementContext.h"
31+
#include "flang/Lower/Support/Utils.h"
2932
#include "flang/Optimizer/Builder/BoxValue.h"
3033
#include "flang/Optimizer/Builder/FIRBuilder.h"
3134
#include "flang/Optimizer/Builder/Todo.h"
@@ -36,7 +39,9 @@
3639
#include "mlir/Dialect/OpenACC/OpenACC.h"
3740
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3841
#include "mlir/Dialect/SCF/IR/SCF.h"
42+
#include "mlir/IR/Value.h"
3943
#include "llvm/Frontend/OpenMP/OMPConstants.h"
44+
#include <list>
4045
#include <type_traits>
4146

4247
namespace Fortran {
@@ -611,6 +616,309 @@ void createEmptyRegionBlocks(
611616
}
612617
}
613618

619+
inline mlir::Value
620+
getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
621+
fir::FirOpBuilder &builder,
622+
Fortran::lower::SymbolRef sym, mlir::Location loc) {
623+
mlir::Value symAddr = converter.getSymbolAddress(sym);
624+
// TODO: Might need revisiting to handle for non-shared clauses
625+
if (!symAddr) {
626+
if (const auto *details =
627+
sym->detailsIf<Fortran::semantics::HostAssocDetails>())
628+
symAddr = converter.getSymbolAddress(details->symbol());
629+
}
630+
631+
if (!symAddr)
632+
llvm::report_fatal_error("could not retrieve symbol address");
633+
634+
if (auto boxTy =
635+
fir::unwrapRefType(symAddr.getType()).dyn_cast<fir::BaseBoxType>()) {
636+
if (boxTy.getEleTy().isa<fir::RecordType>())
637+
TODO(loc, "derived type");
638+
639+
// Load the box when baseAddr is a `fir.ref<fir.box<T>>` or a
640+
// `fir.ref<fir.class<T>>` type.
641+
if (symAddr.getType().isa<fir::ReferenceType>())
642+
return builder.create<fir::LoadOp>(loc, symAddr);
643+
}
644+
return symAddr;
645+
}
646+
647+
/// Generate the bounds operation from the descriptor information.
648+
template <typename BoundsOp, typename BoundsType>
649+
llvm::SmallVector<mlir::Value>
650+
genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
651+
Fortran::lower::AbstractConverter &converter,
652+
fir::ExtendedValue dataExv, mlir::Value box) {
653+
llvm::SmallVector<mlir::Value> bounds;
654+
mlir::Type idxTy = builder.getIndexType();
655+
mlir::Type boundTy = builder.getType<BoundsType>();
656+
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
657+
assert(box.getType().isa<fir::BaseBoxType>() &&
658+
"expect fir.box or fir.class");
659+
for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
660+
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim);
661+
mlir::Value baseLb =
662+
fir::factory::readLowerBound(builder, loc, dataExv, dim, one);
663+
auto dimInfo =
664+
builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, d);
665+
mlir::Value lb = builder.createIntegerConstant(loc, idxTy, 0);
666+
mlir::Value ub =
667+
builder.create<mlir::arith::SubIOp>(loc, dimInfo.getExtent(), one);
668+
mlir::Value bound =
669+
builder.create<BoundsOp>(loc, boundTy, lb, ub, mlir::Value(),
670+
dimInfo.getByteStride(), true, baseLb);
671+
bounds.push_back(bound);
672+
}
673+
return bounds;
674+
}
675+
676+
/// Generate bounds operation for base array without any subscripts
677+
/// provided.
678+
template <typename BoundsOp, typename BoundsType>
679+
llvm::SmallVector<mlir::Value>
680+
genBaseBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
681+
Fortran::lower::AbstractConverter &converter,
682+
fir::ExtendedValue dataExv, mlir::Value baseAddr) {
683+
mlir::Type idxTy = builder.getIndexType();
684+
mlir::Type boundTy = builder.getType<BoundsType>();
685+
llvm::SmallVector<mlir::Value> bounds;
686+
687+
if (dataExv.rank() == 0)
688+
return bounds;
689+
690+
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
691+
for (std::size_t dim = 0; dim < dataExv.rank(); ++dim) {
692+
mlir::Value baseLb =
693+
fir::factory::readLowerBound(builder, loc, dataExv, dim, one);
694+
mlir::Value ext = fir::factory::readExtent(builder, loc, dataExv, dim);
695+
mlir::Value lb = builder.createIntegerConstant(loc, idxTy, 0);
696+
697+
// ub = extent - 1
698+
mlir::Value ub = builder.create<mlir::arith::SubIOp>(loc, ext, one);
699+
mlir::Value bound =
700+
builder.create<BoundsOp>(loc, boundTy, lb, ub, ext, one, false, baseLb);
701+
bounds.push_back(bound);
702+
}
703+
return bounds;
704+
}
705+
706+
/// Generate bounds operations for an array section when subscripts are
707+
/// provided.
708+
template <typename BoundsOp, typename BoundsType>
709+
llvm::SmallVector<mlir::Value>
710+
genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
711+
Fortran::lower::AbstractConverter &converter,
712+
Fortran::lower::StatementContext &stmtCtx,
713+
const std::list<Fortran::parser::SectionSubscript> &subscripts,
714+
std::stringstream &asFortran, fir::ExtendedValue &dataExv,
715+
mlir::Value baseAddr) {
716+
int dimension = 0;
717+
mlir::Type idxTy = builder.getIndexType();
718+
mlir::Type boundTy = builder.getType<BoundsType>();
719+
llvm::SmallVector<mlir::Value> bounds;
720+
721+
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
722+
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
723+
for (const auto &subscript : subscripts) {
724+
if (const auto *triplet{
725+
std::get_if<Fortran::parser::SubscriptTriplet>(&subscript.u)}) {
726+
if (dimension != 0)
727+
asFortran << ',';
728+
mlir::Value lbound, ubound, extent;
729+
std::optional<std::int64_t> lval, uval;
730+
mlir::Value baseLb =
731+
fir::factory::readLowerBound(builder, loc, dataExv, dimension, one);
732+
bool defaultLb = baseLb == one;
733+
mlir::Value stride = one;
734+
bool strideInBytes = false;
735+
736+
if (fir::unwrapRefType(baseAddr.getType()).isa<fir::BaseBoxType>()) {
737+
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dimension);
738+
auto dimInfo = builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy,
739+
baseAddr, d);
740+
stride = dimInfo.getByteStride();
741+
strideInBytes = true;
742+
}
743+
744+
const auto &lower{std::get<0>(triplet->t)};
745+
if (lower) {
746+
lval = Fortran::semantics::GetIntValue(lower);
747+
if (lval) {
748+
if (defaultLb) {
749+
lbound = builder.createIntegerConstant(loc, idxTy, *lval - 1);
750+
} else {
751+
mlir::Value lb = builder.createIntegerConstant(loc, idxTy, *lval);
752+
lbound = builder.create<mlir::arith::SubIOp>(loc, lb, baseLb);
753+
}
754+
asFortran << *lval;
755+
} else {
756+
const Fortran::lower::SomeExpr *lexpr =
757+
Fortran::semantics::GetExpr(*lower);
758+
mlir::Value lb =
759+
fir::getBase(converter.genExprValue(loc, *lexpr, stmtCtx));
760+
lb = builder.createConvert(loc, baseLb.getType(), lb);
761+
lbound = builder.create<mlir::arith::SubIOp>(loc, lb, baseLb);
762+
asFortran << lexpr->AsFortran();
763+
}
764+
} else {
765+
lbound = defaultLb ? zero : baseLb;
766+
}
767+
asFortran << ':';
768+
const auto &upper{std::get<1>(triplet->t)};
769+
if (upper) {
770+
uval = Fortran::semantics::GetIntValue(upper);
771+
if (uval) {
772+
if (defaultLb) {
773+
ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1);
774+
} else {
775+
mlir::Value ub = builder.createIntegerConstant(loc, idxTy, *uval);
776+
ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb);
777+
}
778+
asFortran << *uval;
779+
} else {
780+
const Fortran::lower::SomeExpr *uexpr =
781+
Fortran::semantics::GetExpr(*upper);
782+
mlir::Value ub =
783+
fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx));
784+
ub = builder.createConvert(loc, baseLb.getType(), ub);
785+
ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb);
786+
asFortran << uexpr->AsFortran();
787+
}
788+
}
789+
if (lower && upper) {
790+
if (lval && uval && *uval < *lval) {
791+
mlir::emitError(loc, "zero sized array section");
792+
break;
793+
} else if (std::get<2>(triplet->t)) {
794+
const auto &strideExpr{std::get<2>(triplet->t)};
795+
if (strideExpr) {
796+
mlir::emitError(loc, "stride cannot be specified on "
797+
"an OpenMP array section");
798+
break;
799+
}
800+
}
801+
}
802+
// ub = baseLb + extent - 1
803+
if (!ubound) {
804+
mlir::Value ext =
805+
fir::factory::readExtent(builder, loc, dataExv, dimension);
806+
mlir::Value lbExt =
807+
builder.create<mlir::arith::AddIOp>(loc, ext, baseLb);
808+
ubound = builder.create<mlir::arith::SubIOp>(loc, lbExt, one);
809+
}
810+
mlir::Value bound = builder.create<BoundsOp>(
811+
loc, boundTy, lbound, ubound, extent, stride, strideInBytes, baseLb);
812+
bounds.push_back(bound);
813+
++dimension;
814+
}
815+
}
816+
return bounds;
817+
}
818+
819+
template <typename ObjectType, typename BoundsOp, typename BoundsType>
820+
mlir::Value gatherDataOperandAddrAndBounds(
821+
Fortran::lower::AbstractConverter &converter, fir::FirOpBuilder &builder,
822+
Fortran::semantics::SemanticsContext &semanticsContext,
823+
Fortran::lower::StatementContext &stmtCtx, const ObjectType &object,
824+
mlir::Location operandLocation, std::stringstream &asFortran,
825+
llvm::SmallVector<mlir::Value> &bounds) {
826+
mlir::Value baseAddr;
827+
828+
std::visit(
829+
Fortran::common::visitors{
830+
[&](const Fortran::parser::Designator &designator) {
831+
if (auto expr{Fortran::semantics::AnalyzeExpr(semanticsContext,
832+
designator)}) {
833+
if ((*expr).Rank() > 0 &&
834+
Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
835+
designator)) {
836+
const auto *arrayElement =
837+
Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
838+
designator);
839+
const auto *dataRef =
840+
std::get_if<Fortran::parser::DataRef>(&designator.u);
841+
fir::ExtendedValue dataExv;
842+
if (Fortran::parser::Unwrap<
843+
Fortran::parser::StructureComponent>(
844+
arrayElement->base)) {
845+
auto exprBase = Fortran::semantics::AnalyzeExpr(
846+
semanticsContext, arrayElement->base);
847+
dataExv = converter.genExprAddr(operandLocation, *exprBase,
848+
stmtCtx);
849+
baseAddr = fir::getBase(dataExv);
850+
asFortran << (*exprBase).AsFortran();
851+
} else {
852+
const Fortran::parser::Name &name =
853+
Fortran::parser::GetLastName(*dataRef);
854+
baseAddr = getDataOperandBaseAddr(
855+
converter, builder, *name.symbol, operandLocation);
856+
dataExv = converter.getSymbolExtendedValue(*name.symbol);
857+
asFortran << name.ToString();
858+
}
859+
860+
if (!arrayElement->subscripts.empty()) {
861+
asFortran << '(';
862+
bounds = genBoundsOps<BoundsType, BoundsOp>(
863+
builder, operandLocation, converter, stmtCtx,
864+
arrayElement->subscripts, asFortran, dataExv, baseAddr);
865+
}
866+
asFortran << ')';
867+
} else if (Fortran::parser::Unwrap<
868+
Fortran::parser::StructureComponent>(designator)) {
869+
fir::ExtendedValue compExv =
870+
converter.genExprAddr(operandLocation, *expr, stmtCtx);
871+
baseAddr = fir::getBase(compExv);
872+
if (fir::unwrapRefType(baseAddr.getType())
873+
.isa<fir::SequenceType>())
874+
bounds = genBaseBoundsOps<BoundsType, BoundsOp>(
875+
builder, operandLocation, converter, compExv, baseAddr);
876+
asFortran << (*expr).AsFortran();
877+
878+
// If the component is an allocatable or pointer the result of
879+
// genExprAddr will be the result of a fir.box_addr operation.
880+
// Retrieve the box so we handle it like other descriptor.
881+
if (auto boxAddrOp = mlir::dyn_cast_or_null<fir::BoxAddrOp>(
882+
baseAddr.getDefiningOp())) {
883+
baseAddr = boxAddrOp.getVal();
884+
bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>(
885+
builder, operandLocation, converter, compExv, baseAddr);
886+
}
887+
} else {
888+
// Scalar or full array.
889+
if (const auto *dataRef{
890+
std::get_if<Fortran::parser::DataRef>(&designator.u)}) {
891+
const Fortran::parser::Name &name =
892+
Fortran::parser::GetLastName(*dataRef);
893+
fir::ExtendedValue dataExv =
894+
converter.getSymbolExtendedValue(*name.symbol);
895+
baseAddr = getDataOperandBaseAddr(
896+
converter, builder, *name.symbol, operandLocation);
897+
if (fir::unwrapRefType(baseAddr.getType())
898+
.isa<fir::BaseBoxType>())
899+
bounds = genBoundsOpsFromBox<BoundsType, BoundsOp>(
900+
builder, operandLocation, converter, dataExv, baseAddr);
901+
if (fir::unwrapRefType(baseAddr.getType())
902+
.isa<fir::SequenceType>())
903+
bounds = genBaseBoundsOps<BoundsType, BoundsOp>(
904+
builder, operandLocation, converter, dataExv, baseAddr);
905+
asFortran << name.ToString();
906+
} else { // Unsupported
907+
llvm::report_fatal_error(
908+
"Unsupported type of OpenACC operand");
909+
}
910+
}
911+
}
912+
},
913+
[&](const Fortran::parser::Name &name) {
914+
baseAddr = getDataOperandBaseAddr(converter, builder, *name.symbol,
915+
operandLocation);
916+
asFortran << name.ToString();
917+
}},
918+
object.u);
919+
return baseAddr;
920+
}
921+
614922
} // namespace lower
615923
} // namespace Fortran
616924

0 commit comments

Comments
 (0)