Skip to content

Commit bf2bdf3

Browse files
committed
[Flang][MLIR] Move majority of OpenMP lowering of descriptor types into Opt Pass
1 parent 8e63299 commit bf2bdf3

File tree

7 files changed

+173
-95
lines changed

7 files changed

+173
-95
lines changed

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ std::unique_ptr<mlir::Pass>
7676
createAlgebraicSimplificationPass(const mlir::GreedyRewriteConfig &config);
7777
std::unique_ptr<mlir::Pass> createPolymorphicOpConversionPass();
7878

79+
std::unique_ptr<mlir::Pass> createOMPDescriptorMapInfoGenPass();
7980
std::unique_ptr<mlir::Pass> createOMPFunctionFilteringPass();
8081
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
8182
createOMPMarkDeclareTargetPass();

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,17 @@ def LoopVersioning : Pass<"loop-versioning", "mlir::func::FuncOp"> {
318318
let dependentDialects = [ "fir::FIROpsDialect" ];
319319
}
320320

321+
def OMPDescriptorMapInfoGenPass
322+
: Pass<"omp-descriptor-map-info-gen", "mlir::ModuleOp"> {
323+
let summary = "expands OpenMP MapInfo operations containing descriptors";
324+
let description = [{
325+
Expands MapInfo operations containing descriptor types into multiple MapInfo's for each pointer element in
326+
the descriptor that requires explicit individual mapping by the OpenMP runtime.
327+
}];
328+
let constructor = "::fir::createOMPDescriptorMapInfoGenPass()";
329+
let dependentDialects = ["mlir::omp::OpenMPDialect"];
330+
}
331+
321332
def OMPMarkDeclareTargetPass
322333
: Pass<"omp-mark-declare-target", "mlir::ModuleOp"> {
323334
let summary = "Marks all functions called by an OpenMP declare target function as declare target";

flang/include/flang/Tools/CLOptions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ inline void createHLFIRToFIRPassPipeline(
274274
/// rather than the host device.
275275
inline void createOpenMPFIRPassPipeline(
276276
mlir::PassManager &pm, bool isTargetDevice) {
277+
pm.addPass(fir::createOMPDescriptorMapInfoGenPass());
277278
pm.addPass(fir::createOMPMarkDeclareTargetPass());
278279
if (isTargetDevice)
279280
pm.addPass(fir::createOMPFunctionFilteringPass());

flang/lib/Lower/OpenMP.cpp

Lines changed: 34 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1740,54 +1740,6 @@ createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
17401740
return op;
17411741
}
17421742

1743-
static mlir::omp::MapInfoOp processDescriptorTypeMappings(
1744-
Fortran::semantics::SemanticsContext &semanticsContext,
1745-
Fortran::lower::StatementContext &stmtCtx,
1746-
Fortran::lower::AbstractConverter &converter, mlir::Location loc,
1747-
mlir::Value descriptorAddr, mlir::Value descDataBaseAddr,
1748-
mlir::ValueRange bounds, std::string asFortran,
1749-
llvm::omp::OpenMPOffloadMappingFlags mapCaptureType) {
1750-
llvm::SmallVector<mlir::Value> descriptorBaseAddrMembers;
1751-
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
1752-
1753-
mlir::Value descriptor = descriptorAddr;
1754-
1755-
// The fir::BoxOffsetOp only works with !fir.ref<!fir.box<...>> types, as
1756-
// allowing it to access non-reference box operations can cause some
1757-
// problematic SSA IR. However, in the case of assumed shape's the type
1758-
// is not a !fir.ref, in these cases to retrieve the appropriate
1759-
// !fir.ref<!fir.box<...>> to access the data we need to map we must
1760-
// perform an alloca and then store to it and retrieve the data from the new
1761-
// alloca.
1762-
if (mlir::isa<fir::BaseBoxType>(descriptorAddr.getType())) {
1763-
mlir::OpBuilder::InsertPoint insPt = firOpBuilder.saveInsertionPoint();
1764-
firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
1765-
descriptor =
1766-
firOpBuilder.create<fir::AllocaOp>(loc, descriptorAddr.getType());
1767-
firOpBuilder.restoreInsertionPoint(insPt);
1768-
firOpBuilder.create<fir::StoreOp>(loc, descriptorAddr, descriptor);
1769-
}
1770-
1771-
mlir::Value baseAddrAddr = firOpBuilder.create<fir::BoxOffsetOp>(
1772-
loc, descriptor, fir::BoxFieldAttr::base_addr);
1773-
1774-
descriptorBaseAddrMembers.push_back(createMapInfoOp(
1775-
firOpBuilder, loc, descDataBaseAddr, baseAddrAddr, asFortran, bounds, {},
1776-
static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1777-
mapCaptureType),
1778-
mlir::omp::VariableCaptureKind::ByRef, descDataBaseAddr.getType()));
1779-
1780-
// TODO: map the addendum segment of the descriptor, similarly to the above
1781-
// base address/data pointer member.
1782-
1783-
return createMapInfoOp(
1784-
firOpBuilder, loc, descriptor, mlir::Value{}, asFortran, {},
1785-
descriptorBaseAddrMembers,
1786-
static_cast<std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1787-
mapCaptureType),
1788-
mlir::omp::VariableCaptureKind::ByRef, descriptor.getType());
1789-
}
1790-
17911743
bool ClauseProcessor::processMap(
17921744
mlir::Location currentLocation, const llvm::omp::Directive &directive,
17931745
Fortran::semantics::SemanticsContext &semanticsContext,
@@ -1857,24 +1809,20 @@ bool ClauseProcessor::processMap(
18571809

18581810
auto origSymbol =
18591811
converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
1860-
mlir::Value mapOp, symAddr;
1861-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) {
1812+
mlir::Value symAddr = info.addr;
1813+
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
18621814
symAddr = origSymbol;
1863-
mapOp = processDescriptorTypeMappings(
1864-
semanticsContext, stmtCtx, converter, clauseLocation,
1865-
origSymbol, info.addr, bounds, asFortran.str(), mapTypeBits);
1866-
} else {
1867-
// Explicit map captures are captured ByRef by default,
1868-
// optimisation passes may alter this to ByCopy or other capture
1869-
// types to optimise
1870-
symAddr = info.addr;
1871-
mapOp = createMapInfoOp(
1872-
firOpBuilder, clauseLocation, info.addr, mlir::Value{},
1873-
asFortran.str(), bounds, {},
1874-
static_cast<std::underlying_type_t<
1875-
llvm::omp::OpenMPOffloadMappingFlags>>(mapTypeBits),
1876-
mlir::omp::VariableCaptureKind::ByRef, info.addr.getType());
1877-
}
1815+
1816+
// Explicit map captures are captured ByRef by default,
1817+
// optimisation passes may alter this to ByCopy or other capture
1818+
// types to optimise
1819+
mlir::Value mapOp = createMapInfoOp(
1820+
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
1821+
asFortran.str(), bounds, {},
1822+
static_cast<
1823+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1824+
mapTypeBits),
1825+
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
18781826

18791827
mapOperands.push_back(mapOp);
18801828
if (mapSymTypes)
@@ -1989,22 +1937,20 @@ bool ClauseProcessor::processMotionClauses(
19891937

19901938
auto origSymbol =
19911939
converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
1992-
mlir::Value mapOp;
1993-
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType())) {
1994-
mapOp = processDescriptorTypeMappings(
1995-
semanticsContext, stmtCtx, converter, clauseLocation,
1996-
origSymbol, info.addr, bounds, asFortran.str(), mapTypeBits);
1997-
} else {
1998-
// Explicit map captures are captured ByRef by default,
1999-
// optimisation passes may alter this to ByCopy or other capture
2000-
// types to optimise
2001-
mapOp = createMapInfoOp(
2002-
firOpBuilder, clauseLocation, info.addr, mlir::Value{},
2003-
asFortran.str(), bounds, {},
2004-
static_cast<std::underlying_type_t<
2005-
llvm::omp::OpenMPOffloadMappingFlags>>(mapTypeBits),
2006-
mlir::omp::VariableCaptureKind::ByRef, info.addr.getType());
2007-
}
1940+
mlir::Value symAddr = info.addr;
1941+
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
1942+
symAddr = origSymbol;
1943+
1944+
// Explicit map captures are captured ByRef by default,
1945+
// optimisation passes may alter this to ByCopy or other capture
1946+
// types to optimise
1947+
mlir::Value mapOp = createMapInfoOp(
1948+
firOpBuilder, clauseLocation, symAddr, mlir::Value{},
1949+
asFortran.str(), bounds, {},
1950+
static_cast<
1951+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
1952+
mapTypeBits),
1953+
mlir::omp::VariableCaptureKind::ByRef, symAddr.getType());
20081954

