Skip to content

[Clang][OpenMP] Fix mapping of structs to device #75642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 115 additions & 33 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6811,8 +6811,10 @@ class MappableExprsHandler {
OpenMPMapClauseKind MapType, ArrayRef<OpenMPMapModifierKind> MapModifiers,
ArrayRef<OpenMPMotionModifierKind> MotionModifiers,
OMPClauseMappableExprCommon::MappableExprComponentListRef Components,
MapCombinedInfoTy &CombinedInfo, StructRangeInfoTy &PartialStruct,
bool IsFirstComponentList, bool IsImplicit,
MapCombinedInfoTy &CombinedInfo,
MapCombinedInfoTy &StructBaseCombinedInfo,
StructRangeInfoTy &PartialStruct, bool IsFirstComponentList,
bool IsImplicit, bool GenerateAllInfoForClauses,
const ValueDecl *Mapper = nullptr, bool ForDeviceAddr = false,
const ValueDecl *BaseDecl = nullptr, const Expr *MapExpr = nullptr,
ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
Expand Down Expand Up @@ -7098,6 +7100,25 @@ class MappableExprsHandler {
bool IsNonContiguous = CombinedInfo.NonContigInfo.IsNonContiguous;
bool IsPrevMemberReference = false;

// We need to check if we will be encountering any MEs. If we do not
// encounter any ME expression it means we will be mapping the whole struct.
// In that case we need to skip adding an entry for the struct to the
// CombinedInfo list and instead add an entry to the StructBaseCombinedInfo
// list only when generating all info for clauses.
bool IsMappingWholeStruct = true;
if (!GenerateAllInfoForClauses) {
IsMappingWholeStruct = false;
} else {
for (auto TempI = I; TempI != CE; ++TempI) {
const MemberExpr *PossibleME =
dyn_cast<MemberExpr>(TempI->getAssociatedExpression());
if (PossibleME) {
IsMappingWholeStruct = false;
break;
}
}
}

for (; I != CE; ++I) {
// If the current component is member of a struct (parent struct) mark it.
if (!EncounteredME) {
Expand Down Expand Up @@ -7317,21 +7338,41 @@ class MappableExprsHandler {
break;
}
llvm::Value *Size = getExprTypeSize(I->getAssociatedExpression());
// Skip adding an entry in the CurInfo of this combined entry if the
// whole struct is currently being mapped. The struct needs to be added
// in the first position before any data internal to the struct is being
// mapped.
if (!IsMemberPointerOrAddr ||
(Next == CE && MapType != OMPC_MAP_unknown)) {
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
CombinedInfo.BasePointers.push_back(BP.getPointer());
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
CombinedInfo.Pointers.push_back(LB.getPointer());
CombinedInfo.Sizes.push_back(
CGF.Builder.CreateIntCast(Size, CGF.Int64Ty, /*isSigned=*/true));
CombinedInfo.NonContigInfo.Dims.push_back(IsNonContiguous ? DimSize
: 1);
if (!IsMappingWholeStruct) {
CombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
CombinedInfo.BasePointers.push_back(BP.getPointer());
CombinedInfo.DevicePtrDecls.push_back(nullptr);
CombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
CombinedInfo.Pointers.push_back(LB.getPointer());
CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
Size, CGF.Int64Ty, /*isSigned=*/true));
CombinedInfo.NonContigInfo.Dims.push_back(IsNonContiguous ? DimSize
: 1);
} else {
StructBaseCombinedInfo.Exprs.emplace_back(MapDecl, MapExpr);
StructBaseCombinedInfo.BasePointers.push_back(BP.getPointer());
StructBaseCombinedInfo.DevicePtrDecls.push_back(nullptr);
StructBaseCombinedInfo.DevicePointers.push_back(DeviceInfoTy::None);
StructBaseCombinedInfo.Pointers.push_back(LB.getPointer());
StructBaseCombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
Size, CGF.Int64Ty, /*isSigned=*/true));
StructBaseCombinedInfo.NonContigInfo.Dims.push_back(
IsNonContiguous ? DimSize : 1);
}

// If Mapper is valid, the last component inherits the mapper.
bool HasMapper = Mapper && Next == CE;
CombinedInfo.Mappers.push_back(HasMapper ? Mapper : nullptr);
if (!IsMappingWholeStruct)
CombinedInfo.Mappers.push_back(HasMapper ? Mapper : nullptr);
else
StructBaseCombinedInfo.Mappers.push_back(HasMapper ? Mapper
: nullptr);

// We need to add a pointer flag for each map that comes from the
// same expression except for the first one. We also need to signal
Expand Down Expand Up @@ -7363,7 +7404,10 @@ class MappableExprsHandler {
}
}

CombinedInfo.Types.push_back(Flags);
if (!IsMappingWholeStruct)
CombinedInfo.Types.push_back(Flags);
else
StructBaseCombinedInfo.Types.push_back(Flags);
}

