Skip to content

Commit 35309db

Browse files
committed
[OpenMP][OMPIRBuilder] Migrate MapCombinedInfoTy from Clang to OpenMPIRBuilder
This patch migrates the MapCombinedInfoTy from Clang codegen to OpenMPIRBuilder. Differential Revision: https://reviews.llvm.org/D149666
1 parent 147a561 commit 35309db

File tree

2 files changed

+78
-60
lines changed

2 files changed

+78
-60
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 35 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -6831,67 +6831,30 @@ class MappableExprsHandler {
68316831
const Expr *getMapExpr() const { return MapExpr; }
68326832
};
68336833

6834-
/// Class that associates information with a base pointer to be passed to the
6835-
/// runtime library.
6836-
class BasePointerInfo {
6837-
/// The base pointer.
6838-
llvm::Value *Ptr = nullptr;
6839-
/// The base declaration that refers to this device pointer, or null if
6840-
/// there is none.
6841-
const ValueDecl *DevPtrDecl = nullptr;
6842-
6843-
public:
6844-
BasePointerInfo(llvm::Value *Ptr, const ValueDecl *DevPtrDecl = nullptr)
6845-
: Ptr(Ptr), DevPtrDecl(DevPtrDecl) {}
6846-
llvm::Value *operator*() const { return Ptr; }
6847-
const ValueDecl *getDevicePtrDecl() const { return DevPtrDecl; }
6848-
void setDevicePtrDecl(const ValueDecl *D) { DevPtrDecl = D; }
6849-
};
6850-
6834+
using MapBaseValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
6835+
using MapValuesArrayTy = llvm::OpenMPIRBuilder::MapValuesArrayTy;
6836+
using MapFlagsArrayTy = llvm::OpenMPIRBuilder::MapFlagsArrayTy;
6837+
using MapDimArrayTy = llvm::OpenMPIRBuilder::MapDimArrayTy;
6838+
using MapNonContiguousArrayTy =
6839+
llvm::OpenMPIRBuilder::MapNonContiguousArrayTy;
68516840
using MapExprsArrayTy = SmallVector<MappingExprInfo, 4>;
6852-
using MapBaseValuesArrayTy = SmallVector<BasePointerInfo, 4>;
6853-
using MapValuesArrayTy = SmallVector<llvm::Value *, 4>;
6854-
using MapFlagsArrayTy = SmallVector<OpenMPOffloadMappingFlags, 4>;
6855-
using MapMappersArrayTy = SmallVector<const ValueDecl *, 4>;
6856-
using MapDimArrayTy = SmallVector<uint64_t, 4>;
6857-
using MapNonContiguousArrayTy = SmallVector<MapValuesArrayTy, 4>;
6841+
using MapValueDeclsArrayTy = SmallVector<const ValueDecl *, 4>;
68586842

68596843
/// This structure contains combined information generated for mappable
68606844
/// clauses, including base pointers, pointers, sizes, map types, user-defined
68616845
/// mappers, and non-contiguous information.
6862-
struct MapCombinedInfoTy {
6863-
struct StructNonContiguousInfo {
6864-
bool IsNonContiguous = false;
6865-
MapDimArrayTy Dims;
6866-
MapNonContiguousArrayTy Offsets;
6867-
MapNonContiguousArrayTy Counts;
6868-
MapNonContiguousArrayTy Strides;
6869-
};
6846+
struct MapCombinedInfoTy : llvm::OpenMPIRBuilder::MapInfosTy {
68706847
MapExprsArrayTy Exprs;
6871-
MapBaseValuesArrayTy BasePointers;
6872-
MapValuesArrayTy Pointers;
6873-
MapValuesArrayTy Sizes;
6874-
MapFlagsArrayTy Types;
6875-
MapMappersArrayTy Mappers;
6876-
StructNonContiguousInfo NonContigInfo;
6848+
MapValueDeclsArrayTy Mappers;
6849+
MapValueDeclsArrayTy DevicePtrDecls;
68776850

68786851
/// Append arrays in \a CurInfo.
68796852
void append(MapCombinedInfoTy &CurInfo) {
68806853
Exprs.append(CurInfo.Exprs.begin(), CurInfo.Exprs.end());
6881-
BasePointers.append(CurInfo.BasePointers.begin(),
6882-
CurInfo.BasePointers.end());
6883-
Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end());
6884-
Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end());
6885-
Types.append(CurInfo.Types.begin(), CurInfo.Types.end());
6854+
DevicePtrDecls.append(CurInfo.DevicePtrDecls.begin(),
6855+
CurInfo.DevicePtrDecls.end());
68866856
Mappers.append(CurInfo.Mappers.begin(), CurInfo.Mappers.end());
6887-
NonContigInfo.Dims.append(CurInfo.NonContigInfo.Dims.begin(),
6888-
CurInfo.NonContigInfo.Dims.end());
6889-
NonContigInfo.Offsets.append(CurInfo.NonContigInfo.Offsets.begin(),
6890-
CurInfo.NonContigInfo.Offsets.end());
6891-
NonContigInfo.Counts.append(CurInfo.NonContigInfo.Counts.begin(),
6892-
CurInfo.NonContigInfo.Counts.end());
6893-
NonContigInfo.Strides.append(CurInfo.NonContigInfo.Strides.begin(),
6894-
CurInfo.NonContigInfo.Strides.end());
6857+
llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo);
68956858
}
68966859
};
68976860

