Skip to content

Commit 9f34645

Browse files
committed
Address reviewer comments.
1 parent 547b339 commit 9f34645

File tree

1 file changed

+46
-50
lines changed

1 file changed

+46
-50
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 46 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,6 +2101,7 @@ getRefPtrIfDeclareTarget(mlir::Value value,
21012101
return nullptr;
21022102
}
21032103

2104+
namespace {
21042105
// A small helper structure to contain data gathered
21052106
// for map lowering and coalese it into one area and
21062107
// avoiding extra computations such as searches in the
@@ -2129,6 +2130,7 @@ struct MapInfoData : llvm::OpenMPIRBuilder::MapInfosTy {
21292130
llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
21302131
}
21312132
};
2133+
} // namespace
21322134

21332135
uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
21342136
if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
@@ -2195,16 +2197,15 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
21952197
return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
21962198
}
21972199

2198-
void collectMapDataFromMapOperands(
2199-
MapInfoData &mapData, llvm::SmallVectorImpl<Value> &mapVars,
2200+
static void collectMapDataFromMapOperands(
2201+
MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
22002202
LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
2201-
llvm::IRBuilderBase &builder,
2202-
const llvm::ArrayRef<Value> &useDevPtrOperands = {},
2203-
const llvm::ArrayRef<Value> &useDevAddrOperands = {}) {
2203+
llvm::IRBuilderBase &builder, const ArrayRef<Value> &useDevPtrOperands = {},
2204+
const ArrayRef<Value> &useDevAddrOperands = {}) {
22042205
// Process MapOperands
2205-
for (mlir::Value mapValue : mapVars) {
2206-
auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
2207-
mlir::Value offloadPtr =
2206+
for (Value mapValue : mapVars) {
2207+
auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2208+
Value offloadPtr =
22082209
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
22092210
mapData.OriginalValue.push_back(moduleTranslation.lookupValue(offloadPtr));
22102211
mapData.Pointers.push_back(mapData.OriginalValue.back());
@@ -2240,9 +2241,9 @@ void collectMapDataFromMapOperands(
22402241
// TODO: May require some further additions to support nested record
22412242
// types, i.e. member maps that can have member maps.
22422243
mapData.IsAMember.push_back(false);
2243-
for (mlir::Value mapValue : mapVars) {
2244-
if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2245-
mapValue.getDefiningOp())) {
2244+
for (Value mapValue : mapVars) {
2245+
if (auto map =
2246+
dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp())) {
22462247
for (auto member : map.getMembers()) {
22472248
if (member == mapOp) {
22482249
mapData.IsAMember.back() = true;
@@ -2271,9 +2272,9 @@ void collectMapDataFromMapOperands(
22712272
// Process useDevPtr(Addr)Operands
22722273
auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
22732274
llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
2274-
for (mlir::Value mapValue : useDevOperands) {
2275-
auto mapOp = mlir::cast<mlir::omp::MapInfoOp>(mapValue.getDefiningOp());
2276-
mlir::Value offloadPtr =
2275+
for (Value mapValue : useDevOperands) {
2276+
auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
2277+
Value offloadPtr =
22772278
mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
22782279
llvm::Value *origValue = moduleTranslation.lookupValue(offloadPtr);
22792280

@@ -2302,9 +2303,9 @@ void collectMapDataFromMapOperands(
23022303
// TODO: May require some further additions to support nested record
23032304
// types, i.e. member maps that can have member maps.
23042305
mapData.IsAMember.push_back(false);
2305-
for (mlir::Value mapValue : useDevOperands)
2306-
if (auto map = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(
2307-
mapValue.getDefiningOp()))
2306+
for (Value mapValue : useDevOperands)
2307+
if (auto map =
2308+
dyn_cast_if_present<omp::MapInfoOp>(mapValue.getDefiningOp()))
23082309
for (auto member : map.getMembers())
23092310
if (member == mapOp)
23102311
mapData.IsAMember.back() = true;
@@ -2316,22 +2317,21 @@ void collectMapDataFromMapOperands(
23162317
addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
23172318
}
23182319

2319-
static int getMapDataMemberIdx(MapInfoData &mapData,
2320-
mlir::omp::MapInfoOp memberOp) {
2320+
static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
23212321
auto *res = llvm::find(mapData.MapClause, memberOp);
23222322
assert(res != mapData.MapClause.end() &&
23232323
"MapInfoOp for member not found in MapData, cannot return index");
23242324
return std::distance(mapData.MapClause.begin(), res);
23252325
}
23262326

2327-
static mlir::omp::MapInfoOp
2328-
getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
2329-
mlir::DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
2327+
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
2328+
bool first) {
2329+
DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
23302330

23312331
// Only 1 member has been mapped, we can return it.
23322332
if (indexAttr.size() == 1)
2333-
if (auto mapOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(
2334-
mapInfo.getMembers()[0].getDefiningOp()))
2333+
if (auto mapOp =
2334+
dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()))
23352335
return mapOp;
23362336

23372337
llvm::ArrayRef<int64_t> shape = indexAttr.getShapedType().getShape();
@@ -2368,7 +2368,7 @@ getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
23682368
return false;
23692369
});
23702370

2371-
return llvm::cast<mlir::omp::MapInfoOp>(
2371+
return llvm::cast<omp::MapInfoOp>(
23722372
mapInfo.getMembers()[indices.front()].getDefiningOp());
23732373
}
23742374

@@ -2394,7 +2394,7 @@ getFirstOrLastMappedMemberPtr(mlir::omp::MapInfoOp mapInfo, bool first) {
23942394
std::vector<llvm::Value *>
23952395
calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
23962396
llvm::IRBuilderBase &builder, bool isArrayTy,
2397-
mlir::OperandRange bounds) {
2397+
OperandRange bounds) {
23982398
std::vector<llvm::Value *> idx;
23992399
// There's no bounds to calculate an offset from, we can safely
24002400
// ignore and return no indices.
@@ -2408,7 +2408,7 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
24082408
if (isArrayTy) {
24092409
idx.push_back(builder.getInt64(0));
24102410
for (int i = bounds.size() - 1; i >= 0; --i) {
2411-
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2411+
if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
24122412
bounds[i].getDefiningOp())) {
24132413
idx.push_back(moduleTranslation.lookupValue(boundOp.getLowerBound()));
24142414
}
@@ -2434,7 +2434,7 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
24342434
// (extent/size of current) 100 for 1000 for each index increment
24352435
std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(1)};
24362436
for (size_t i = 1; i < bounds.size(); ++i) {
2437-
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2437+
if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
24382438
bounds[i].getDefiningOp())) {
24392439
dimensionIndexSizeOffset.push_back(builder.CreateMul(
24402440
moduleTranslation.lookupValue(boundOp.getExtent()),
@@ -2447,7 +2447,7 @@ calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
24472447
// have calculated in the previous and accumulate the results to get
24482448
// our final resulting offset.
24492449
for (int i = bounds.size() - 1; i >= 0; --i) {
2450-
if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
2450+
if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
24512451
bounds[i].getDefiningOp())) {
24522452
if (idx.empty())
24532453
idx.emplace_back(builder.CreateMul(
@@ -2504,7 +2504,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
25042504
// data by the descriptor (which itself, is a structure containing
25052505
// runtime information on the dynamically allocated data).
25062506
auto parentClause =
2507-
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2507+
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
25082508

25092509
llvm::Value *lowAddr, *highAddr;
25102510
if (!parentClause.getPartialMap()) {
@@ -2516,8 +2516,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
25162516
builder.getPtrTy());
25172517
combinedInfo.Pointers.emplace_back(mapData.Pointers[mapDataIndex]);
25182518
} else {
2519-
auto mapOp =
2520-
mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2519+
auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
25212520
int firstMemberIdx = getMapDataMemberIdx(
25222521
mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
25232522
lowAddr = builder.CreatePointerCast(mapData.Pointers[firstMemberIdx],
@@ -2575,7 +2574,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
25752574
// There may be a better way to verify this, but unfortunately with
25762575
// opaque pointers we lose the ability to easily check if something is
25772576
// a pointer whilst maintaining access to the underlying type.
2578-
static bool checkIfPointerMap(mlir::omp::MapInfoOp mapOp) {
2577+
static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
25792578
// If we have a varPtrPtr field assigned then the underlying type is a pointer
25802579
if (mapOp.getVarPtrPtr())
25812580
return true;
@@ -2597,11 +2596,11 @@ static void processMapMembersWithParent(
25972596
uint64_t mapDataIndex, llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
25982597

25992598
auto parentClause =
2600-
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2599+
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
26012600

26022601
for (auto mappedMembers : parentClause.getMembers()) {
26032602
auto memberClause =
2604-
llvm::cast<mlir::omp::MapInfoOp>(mappedMembers.getDefiningOp());
2603+
llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
26052604
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
26062605

26072606
assert(memberDataIdx >= 0 && "could not find mapped member of structure");
@@ -2635,8 +2634,7 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
26352634
// OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
26362635
// marked with OMP_MAP_PTR_AND_OBJ instead.
26372636
auto mapFlag = mapData.Types[mapDataIdx];
2638-
auto mapInfoOp =
2639-
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
2637+
auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
26402638

26412639
bool isPtrTy = checkIfPointerMap(mapInfoOp);
26422640
if (isPtrTy)
@@ -2646,7 +2644,7 @@ processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
26462644
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
26472645

26482646
if (mapInfoOp.getMapCaptureType().value() ==
2649-
mlir::omp::VariableCaptureKind::ByCopy &&
2647+
omp::VariableCaptureKind::ByCopy &&
26502648
!isPtrTy)
26512649
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
26522650

@@ -2672,13 +2670,13 @@ static void processMapWithMembersOf(
26722670
llvm::OpenMPIRBuilder::MapInfosTy &combinedInfo, MapInfoData &mapData,
26732671
uint64_t mapDataIndex, bool isTargetParams) {
26742672
auto parentClause =
2675-
llvm::cast<mlir::omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
2673+
llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
26762674

26772675
// If we have a partial map (no parent referenced in the map clauses of the
26782676
// directive, only members) and only a single member, we do not need to bind
26792677
// the map of the member to the parent, we can pass the member separately.
26802678
if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
2681-
auto memberClause = llvm::cast<mlir::omp::MapInfoOp>(
2679+
auto memberClause = llvm::cast<omp::MapInfoOp>(
26822680
parentClause.getMembers()[0].getDefiningOp());
26832681
int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
26842682
// Note: Clang treats arrays with explicit bounds that fall into this
@@ -2715,11 +2713,9 @@ createAlteredByCaptureMap(MapInfoData &mapData,
27152713
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
27162714
// if it's declare target, skip it, it's handled separately.
27172715
if (!mapData.IsDeclareTarget[i]) {
2718-
auto mapOp =
2719-
mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
2720-
mlir::omp::VariableCaptureKind captureKind =
2721-
mapOp.getMapCaptureType().value_or(
2722-
mlir::omp::VariableCaptureKind::ByRef);
2716+
auto mapOp = dyn_cast_if_present<omp::MapInfoOp>(mapData.MapClause[i]);
2717+
omp::VariableCaptureKind captureKind =
2718+
mapOp.getMapCaptureType().value_or(omp::VariableCaptureKind::ByRef);
27232719
bool isPtrTy = checkIfPointerMap(mapOp);
27242720

27252721
// Currently handles array sectioning lowerbound case, but more
@@ -2730,7 +2726,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
27302726
// function mimics some of the logic from Clang that we require for
27312727
// kernel argument passing from host -> device.
27322728
switch (captureKind) {
2733-
case mlir::omp::VariableCaptureKind::ByRef: {
2729+
case omp::VariableCaptureKind::ByRef: {
27342730
llvm::Value *newV = mapData.Pointers[i];
27352731
std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
27362732
moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
@@ -2743,7 +2739,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
27432739
"array_offset");
27442740
mapData.Pointers[i] = newV;
27452741
} break;
2746-
case mlir::omp::VariableCaptureKind::ByCopy: {
2742+
case omp::VariableCaptureKind::ByCopy: {
27472743
llvm::Type *type = mapData.BaseType[i];
27482744
llvm::Value *newV;
27492745
if (mapData.Pointers[i]->getType()->isPointerTy())
@@ -2765,8 +2761,8 @@ createAlteredByCaptureMap(MapInfoData &mapData,
27652761
mapData.Pointers[i] = newV;
27662762
mapData.BasePointers[i] = newV;
27672763
} break;
2768-
case mlir::omp::VariableCaptureKind::This:
2769-
case mlir::omp::VariableCaptureKind::VLAType:
2764+
case omp::VariableCaptureKind::This:
2765+
case omp::VariableCaptureKind::VLAType:
27702766
mapData.MapClause[i]->emitOpError("Unhandled capture kind");
27712767
break;
27722768
}
@@ -2807,7 +2803,7 @@ static void genMapInfos(llvm::IRBuilderBase &builder,
28072803
if (mapData.IsAMember[i])
28082804
continue;
28092805

2810-
auto mapInfoOp = mlir::dyn_cast<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
2806+
auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
28112807
if (!mapInfoOp.getMembers().empty()) {
28122808
processMapWithMembersOf(moduleTranslation, builder, *ompBuilder, dl,
28132809
combinedInfo, mapData, i, isTargetParams);

0 commit comments

Comments
 (0)