Skip to content

Commit 4ef6587

Browse files
authored
[Clang][OpenMP] Fix mapping of structs to device (#75642)
Fix mapping of structs to device. The following example fails: ``` #include <stdio.h> #include <stdlib.h> struct Descriptor { int *datum; long int x; int xi; long int arr[1][30]; }; int main() { Descriptor dat = Descriptor(); dat.datum = (int *)malloc(sizeof(int)*10); dat.xi = 3; dat.arr[0][0] = 1; #pragma omp target enter data map(to: dat.datum[:10]) map(to: dat) #pragma omp target { dat.xi = 4; dat.datum[dat.arr[0][0]] = dat.xi; } #pragma omp target exit data map(from: dat) return 0; } ``` This is a rework of the previous attempt: #72410
1 parent 58a2c4e commit 4ef6587

File tree

3 files changed

+401
-33
lines changed

3 files changed

+401
-33
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 115 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6811,8 +6811,10 @@ class MappableExprsHandler {
68116811
OpenMPMapClauseKind MapType, ArrayRef<OpenMPMapModifierKind> MapModifiers,
68126812
ArrayRef<OpenMPMotionModifierKind> MotionModifiers,
68136813
OMPClauseMappableExprCommon::MappableExprComponentListRef Components,
6814-
MapCombinedInfoTy &CombinedInfo, StructRangeInfoTy &PartialStruct,
6815-
bool IsFirstComponentList, bool IsImplicit,
6814+
MapCombinedInfoTy &CombinedInfo,
6815+
MapCombinedInfoTy &StructBaseCombinedInfo,
6816+
StructRangeInfoTy &PartialStruct, bool IsFirstComponentList,
6817+
bool IsImplicit, bool GenerateAllInfoForClauses,
68166818
const ValueDecl *Mapper = nullptr, bool ForDeviceAddr = false,
68176819
const ValueDecl *BaseDecl = nullptr, const Expr *MapExpr = nullptr,
68186820
ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
@@ -7098,6 +7100,25 @@ class MappableExprsHandler {
70987100
bool IsNonContiguous = CombinedInfo.NonContigInfo.IsNonContiguous;
70997101
bool IsPrevMemberReference = false;
71007102

7103+
// We need to check if we will be encountering any MEs. If we do not
7104+
// encounter any ME expression it means we will be mapping the whole struct.
7105+
// In that case we need to skip adding an entry for the struct to the
7106+
// CombinedInfo list and instead add an entry to the StructBaseCombinedInfo
7107+
// list only when generating all info for clauses.
7108+
bool IsMappingWholeStruct = true;
7109+
if (!GenerateAllInfoForClauses) {
7110+
IsMappingWholeStruct = false;
7111+
} else {
7112+
for (auto TempI = I; TempI != CE; ++TempI) {
7113+
const MemberExpr *PossibleME =
7114+
dyn_cast<MemberExpr>(TempI->getAssociatedExpression());
7115+
if (PossibleME) {
7116+
IsMappingWholeStruct = false;
7117+
break;
7118+
}
7119+
}
7120+
}
7121+
71017122
for (; I != CE; ++I) {
71027123
// If the current component is member of a struct (parent struct) mark it.
71037124
if (!EncounteredME) {
@@ -7317,21 +7338,41 @@ class MappableExprsHandler {
73177338
break;
73187339
}
73197340
llvm::Value *Size = getExprTypeSize(I->getAssociatedExpression());
7341+
// Skip adding an entry in the CurInfo of this combined entry if the
7342+
// whole struct is currently being mapped. The struct needs to be added
7343+
// in the first position before any data internal to the struct is being
7344+
// mapped.
73207345
if (!IsMemberPointerOrAddr ||
73217346
(Next == CE && MapType != OMPC_MAP_unknown)) {
7322-
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
7323-
CombinedInfo.BasePointers.push_back(BP.getPointer());
7324-
CombinedInfo.DevicePtrDecls.push_back(nullptr);
7325-
CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
7326-
CombinedInfo.Pointers.push_back(LB.getPointer());
7327-
CombinedInfo.Sizes.push_back(
7328-
CGF.Builder.CreateIntCast(Size, CGF.Int64Ty, /*isSigned=*/true));
7329-
CombinedInfo.NonContigInfo.Dims.push_back(IsNonContiguous ? DimSize
7330-
: 1);
7347+
if (!IsMappingWholeStruct) {
7348+
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
7349+
CombinedInfo.BasePointers.push_back(BP.getPointer());
7350+
CombinedInfo.DevicePtrDecls.push_back(nullptr);
7351+
CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
7352+
CombinedInfo.Pointers.push_back(LB.getPointer());
7353+
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
7354+
Size, CGF.Int64Ty, /*isSigned=*/true));
7355+
CombinedInfo.NonContigInfo.Dims.push_back(IsNonContiguous ? DimSize
7356+
: 1);
7357+
} else {
7358+
StructBaseCombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
7359+
StructBaseCombinedInfo.BasePointers.push_back(BP.getPointer());
7360+
StructBaseCombinedInfo.DevicePtrDecls.push_back(nullptr);
7361+
StructBaseCombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
7362+
StructBaseCombinedInfo.Pointers.push_back(LB.getPointer());
7363+
StructBaseCombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
7364+
Size, CGF.Int64Ty, /*isSigned=*/true));
7365+
StructBaseCombinedInfo.NonContigInfo.Dims.push_back(
7366+
IsNonContiguous ? DimSize : 1);
7367+
}
73317368

73327369
// If Mapper is valid, the last component inherits the mapper.
73337370
bool HasMapper = Mapper && Next == CE;
7334-
CombinedInfo.Mappers.push_back(HasMapper ? Mapper : nullptr);
7371+
if (!IsMappingWholeStruct)
7372+
CombinedInfo.Mappers.push_back(HasMapper ? Mapper : nullptr);
7373+
else
7374+
StructBaseCombinedInfo.Mappers.push_back(HasMapper ? Mapper
7375+
: nullptr);
73357376

73367377
// We need to add a pointer flag for each map that comes from the
73377378
// same expression except for the first one. We also need to signal
@@ -7363,7 +7404,10 @@ class MappableExprsHandler {
73637404
}
73647405
}
73657406

7366-
CombinedInfo.Types.push_back(Flags);
7407+
if (!IsMappingWholeStruct)
7408+
CombinedInfo.Types.push_back(Flags);
7409+
else
7410+
StructBaseCombinedInfo.Types.push_back(Flags);
73677411
}
73687412

