Skip to content

Commit 3e4d168

Browse files
committed
[Flang][MLIR] Insert descriptor implicit members into map operands and BlockArgs
1 parent bf2bdf3 commit 3e4d168

File tree

5 files changed

+110
-74
lines changed

5 files changed

+110
-74
lines changed

flang/lib/Optimizer/Transforms/OMPDescriptorMapInfoGen.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Pass/Pass.h"
1313
#include "mlir/Support/LLVM.h"
1414
#include "llvm/ADT/SmallPtrSet.h"
15+
#include <iterator>
1516

1617
namespace fir {
1718
#define GEN_PASS_DEF_OMPDESCRIPTORMAPINFOGENPASS
@@ -48,7 +49,8 @@ class OMPDescriptorMapInfoGenPass
4849
}
4950

5051
void genDescriptorMemberMaps(mlir::omp::MapInfoOp op,
51-
fir::FirOpBuilder &builder) {
52+
fir::FirOpBuilder &builder,
53+
mlir::Operation *target) {
5254
llvm::SmallVector<mlir::Value> descriptorBaseAddrMembers;
5355
mlir::Location loc = builder.getUnknownLoc();
5456
mlir::Value descriptor = op.getVarPtr();
@@ -92,8 +94,36 @@ class OMPDescriptorMapInfoGenPass
9294

9395
op.getVarPtrMutable().assign(descriptor);
9496
op.setVarType(fir::unwrapRefType(descriptor.getType()));
95-
op.getMembersMutable().assign(descriptorBaseAddrMembers);
97+
op.getMembersMutable().append(descriptorBaseAddrMembers);
9698
op.getBoundsMutable().assign(llvm::SmallVector<mlir::Value>{});
99+
100+
// could use a template to generalise to other TargetOps
101+
if (auto mapClauseOwner =
102+
llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(target)) {
103+
llvm::SmallVector<mlir::Value> newMapOps;
104+
for (size_t i = 0; i < mapClauseOwner.getMapOperands().size(); ++i) {
105+
if (mapClauseOwner.getMapOperands()[i] == op) {
106+
// Push new implicit maps generated for the descriptor.
107+
newMapOps.push_back(descriptorBaseAddrMembers[0]);
108+
109+
// for TargetOp's which have IsolatedFromAbove we must align the
110+
// new additional map operand with an appropriate BlockArgument,
111+
// as the printing and later processing currently requires a 1:1
112+
// mapping of BlockArgs to MapInfoOp's at the same placement in
113+
// each array (BlockArgs and MapOperands).
114+
if (auto targetOp = llvm::dyn_cast<mlir::omp::TargetOp>(target)) {
115+
targetOp.getRegion().insertArgument(
116+
i, descriptorBaseAddrMembers[0].getType(), loc);
117+
}
118+
119+
newMapOps.push_back(mapClauseOwner.getMapOperands()[i]);
120+
} else {
121+
newMapOps.push_back(mapClauseOwner.getMapOperands()[i]);
122+
}
123+
}
124+
125+
mapClauseOwner.getMapOperandsMutable().assign(newMapOps);
126+
}
97127
}
98128