20091955
mapOperands.push_back(mapOp);
20101956
}
@@ -2819,20 +2765,13 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
28192765
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
28202766
}
28212767

2822-
mlir::Value mapOp;
2823-
if (fir::isTypeWithDescriptor(baseOp.getType())) {
2824-
mapOp = processDescriptorTypeMappings(
2825-
semanticsContext, stmtCtx, converter, baseOp.getLoc(), baseOp,
2826-
info.addr, bounds, name.str(), mapFlag);
2827-
} else {
2828-
mapOp = createMapInfoOp(
2829-
converter.getFirOpBuilder(), baseOp.getLoc(), baseOp,
2830-
mlir::Value{}, name.str(), bounds, {},
2831-
static_cast<
2832-
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2833-
mapFlag),
2834-
captureKind, baseOp.getType());
2835-
}
2768+
mlir::Value mapOp = createMapInfoOp(
2769+
converter.getFirOpBuilder(), baseOp.getLoc(), baseOp, mlir::Value{},
2770+
name.str(), bounds, {},
2771+
static_cast<
2772+
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
2773+
mapFlag),
2774+
captureKind, baseOp.getType());
28362775

28372776
mapOperands.push_back(mapOp);
28382777
mapSymTypes.push_back(baseOp.getType());

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_flang_library(FIRTransforms
1717
AddDebugFoundation.cpp
1818
PolymorphicOpConversion.cpp
1919
LoopVersioning.cpp
20+
OMPDescriptorMapInfoGen.cpp
2021
OMPFunctionFiltering.cpp
2122
OMPMarkDeclareTarget.cpp
2223
VScaleAttr.cpp
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
#include "flang/Optimizer/Builder/FIRBuilder.h"
2+
#include "flang/Optimizer/Dialect/FIRType.h"
3+
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
4+
#include "flang/Optimizer/Transforms/Passes.h"
5+
#include "mlir/Dialect/Func/IR/FuncOps.h"
6+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
7+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
8+
#include "mlir/IR/BuiltinDialect.h"
9+
#include "mlir/IR/BuiltinOps.h"
10+
#include "mlir/IR/Operation.h"
11+
#include "mlir/IR/SymbolTable.h"
12+
#include "mlir/Pass/Pass.h"
13+
#include "mlir/Support/LLVM.h"
14+
#include "llvm/ADT/SmallPtrSet.h"
15+
16+
namespace fir {
17+
#define GEN_PASS_DEF_OMPDESCRIPTORMAPINFOGENPASS
18+
#include "flang/Optimizer/Transforms/Passes.h.inc"
19+
} // namespace fir
20+
21+
namespace {
22+
class OMPDescriptorMapInfoGenPass
23+
: public fir::impl::OMPDescriptorMapInfoGenPassBase<
24+
OMPDescriptorMapInfoGenPass> {
25+
26+
mlir::omp::MapInfoOp
27+
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
28+
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
29+
mlir::SmallVector<mlir::Value> bounds,
30+
mlir::SmallVector<mlir::Value> members, uint64_t mapType,
31+
mlir::omp::VariableCaptureKind mapCaptureType,
32+
mlir::Type retTy, bool isVal = false) {
33+
if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
34+
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
35+
retTy = baseAddr.getType();
36+
}
37+
38+
mlir::TypeAttr varType = mlir::TypeAttr::get(
39+
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
40+
41+
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
42+
loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
43+
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
44+
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
45+
builder.getStringAttr(name));
46+
47+
return op;
48+
}
49+
50+
void genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
51+
fir::FirOpBuilder &builder) {
52+
llvm::SmallVector<mlir::Value> descriptorBaseAddrMembers;
53+
mlir::Location loc = builder.getUnknownLoc();
54+
mlir::Value descriptor = op.getVarPtr();
55+
56+
// If we enter this function, but the mapped type itself is not the
57+
// descriptor, then it's likely the address of the descriptor so we
58+
// must retrieve the descriptor SSA.
59+
if (!fir::isTypeWithDescriptor(op.getVarType())) {
60+
if (auto addrOp = mlir::dyn_cast_if_present<fir::BoxAddrOp>(
61+
op.getVarPtr().getDefiningOp())) {
62+
descriptor = addrOp.getVal();
63+
}
64+
}
65+
66+
// The fir::BoxOffsetOp only works with !fir.ref<!fir.box<...>> types, as
67+
// allowing it to access non-reference box operations can cause some
68+
// problematic SSA IR. However, in the case of assumed shape's the type
69+
// is not a !fir.ref, in these cases to retrieve the appropriate
70+
// !fir.ref<!fir.box<...>> to access the data we need to map we must
71+
// perform an alloca and then store to it and retrieve the data from the new
72+
// alloca.
73+
if (mlir::isa<fir::BaseBoxType>(descriptor.getType())) {
74+
mlir::OpBuilder::InsertPoint insPt = builder.saveInsertionPoint();
75+
builder.setInsertionPointToStart(builder.getAllocaBlock());
76+
auto alloca = builder.create<fir::AllocaOp>(loc, descriptor.getType());
77+
builder.restoreInsertionPoint(insPt);
78+
builder.create<fir::StoreOp>(loc, descriptor, alloca);
79+
descriptor = alloca;
80+
}
81+
82+
mlir::Value baseAddrAddr = builder.create<fir::BoxOffsetOp>(
83+
loc, descriptor, fir::BoxFieldAttr::base_addr);
84+
85+
descriptorBaseAddrMembers.push_back(createMapInfoOp(
86+
builder, loc, baseAddrAddr, {}, "", op.getBounds(), {},
87+
op.getMapType().value(), mlir::omp::VariableCaptureKind::ByRef,
88+
fir::unwrapRefType(baseAddrAddr.getType())));
89+
90+
// TODO: map the addendum segment of the descriptor, similarly to the above
91+
// base address/data pointer member.
92+
93+
op.getVarPtrMutable().assign(descriptor);
94+
op.setVarType(fir::unwrapRefType(descriptor.getType()));
95+
op.getMembersMutable().assign(descriptorBaseAddrMembers);
96+
op.getBoundsMutable().assign(llvm::SmallVector<mlir::Value>{});
97+
}
98+
99+
// This pass executes on mlir::ModuleOp's finding omp::MapInfoOp's containing
100+
// descriptor based types (allocatables, pointers, assumed shape etc.) and
101+
// expanding them into multiple omp::MapInfoOp's for each pointer member
102+
// contained within the descriptor.
103+
void runOnOperation() override {
104+
fir::KindMapping kindMap = fir::getKindMapping(getOperation());
105+
fir::FirOpBuilder builder{getOperation(), std::move(kindMap)};
106+
107+
getOperation()->walk([&](mlir::omp::MapInfoOp op) {
108+
if (fir::isTypeWithDescriptor(op.getVarType()) ||
109+
mlir::isa<fir::BoxAddrOp>(op.getVarPtr().getDefiningOp())) {
110+
builder.setInsertionPoint(op);
111+
genDescriptorMemberMaps(op, builder);
112+
}
113+
});
114+
}
115+
};
116+
117+
} // namespace
118+
119+
namespace fir {
120+
std::unique_ptr<mlir::Pass> createOMPDescriptorMapInfoGenPass() {
121+
return std::make_unique<OMPDescriptorMapInfoGenPass>();
122+
}
123+
} // namespace fir

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,6 +1584,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
15841584
if (failed(translator.convertFunctions()))
15851585
return nullptr;
15861586

1587+
// translator.llvmModule->dump();
1588+
15871589
if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))
15881590
return nullptr;
15891591

0 commit comments

Comments
 (0)