73697413
// If we have encountered a member expression so far, keep track of the
@@ -7954,8 +7998,10 @@ class MappableExprsHandler {
79547998

79557999
for (const auto &Data : Info) {
79568000
StructRangeInfoTy PartialStruct;
7957-
// Temporary generated information.
8001+
// Current struct information:
79588002
MapCombinedInfoTy CurInfo;
8003+
// Current struct base information:
8004+
MapCombinedInfoTy StructBaseCurInfo;
79598005
const Decl *D = Data.first;
79608006
const ValueDecl *VD = cast_or_null<ValueDecl>(D);
79618007
for (const auto &M : Data.second) {
@@ -7965,29 +8011,55 @@ class MappableExprsHandler {
79658011

79668012
// Remember the current base pointer index.
79678013
unsigned CurrentBasePointersIdx = CurInfo.BasePointers.size();
8014+
unsigned StructBasePointersIdx =
8015+
StructBaseCurInfo.BasePointers.size();
79688016
CurInfo.NonContigInfo.IsNonContiguous =
79698017
L.Components.back().isNonContiguous();
79708018
generateInfoForComponentList(
79718019
L.MapType, L.MapModifiers, L.MotionModifiers, L.Components,
7972-
CurInfo, PartialStruct, /*IsFirstComponentList=*/false,
7973-
L.IsImplicit, L.Mapper, L.ForDeviceAddr, VD, L.VarRef);
8020+
CurInfo, StructBaseCurInfo, PartialStruct,
8021+
/*IsFirstComponentList=*/false, L.IsImplicit,
8022+
/*GenerateAllInfoForClauses*/ true, L.Mapper, L.ForDeviceAddr, VD,
8023+
L.VarRef);
79748024

7975-
// If this entry relates with a device pointer, set the relevant
8025+
// If this entry relates to a device pointer, set the relevant
79768026
// declaration and add the 'return pointer' flag.
79778027
if (L.ReturnDevicePointer) {
7978-
assert(CurInfo.BasePointers.size() > CurrentBasePointersIdx &&
8028+
// Check whether a value was added to either CurInfo or
8029+
// StructBaseCurInfo and error if no value was added to either of
8030+
// them:
8031+
assert((CurrentBasePointersIdx < CurInfo.BasePointers.size() ||
8032+
StructBasePointersIdx <
8033+
StructBaseCurInfo.BasePointers.size()) &&
79798034
"Unexpected number of mapped base pointers.");
79808035

8036+
// Choose a base pointer index which is always valid:
79818037
const ValueDecl *RelevantVD =
79828038
L.Components.back().getAssociatedDeclaration();
79838039
assert(RelevantVD &&
79848040
"No relevant declaration related with device pointer??");
79858041

7986-
CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD;
7987-
CurInfo.DevicePointers[CurrentBasePointersIdx] =
7988-
L.ForDeviceAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer;
7989-
CurInfo.Types[CurrentBasePointersIdx] |=
7990-
OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
8042+
// If StructBaseCurInfo has been updated this iteration then work on
8043+
// the first new entry added to it i.e. make sure that when multiple
8044+
// values are added to any of the lists, the first value added is
8045+
// being modified by the assignments below (not the last value
8046+
// added).
8047+
if (StructBasePointersIdx < StructBaseCurInfo.BasePointers.size()) {
8048+
StructBaseCurInfo.DevicePtrDecls[StructBasePointersIdx] =
8049+
RelevantVD;
8050+
StructBaseCurInfo.DevicePointers[StructBasePointersIdx] =
8051+
L.ForDeviceAddr ? DeviceInfoTy::Address
8052+
: DeviceInfoTy::Pointer;
8053+
StructBaseCurInfo.Types[StructBasePointersIdx] |=
8054+
OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
8055+
} else {
8056+
CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD;
8057+
CurInfo.DevicePointers[CurrentBasePointersIdx] =
8058+
L.ForDeviceAddr ? DeviceInfoTy::Address
8059+
: DeviceInfoTy::Pointer;
8060+
CurInfo.Types[CurrentBasePointersIdx] |=
8061+
OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
8062+
}
79918063
}
79928064
}
79938065
}
@@ -8034,17 +8106,24 @@ class MappableExprsHandler {
80348106
CurInfo.Mappers.push_back(nullptr);
80358107
}
80368108
}
8109+
8110+
// Unify entries in one list making sure the struct mapping precedes the
8111+
// individual fields:
8112+
MapCombinedInfoTy UnionCurInfo;
8113+
UnionCurInfo.append(StructBaseCurInfo);
8114+
UnionCurInfo.append(CurInfo);
8115+
80378116
// If there is an entry in PartialStruct it means we have a struct with
80388117
// individual members mapped. Emit an extra combined entry.
80398118
if (PartialStruct.Base.isValid()) {
8040-
CurInfo.NonContigInfo.Dims.push_back(0);
8041-
emitCombinedEntry(CombinedInfo, CurInfo.Types, PartialStruct,
8119+
UnionCurInfo.NonContigInfo.Dims.push_back(0);
8120+
// Emit a combined entry:
8121+
emitCombinedEntry(CombinedInfo, UnionCurInfo.Types, PartialStruct,
80428122
/*IsMapThis*/ !VD, OMPBuilder, VD);
80438123
}
80448124

8045-
// We need to append the results of this capture to what we already
8046-
// have.
8047-
CombinedInfo.append(CurInfo);
8125+
// We need to append the results of this capture to what we already have.
8126+
CombinedInfo.append(UnionCurInfo);
80488127
}
80498128
// Append data for use_device_ptr clauses.
80508129
CombinedInfo.append(UseDeviceDataCombinedInfo);
@@ -8554,6 +8633,7 @@ class MappableExprsHandler {
85548633
// Associated with a capture, because the mapping flags depend on it.
85558634
// Go through all of the elements with the overlapped elements.
85568635
bool IsFirstComponentList = true;
8636+
MapCombinedInfoTy StructBaseCombinedInfo;
85578637
for (const auto &Pair : OverlappedData) {
85588638
const MapData &L = *Pair.getFirst();
85598639
OMPClauseMappableExprCommon::MappableExprComponentListRef Components;
@@ -8568,7 +8648,8 @@ class MappableExprsHandler {
85688648
OverlappedComponents = Pair.getSecond();
85698649
generateInfoForComponentList(
85708650
MapType, MapModifiers, std::nullopt, Components, CombinedInfo,
8571-
PartialStruct, IsFirstComponentList, IsImplicit, Mapper,
8651+
StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
8652+
IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
85728653
/*ForDeviceAddr=*/false, VD, VarRef, OverlappedComponents);
85738654
IsFirstComponentList = false;
85748655
}
@@ -8584,10 +8665,11 @@ class MappableExprsHandler {
85848665
L;
85858666
auto It = OverlappedData.find(&L);
85868667
if (It == OverlappedData.end())
8587-
generateInfoForComponentList(MapType, MapModifiers, std::nullopt,
8588-
Components, CombinedInfo, PartialStruct,
8589-
IsFirstComponentList, IsImplicit, Mapper,
8590-
/*ForDeviceAddr=*/false, VD, VarRef);
8668+
generateInfoForComponentList(
8669+
MapType, MapModifiers, std::nullopt, Components, CombinedInfo,
8670+
StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
8671+
IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
8672+
/*ForDeviceAddr=*/false, VD, VarRef);
85918673
IsFirstComponentList = false;
85928674
}
85938675
}

0 commit comments

Comments
 (0)