99129
// This pass executes on mlir::ModuleOp's finding omp::MapInfoOp's containing
@@ -108,7 +138,9 @@ class OMPDescriptorMapInfoGenPass
108138
if (fir::isTypeWithDescriptor(op.getVarType()) ||
109139
mlir::isa<fir::BoxAddrOp>(op.getVarPtr().getDefiningOp())) {
110140
builder.setInsertionPoint(op);
111-
genDescriptorMemberMaps(op, builder);
141+
// Currently a MapInfoOp argument can only show up on a single target
142+
// user so we can retrieve and use the first user.
143+
genDescriptorMemberMaps(op, builder, *op->getUsers().begin());
112144
}
113145
});
114146
}

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,8 @@ def MapInfoOp : OpenMP_Op<"map_info", [AttrSizedOperandSegments]> {
12321232
// 2.14.2 target data Construct
12331233
//===---------------------------------------------------------------------===//
12341234

1235-
def Target_DataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments]>{
1235+
def Target_DataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments,
1236+
MapClauseOwningOpInterface]>{
12361237
let summary = "target data construct";
12371238
let description = [{
12381239
Map variables to a device data environment for the extent of the region.
@@ -1289,7 +1290,8 @@ def Target_DataOp: OpenMP_Op<"target_data", [AttrSizedOperandSegments]>{
12891290
//===---------------------------------------------------------------------===//
12901291

12911292
def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
1292-
[AttrSizedOperandSegments]>{
1293+
[AttrSizedOperandSegments,
1294+
MapClauseOwningOpInterface]>{
12931295
let summary = "target enter data construct";
12941296
let description = [{
12951297
The target enter data directive specifies that variables are mapped to
@@ -1335,7 +1337,8 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
13351337
//===---------------------------------------------------------------------===//
13361338

13371339
def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
1338-
[AttrSizedOperandSegments]>{
1340+
[AttrSizedOperandSegments,
1341+
MapClauseOwningOpInterface]>{
13391342
let summary = "target exit data construct";
13401343
let description = [{
13411344
The target exit data directive specifies that variables are mapped to a
@@ -1381,7 +1384,8 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
13811384
//===---------------------------------------------------------------------===//
13821385

13831386
def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
1384-
[AttrSizedOperandSegments]>{
1387+
[AttrSizedOperandSegments,
1388+
MapClauseOwningOpInterface]>{
13851389
let summary = "target update data construct";
13861390
let description = [{
13871391
The target update directive makes the corresponding list items in the device
@@ -1413,13 +1417,13 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
14131417
let arguments = (ins Optional<I1>:$if_expr,
14141418
Optional<AnyInteger>:$device,
14151419
UnitAttr:$nowait,
1416-
Variadic<OpenMP_PointerLikeType>:$motion_operands);
1420+
Variadic<OpenMP_PointerLikeType>:$map_operands);
14171421

14181422
let assemblyFormat = [{
14191423
oilist(`if` `(` $if_expr `:` type($if_expr) `)`
14201424
| `device` `(` $device `:` type($device) `)`
14211425
| `nowait` $nowait
1422-
| `motion_entries` `(` $motion_operands `:` type($motion_operands) `)`
1426+
| `motion_entries` `(` $map_operands `:` type($map_operands) `)`
14231427
) attr-dict
14241428
}];
14251429

@@ -1430,7 +1434,8 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
14301434
// 2.14.5 target construct
14311435
//===----------------------------------------------------------------------===//
14321436

1433-
def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, OutlineableOpenMPOpInterface, AttrSizedOperandSegments]> {
1437+
def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface,
1438+
OutlineableOpenMPOpInterface, AttrSizedOperandSegments]> {
14341439
let summary = "target construct";
14351440
let description = [{
14361441
The target construct includes a region of code which is to be executed

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ include "mlir/IR/OpBase.td"
1818
def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
1919
let description = [{
2020
OpenMP operations whose region will be outlined will implement this
21-
interface. These operations will
21+
interface.
2222
}];
2323

2424
let cppNamespace = "::mlir::omp";
@@ -31,6 +31,28 @@ def OutlineableOpenMPOpInterface : OpInterface<"OutlineableOpenMPOpInterface"> {
3131
];
3232
}
3333

34+
def MapClauseOwningOpInterface : OpInterface<"MapClauseOwningOpInterface"> {
35+
let description = [{
36+
OpenMP operations which own a list of omp::MapInfoOp's implement this interface
37+
to allow generic access to deal with map operands to more easily manipulate
38+
this class of operations.
39+
}];
40+
41+
let cppNamespace = "::mlir::omp";
42+
43+
let methods = [
44+
InterfaceMethod<"Get map operands", "::mlir::OperandRange", "getMapOperands",
45+
(ins), [{
46+
return $_op.getMapOperands();
47+
}]>,
48+
InterfaceMethod<"Get mutable map operands", "::mlir::MutableOperandRange",
49+
"getMapOperandsMutable",
50+
(ins), [{
51+
return $_op.getMapOperandsMutable();
52+
}]>,
53+
];
54+
}
55+
3456
def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
3557
let description = [{
3658
OpenMP operations that support reduction clause have this interface.

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ LogicalResult ExitDataOp::verify() {
980980
}
981981

982982
LogicalResult UpdateDataOp::verify() {
983-
return verifyMapClause(*this, getMotionOperands());
983+
return verifyMapClause(*this, getMapOperands());
984984
}
985985

986986
LogicalResult TargetOp::verify() {

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 39 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
16401640
// value) more than neccessary.
16411641
struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
16421642
llvm::SmallVector<bool, 4> IsDeclareTarget;
1643+
llvm::SmallVector<bool, 4> IsAMember;
16431644
llvm::SmallVector<mlir::Operation *, 4> MapClause;
16441645
llvm::SmallVector<llvm::Value *, 4> OriginalValue;
16451646
// Stripped off array/pointer to get the underlying
@@ -1728,13 +1729,11 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
17281729
LLVM::ModuleTranslation &moduleTranslation,
17291730
DataLayout &dl,
17301731
llvm::IRBuilderBase &builder) {
1731-
auto mapFill = [&](mlir::Value mapValue) {
1732-
assert(mlir::isa<mlir::omp::MapInfoOp>(mapValue.getDefiningOp()) &&
1733-
"missing map info operation or incorrect map info operation type");
1732+
for (mlir::Value mapValue : mapOperands) {
17341733
if (auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
17351734
mapValue.getDefiningOp())) {
1736-
mapData.OriginalValue.push_back(moduleTranslation.lookupValue(
1737-
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr()));
1735+
mapData.OriginalValue.push_back(
1736+
moduleTranslation.lookupValue(mapOp.getVarPtr()));
17381737
mapData.Pointers.push_back(mapData.OriginalValue.back());
17391738

17401739
if (llvm::Value *refPtr =
@@ -1759,41 +1758,27 @@ void collectMapDataFromMapOperands(MapInfoData &mapData,
17591758
mapOp.getLoc(), *moduleTranslation.getOpenMPBuilder()));
17601759
mapData.DevicePointers.push_back(
17611760
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
1762-
}
1763-
};
17641761

1765-
// In the case of Fortran descriptors some members get added implicitly
1766-
// after the target region has been generated during CodeGen lowering
1767-
// which prevents them from being added trivially to the target region
1768-
// as map arguments, we must handle this case here by generating
1769-
// MapInfoData for them.
1770-
// TODO: Revisit this when implementing derived types explicit member
1771-
// mapping, we likely want to represent these identically to simplify
1772-
// the overall lowering
1773-
SmallVector<Value> mapMemberOperands;
1774-
for (size_t i = 0; i < mapOperands.size(); ++i) {
1775-
auto mapInfoOp =
1776-
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOperands[i].getDefiningOp());
1777-
for (auto members : mapInfoOp.getMembers()) {
1778-
if (!std::any_of(mapOperands.begin(), mapOperands.end(),
1779-
[&](auto mapOp) {
1780-
return mapOp.getDefiningOp() ==
1781-
members.getDefiningOp();
1782-
}) &&
1783-
!std::any_of(mapMemberOperands.begin(), mapMemberOperands.end(),
1784-
[&](auto mapOp) {
1785-
return mapOp.getDefiningOp() ==
1786-
members.getDefiningOp();
1787-
}))
1788-
mapMemberOperands.push_back(members);
1762+
// Check if this is a member mapping and correctly assign that it is, if
1763+
// it is a member of a larger object.
1764+
// TODO: Need better handling of members, and distinguishing of members
1765+
// that are implicitly allocated on device vs explicitly passed in as
1766+
// arguments.
1767+
// TODO: May require some further additions to support nested record
1768+
// types, i.e. member maps that can have member maps.
1769+
mapData.IsAMember.push_back(false);
1770+
for (mlir::Value mapValue : mapOperands) {
1771+
if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
1772+
mapValue.getDefiningOp())) {
1773+
for (auto member : map.getMembers()) {
1774+
if (member == mapOp) {
1775+
mapData.IsAMember.back() = true;
1776+
}
1777+
}
1778+
}
1779+
}
17891780
}
17901781
}
1791-
1792-
for (mlir::Value mapValue : mapOperands)
1793-
mapFill(mapValue);
1794-
1795-
for (mlir::Value mapValue : mapMemberOperands)
1796-
mapFill(mapValue);
17971782
}
17981783

17991784
static void processMapWithMembersOf(
@@ -1960,35 +1945,22 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
19601945
combinedInfo.Names.clear();
19611946
};
19621947

1963-
llvm::SmallVector<size_t, 4> primaryMapIdx;
1964-
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
1965-
primaryMapIdx.push_back(i);
1966-
}
1967-
1968-
// TODO: Handle nested MembersOf, currently only cares about the first level
1969-
// of nesting (all that was relevant for Fortran descriptors), but a slight
1970-
// refactoring of mapInfoData to hold nestings or membersOf may be a better
1971-
// approach to simplify things.
1972-
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
1973-
auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
1974-
for (auto member : mapInfoOp.getMembers()) {
1975-
for (size_t j = 0; j < primaryMapIdx.size(); j++) {
1976-
if (member.getDefiningOp() == mapData.MapClause[primaryMapIdx[j]]) {
1977-
primaryMapIdx.erase(&primaryMapIdx[j]);
1978-
j--;
1979-
}
1980-
}
1981-
}
1982-
}
1983-
19841948
// We operate under the assumption that all vectors that are
19851949
// required in MapInfoData are of equal lengths (either filled with
19861950
// default constructed data or appropiate information) so we can
19871951
// utilise the size from any component of MapInfoData, if we can't
19881952
// something is missing from the initial MapInfoData construction.
1989-
for (unsigned long i : primaryMapIdx) {
1990-
auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
1953+
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
1954+
// NOTE/TODO: We currently do not handle member mapping seperately from it's
1955+
// parent or explicit mapping of a parent and member in the same operation,
1956+
// this will need to change in the near future, for now we primarily handle
1957+
// descriptor mapping from fortran, generalised as mapping record types
1958+
// with implicit member maps. This lowering needs further generalisation to
1959+
// fully support fortran derived types, and C/C++ structures and classes.
1960+
if (mapData.IsAMember[i])
1961+
continue;
19911962

1963+
auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
19921964
if (!mapInfoOp.getMembers().empty()) {
19931965
processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
19941966
combinedInfo, mapData, i, isTargetParams);
@@ -2136,7 +2108,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
21362108
deviceID = intAttr.getInt();
21372109

21382110
RTLFn = llvm::omp::OMPRTL___tgt_target_data_update_mapper;
2139-
mapOperands = updateDataOp.getMotionOperands();
2111+
mapOperands = updateDataOp.getMapOperands();
21402112
return success();
21412113
})
21422114
.Default([&](Operation *op) {
@@ -2633,7 +2605,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
26332605
llvm::SmallVector<llvm::Value *, 4> kernelInput;
26342606
for (size_t i = 0; i < mapOperands.size(); ++i) {
26352607
// declare target arguments are not passed to kernels as arguments
2636-
if (!mapData.IsDeclareTarget[i])
2608+
// TODO: We currently do not handle cases where a member is explicitly
2609+
// passed in as an argument, this will likley need to be handled in
2610+
// the near future, rather than using IsAMember, it may be better to
2611+
// test if the relevant BlockArg is used within the target region and
2612+
// then use that as a basis for exclusion in the kernel inputs.
2613+
if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
26372614
kernelInput.push_back(mapData.OriginalValue[i]);
26382615
}
26392616

0 commit comments

Comments
 (0)