// If we have encountered a member expression so far, keep track of the
Expand Down Expand Up @@ -7954,8 +7998,10 @@ class MappableExprsHandler {

for (const auto &Data : Info) {
StructRangeInfoTy PartialStruct;
// Temporary generated information.
// Current struct information:
MapCombinedInfoTy CurInfo;
// Current struct base information:
MapCombinedInfoTy StructBaseCurInfo;
const Decl *D = Data.first;
const ValueDecl *VD = cast_or_null<ValueDecl>(D);
for (const auto &M : Data.second) {
Expand All @@ -7965,29 +8011,55 @@ class MappableExprsHandler {

// Remember the current base pointer index.
unsigned CurrentBasePointersIdx = CurInfo.BasePointers.size();
unsigned StructBasePointersIdx =
StructBaseCurInfo.BasePointers.size();
CurInfo.NonContigInfo.IsNonContiguous =
L.Components.back().isNonContiguous();
generateInfoForComponentList(
L.MapType, L.MapModifiers, L.MotionModifiers, L.Components,
CurInfo, PartialStruct, /*IsFirstComponentList=*/false,
L.IsImplicit, L.Mapper, L.ForDeviceAddr, VD, L.VarRef);
CurInfo, StructBaseCurInfo, PartialStruct,
/*IsFirstComponentList=*/false, L.IsImplicit,
/*GenerateAllInfoForClauses*/ true, L.Mapper, L.ForDeviceAddr, VD,
L.VarRef);

// If this entry relates with a device pointer, set the relevant
// If this entry relates to a device pointer, set the relevant
// declaration and add the 'return pointer' flag.
if (L.ReturnDevicePointer) {
assert(CurInfo.BasePointers.size() > CurrentBasePointersIdx &&
// Check whether a value was added to either CurInfo or
// StructBaseCurInfo and error if no value was added to either of
// them:
assert((CurrentBasePointersIdx < CurInfo.BasePointers.size() ||
StructBasePointersIdx <
StructBaseCurInfo.BasePointers.size()) &&
"Unexpected number of mapped base pointers.");

// Choose a base pointer index which is always valid:
const ValueDecl *RelevantVD =
L.Components.back().getAssociatedDeclaration();
assert(RelevantVD &&
"No relevant declaration related with device pointer??");

CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD;
CurInfo.DevicePointers[CurrentBasePointersIdx] =
L.ForDeviceAddr ? DeviceInfoTy::Address : DeviceInfoTy::Pointer;
CurInfo.Types[CurrentBasePointersIdx] |=
OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
// If StructBaseCurInfo has been updated this iteration then work on
// the first new entry added to it i.e. make sure that when multiple
// values are added to any of the lists, the first value added is
// being modified by the assignments below (not the last value
// added).
if (StructBasePointersIdx < StructBaseCurInfo.BasePointers.size()) {
StructBaseCurInfo.DevicePtrDecls[StructBasePointersIdx] =
RelevantVD;
StructBaseCurInfo.DevicePointers[StructBasePointersIdx] =
L.ForDeviceAddr ? DeviceInfoTy::Address
: DeviceInfoTy::Pointer;
StructBaseCurInfo.Types[StructBasePointersIdx] |=
OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
} else {
CurInfo.DevicePtrDecls[CurrentBasePointersIdx] = RelevantVD;
CurInfo.DevicePointers[CurrentBasePointersIdx] =
L.ForDeviceAddr ? DeviceInfoTy::Address
: DeviceInfoTy::Pointer;
CurInfo.Types[CurrentBasePointersIdx] |=
OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
}
}
}
}
Expand Down Expand Up @@ -8034,17 +8106,24 @@ class MappableExprsHandler {
CurInfo.Mappers.push_back(nullptr);
}
}

// Unify entries in one list making sure the struct mapping precedes the
// individual fields:
MapCombinedInfoTy UnionCurInfo;
UnionCurInfo.append(StructBaseCurInfo);
UnionCurInfo.append(CurInfo);

// If there is an entry in PartialStruct it means we have a struct with
// individual members mapped. Emit an extra combined entry.
if (PartialStruct.Base.isValid()) {
CurInfo.NonContigInfo.Dims.push_back(0);
emitCombinedEntry(CombinedInfo, CurInfo.Types, PartialStruct,
UnionCurInfo.NonContigInfo.Dims.push_back(0);
// Emit a combined entry:
emitCombinedEntry(CombinedInfo, UnionCurInfo.Types, PartialStruct,
/*IsMapThis*/ !VD, OMPBuilder, VD);
}

// We need to append the results of this capture to what we already
// have.
CombinedInfo.append(CurInfo);
// We need to append the results of this capture to what we already have.
CombinedInfo.append(UnionCurInfo);
}
// Append data for use_device_ptr clauses.
CombinedInfo.append(UseDeviceDataCombinedInfo);
Expand Down Expand Up @@ -8554,6 +8633,7 @@ class MappableExprsHandler {
// Associated with a capture, because the mapping flags depend on it.
// Go through all of the elements with the overlapped elements.
bool IsFirstComponentList = true;
MapCombinedInfoTy StructBaseCombinedInfo;
for (const auto &Pair : OverlappedData) {
const MapData &L = *Pair.getFirst();
OMPClauseMappableExprCommon::MappableExprComponentListRef Components;
Expand All @@ -8568,7 +8648,8 @@ class MappableExprsHandler {
OverlappedComponents = Pair.getSecond();
generateInfoForComponentList(
MapType, MapModifiers, std::nullopt, Components, CombinedInfo,
PartialStruct, IsFirstComponentList, IsImplicit, Mapper,
StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
/*ForDeviceAddr=*/false, VD, VarRef, OverlappedComponents);
IsFirstComponentList = false;
}
Expand All @@ -8584,10 +8665,11 @@ class MappableExprsHandler {
L;
auto It = OverlappedData.find(&L);
if (It == OverlappedData.end())
generateInfoForComponentList(MapType, MapModifiers, std::nullopt,
Components, CombinedInfo, PartialStruct,
IsFirstComponentList, IsImplicit, Mapper,
/*ForDeviceAddr=*/false, VD, VarRef);
generateInfoForComponentList(
MapType, MapModifiers, std::nullopt, Components, CombinedInfo,
StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
/*ForDeviceAddr=*/false, VD, VarRef);
IsFirstComponentList = false;
}
}
Expand Down
Loading