@@ -7638,6 +7601,7 @@ class MappableExprsHandler {
76387601
assert(Size && "Failed to determine structure size");
76397602
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
76407603
CombinedInfo.BasePointers.push_back(BP.getPointer());
7604+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
76417605
CombinedInfo.Pointers.push_back(LB.getPointer());
76427606
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
76437607
Size, CGF.Int64Ty, /*isSigned=*/true));
@@ -7649,6 +7613,7 @@ class MappableExprsHandler {
76497613
}
76507614
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
76517615
CombinedInfo.BasePointers.push_back(BP.getPointer());
7616+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
76527617
CombinedInfo.Pointers.push_back(LB.getPointer());
76537618
Size = CGF.Builder.CreatePtrDiff(
76547619
CGF.Int8Ty, CGF.Builder.CreateConstGEP(HB, 1).getPointer(),
@@ -7666,6 +7631,7 @@ class MappableExprsHandler {
76667631
(Next == CE && MapType != OMPC_MAP_unknown)) {
76677632
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
76687633
CombinedInfo.BasePointers.push_back(BP.getPointer());
7634+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
76697635
CombinedInfo.Pointers.push_back(LB.getPointer());
76707636
CombinedInfo.Sizes.push_back(
76717637
CGF.Builder.CreateIntCast(Size, CGF.Int64Ty, /*isSigned=*/true));
@@ -8168,7 +8134,8 @@ class MappableExprsHandler {
81688134
[&UseDeviceDataCombinedInfo](const ValueDecl *VD, llvm::Value *Ptr,
81698135
CodeGenFunction &CGF) {
81708136
UseDeviceDataCombinedInfo.Exprs.push_back(VD);
8171-
UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr, VD);
8137+
UseDeviceDataCombinedInfo.BasePointers.emplace_back(Ptr);
8138+
UseDeviceDataCombinedInfo.DevicePtrDecls.emplace_back(VD);
81728139
UseDeviceDataCombinedInfo.Pointers.push_back(Ptr);
81738140
UseDeviceDataCombinedInfo.Sizes.push_back(
81748141
llvm::Constant::getNullValue(CGF.Int64Ty));
@@ -8337,8 +8304,7 @@ class MappableExprsHandler {
83378304
assert(RelevantVD &&
83388305
"No relevant declaration related with device pointer??");
83398306

8340-
CurInfo.BasePointers[CurrentBasePointersIdx].setDevicePtrDecl(
8341-
RelevantVD);
8307+
CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD;
83428308
CurInfo.Types[CurrentBasePointersIdx] |=
83438309
OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
83448310
}
@@ -8377,7 +8343,8 @@ class MappableExprsHandler {
83778343
OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
83788344
}
83798345
CurInfo.Exprs.push_back(L.VD);
8380-
CurInfo.BasePointers.emplace_back(BasePtr, L.VD);
8346+
CurInfo.BasePointers.emplace_back(BasePtr);
8347+
CurInfo.DevicePtrDecls.emplace_back(L.VD);
83818348
CurInfo.Pointers.push_back(Ptr);
83828349
CurInfo.Sizes.push_back(
83838350
llvm::Constant::getNullValue(this->CGF.Int64Ty));
@@ -8472,6 +8439,7 @@ class MappableExprsHandler {
84728439
CombinedInfo.Exprs.push_back(VD);
84738440
// Base is the base of the struct
84748441
CombinedInfo.BasePointers.push_back(PartialStruct.Base.getPointer());
8442+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
84758443
// Pointer is the address of the lowest element
84768444
llvm::Value *LB = LBAddr.getPointer();
84778445
const CXXMethodDecl *MD =
@@ -8593,6 +8561,7 @@ class MappableExprsHandler {
85938561
VDLVal.getPointer(CGF));
85948562
CombinedInfo.Exprs.push_back(VD);
85958563
CombinedInfo.BasePointers.push_back(ThisLVal.getPointer(CGF));
8564+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
85968565
CombinedInfo.Pointers.push_back(ThisLValVal.getPointer(CGF));
85978566
CombinedInfo.Sizes.push_back(
85988567
CGF.Builder.CreateIntCast(CGF.getTypeSize(CGF.getContext().VoidPtrTy),
@@ -8619,6 +8588,7 @@ class MappableExprsHandler {
86198588
VDLVal.getPointer(CGF));
86208589
CombinedInfo.Exprs.push_back(VD);
86218590
CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
8591+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
86228592
CombinedInfo.Pointers.push_back(VarLValVal.getPointer(CGF));
86238593
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
86248594
CGF.getTypeSize(
@@ -8630,6 +8600,7 @@ class MappableExprsHandler {
86308600
VDLVal.getPointer(CGF));
86318601
CombinedInfo.Exprs.push_back(VD);
86328602
CombinedInfo.BasePointers.push_back(VarLVal.getPointer(CGF));
8603+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
86338604
CombinedInfo.Pointers.push_back(VarRVal.getScalarVal());
86348605
CombinedInfo.Sizes.push_back(llvm::ConstantInt::get(CGF.Int64Ty, 0));
86358606
}
@@ -8654,7 +8625,7 @@ class MappableExprsHandler {
86548625
OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF |
86558626
OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT))
86568627
continue;
8657-
llvm::Value *BasePtr = LambdaPointers.lookup(*BasePointers[I]);
8628+
llvm::Value *BasePtr = LambdaPointers.lookup(BasePointers[I]);
86588629
assert(BasePtr && "Unable to find base lambda address.");
86598630
int TgtIdx = -1;
86608631
for (unsigned J = I; J > 0; --J) {
@@ -8696,7 +8667,8 @@ class MappableExprsHandler {
86968667
// pass its value.
86978668
if (VD && (DevPointersMap.count(VD) || HasDevAddrsMap.count(VD))) {
86988669
CombinedInfo.Exprs.push_back(VD);
8699-
CombinedInfo.BasePointers.emplace_back(Arg, VD);
8670+
CombinedInfo.BasePointers.emplace_back(Arg);
8671+
CombinedInfo.DevicePtrDecls.emplace_back(VD);
87008672
CombinedInfo.Pointers.push_back(Arg);
87018673
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
87028674
CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty,
@@ -8938,6 +8910,7 @@ class MappableExprsHandler {
89388910
if (CI.capturesThis()) {
89398911
CombinedInfo.Exprs.push_back(nullptr);
89408912
CombinedInfo.BasePointers.push_back(CV);
8913+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
89418914
CombinedInfo.Pointers.push_back(CV);
89428915
const auto *PtrTy = cast<PointerType>(RI.getType().getTypePtr());
89438916
CombinedInfo.Sizes.push_back(
@@ -8950,6 +8923,7 @@ class MappableExprsHandler {
89508923
const VarDecl *VD = CI.getCapturedVar();
89518924
CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
89528925
CombinedInfo.BasePointers.push_back(CV);
8926+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
89538927
CombinedInfo.Pointers.push_back(CV);
89548928
if (!RI.getType()->isAnyPointerType()) {
89558929
// We have to signal to the runtime captures passed by value that are
@@ -8981,6 +8955,7 @@ class MappableExprsHandler {
89818955
auto I = FirstPrivateDecls.find(VD);
89828956
CombinedInfo.Exprs.push_back(VD->getCanonicalDecl());
89838957
CombinedInfo.BasePointers.push_back(CV);
8958+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
89848959
if (I != FirstPrivateDecls.end() && ElementType->isAnyPointerType()) {
89858960
Address PtrAddr = CGF.EmitLoadOfReference(CGF.MakeAddrLValue(
89868961
CV, ElementType, CGF.getContext().getDeclAlign(VD),
@@ -9266,7 +9241,7 @@ static void emitOffloadingArrays(
92669241
}
92679242

92689243
for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
9269-
llvm::Value *BPVal = *CombinedInfo.BasePointers[I];
9244+
llvm::Value *BPVal = CombinedInfo.BasePointers[I];
92709245
llvm::Value *BP = CGF.Builder.CreateConstInBoundsGEP2_32(
92719246
llvm::ArrayType::get(CGM.VoidPtrTy, Info.NumberOfPtrs),
92729247
Info.RTArgs.BasePointersArray, 0, I);
@@ -9277,8 +9252,7 @@ static void emitOffloadingArrays(
92779252
CGF.Builder.CreateStore(BPVal, BPAddr);
92789253

92799254
if (Info.requiresDevicePointerInfo())
9280-
if (const ValueDecl *DevVD =
9281-
CombinedInfo.BasePointers[I].getDevicePtrDecl())
9255+
if (const ValueDecl *DevVD = CombinedInfo.DevicePtrDecls[I])
92829256
Info.CaptureDeviceAddrMap.try_emplace(DevVD, BPAddr);
92839257

92849258
llvm::Value *PVal = CombinedInfo.Pointers[I];
@@ -9592,7 +9566,7 @@ void CGOpenMPRuntime::emitUserDefinedMapper(const OMPDeclareMapperDecl *D,
95929566
// Fill up the runtime mapper handle for all components.
95939567
for (unsigned I = 0; I < Info.BasePointers.size(); ++I) {
95949568
llvm::Value *CurBaseArg = MapperCGF.Builder.CreateBitCast(
9595-
*Info.BasePointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy));
9569+
Info.BasePointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy));
95969570
llvm::Value *CurBeginArg = MapperCGF.Builder.CreateBitCast(
95979571
Info.Pointers[I], CGM.getTypes().ConvertTypeForMem(C.VoidPtrTy));
95989572
llvm::Value *CurSizeArg = Info.Sizes[I];
@@ -10028,6 +10002,7 @@ void CGOpenMPRuntime::emitTargetCall(
1002810002
if (CI->capturesVariableArrayType()) {
1002910003
CurInfo.Exprs.push_back(nullptr);
1003010004
CurInfo.BasePointers.push_back(*CV);
10005+
CurInfo.DevicePtrDecls.push_back(nullptr);
1003110006
CurInfo.Pointers.push_back(*CV);
1003210007
CurInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
1003310008
CGF.getTypeSize(RI->getType()), CGF.Int64Ty, /*isSigned=*/true));

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,6 +1445,49 @@ class OpenMPIRBuilder {
14451445
bool separateBeginEndCalls() { return SeparateBeginEndCalls; }
14461446
};
14471447

1448+
using MapValuesArrayTy = SmallVector<Value *, 4>;
1449+
using MapFlagsArrayTy = SmallVector<omp::OpenMPOffloadMappingFlags, 4>;
1450+
using MapNamesArrayTy = SmallVector<Constant *, 4>;
1451+
using MapDimArrayTy = SmallVector<uint64_t, 4>;
1452+
using MapNonContiguousArrayTy = SmallVector<MapValuesArrayTy, 4>;
1453+
1454+
/// This structure contains combined information generated for mappable
1455+
/// clauses, including base pointers, pointers, sizes, map types, user-defined
1456+
/// mappers, and non-contiguous information.
1457+
struct MapInfosTy {
1458+
struct StructNonContiguousInfo {
1459+
bool IsNonContiguous = false;
1460+
MapDimArrayTy Dims;
1461+
MapNonContiguousArrayTy Offsets;
1462+
MapNonContiguousArrayTy Counts;
1463+
MapNonContiguousArrayTy Strides;
1464+
};
1465+
MapValuesArrayTy BasePointers;
1466+
MapValuesArrayTy Pointers;
1467+
MapValuesArrayTy Sizes;
1468+
MapFlagsArrayTy Types;
1469+
MapNamesArrayTy Names;
1470+
StructNonContiguousInfo NonContigInfo;
1471+
1472+
/// Append arrays in \a CurInfo.
1473+
void append(MapInfosTy &CurInfo) {
1474+
BasePointers.append(CurInfo.BasePointers.begin(),
1475+
CurInfo.BasePointers.end());
1476+
Pointers.append(CurInfo.Pointers.begin(), CurInfo.Pointers.end());
1477+
Sizes.append(CurInfo.Sizes.begin(), CurInfo.Sizes.end());
1478+
Types.append(CurInfo.Types.begin(), CurInfo.Types.end());
1479+
Names.append(CurInfo.Names.begin(), CurInfo.Names.end());
1480+
NonContigInfo.Dims.append(CurInfo.NonContigInfo.Dims.begin(),
1481+
CurInfo.NonContigInfo.Dims.end());
1482+
NonContigInfo.Offsets.append(CurInfo.NonContigInfo.Offsets.begin(),
1483+
CurInfo.NonContigInfo.Offsets.end());
1484+
NonContigInfo.Counts.append(CurInfo.NonContigInfo.Counts.begin(),
1485+
CurInfo.NonContigInfo.Counts.end());
1486+
NonContigInfo.Strides.append(CurInfo.NonContigInfo.Strides.begin(),
1487+
CurInfo.NonContigInfo.Strides.end());
1488+
}
1489+
};
1490+
14481491
/// Emit the arguments to be passed to the runtime library based on the
14491492
/// arrays of base pointers, pointers, sizes, map types, and mappers. If
14501493
/// ForEndCall, emit map types to be passed for the end of the region instead

0 commit comments

Comments
 (0)