Skip to content

Commit 113f7a9

Browse files
committed
[OpenMP][MLIR] Descriptor explicit member map lowering changes
This is one of 3 PRs in a PR stack that aims to add support for explicit mapping of allocatable members in derived types. The primary changes in this PR are the OpenMPToLLVMIRTranslation.cpp changes, which are small and seek to alter the current member mapping to add an additional map insertion for pointers. Effectively, if the member is a pointer (currently indicated by having a varPtrPtr field) we add an additional map for the pointer and then alter the subsequent mapping of the member (the data) to utilise the member rather than the parents base pointer. This appears to be necessary in certain cases when mapping pointer data within record types to avoid segfaulting on device (due to incorrect data mapping). In general this record type mapping may be simplifiable in the future. There are also additions of tests which should help to showcase the affect of the changes above.
1 parent 76edf72 commit 113f7a9

File tree

7 files changed

+197
-130
lines changed

7 files changed

+197
-130
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
895895
TypeAttr:$var_type,
896896
Optional<OpenMP_PointerLikeType>:$var_ptr_ptr,
897897
Variadic<OpenMP_PointerLikeType>:$members,
898-
OptionalAttr<AnyIntElementsAttr>:$members_index,
898+
OptionalAttr<IndexListArrayAttr>:$members_index,
899899
Variadic<OpenMP_MapBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
900900
OptionalAttr<UI64Attr>:$map_type,
901901
OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,16 +1395,15 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
13951395
}
13961396

13971397
static ParseResult parseMembersIndex(OpAsmParser &parser,
1398-
DenseIntElementsAttr &membersIdx) {
1399-
SmallVector<APInt> values;
1400-
int64_t value;
1401-
int64_t shape[2] = {0, 0};
1402-
unsigned shapeTmp = 0;
1398+
ArrayAttr &membersIdx) {
1399+
SmallVector<Attribute> values, memberIdxs;
1400+
14031401
auto parseIndices = [&]() -> ParseResult {
1402+
int64_t value;
14041403
if (parser.parseInteger(value))
14051404
return failure();
1406-
shapeTmp++;
1407-
values.push_back(APInt(32, value, /*isSigned=*/true));
1405+
values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1406+
APInt(64, value, /*isSigned=*/false)));
14081407
return success();
14091408
};
14101409

@@ -1418,52 +1417,29 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
14181417
if (failed(parser.parseRSquare()))
14191418
return failure();
14201419

1421-
// Only set once, if any indices are not the same size
1422-
// we error out in the next check as that's unsupported
1423-
if (shape[1] == 0)
1424-
shape[1] = shapeTmp;
1425-
1426-
// Verify that the recently parsed list is equal to the
1427-
// first one we parsed, they must be equal lengths to
1428-
// keep the rectangular shape DenseIntElementsAttr
1429-
// requires
1430-
if (shapeTmp != shape[1])
1431-
return failure();
1432-
1433-
shapeTmp = 0;
1434-
shape[0]++;
1420+
memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1421+
values.clear();
14351422
} while (succeeded(parser.parseOptionalComma()));
14361423

1437-
if (!values.empty()) {
1438-
ShapedType valueType =
1439-
VectorType::get(shape, IntegerType::get(parser.getContext(), 32));
1440-
membersIdx = DenseIntElementsAttr::get(valueType, values);
1441-
}
1424+
if (!memberIdxs.empty())
1425+
membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
14421426

14431427
return success();
14441428
}
14451429

14461430
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1447-
DenseIntElementsAttr membersIdx) {
1448-
llvm::ArrayRef<int64_t> shape = membersIdx.getShapedType().getShape();
1449-
assert(shape.size() <= 2);
1450-
1431+
ArrayAttr membersIdx) {
14511432
if (!membersIdx)
14521433
return;
14531434

1454-
for (int i = 0; i < shape[0]; ++i) {
1435+
llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
14551436
p << "[";
1456-
int rowOffset = i * shape[1];
1457-
for (int j = 0; j < shape[1]; ++j) {
1458-
p << membersIdx.getValues<int32_t>()[rowOffset + j];
1459-
if ((j + 1) < shape[1])
1460-
p << ",";
1461-
}
1437+
auto memberIdx = cast<ArrayAttr>(v);
1438+
llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1439+
p << cast<IntegerAttr>(v2).getInt();
1440+
});
14621441
p << "]";
1463-
1464-
if ((i + 1) < shape[0])
1465-
p << ", ";
1466-
}
1442+
});
14671443
}
14681444

14691445
static void printCaptureType(OpAsmPrinter &p, Operation *op,

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,46 +2386,36 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
23862386

23872387
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
23882388
bool first) {
2389-
DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
2390-
2389+
ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
23912390
// Only 1 member has been mapped, we can return it.
23922391
if (indexAttr.size() == 1)
2393-
if (auto mapOp =
2394-
dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()))
2395-
return mapOp;
2392+
return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
23962393

2397-
llvm::ArrayRef<int64_t> shape = indexAttr.getShapedType().getShape();
2398-
llvm::SmallVector<size_t> indices(shape[0]);
2394+
llvm::SmallVector<size_t> indices(indexAttr.size());
23992395
std::iota(indices.begin(), indices.end(), 0);
24002396

24012397
llvm::sort(indices.begin(), indices.end(),
24022398
[&](const size_t a, const size_t b) {
2403-
auto indexValues = indexAttr.getValues<int32_t>();
2404-
for (int i = 0; i < shape[1]; ++i) {
2405-
int aIndex = indexValues[a * shape[1] + i];
2406-
int bIndex = indexValues[b * shape[1] + i];
2399+
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
2400+
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
2401+
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
2402+
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
2403+
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
24072404

24082405
if (aIndex == bIndex)
24092406
continue;
24102407

2411-
if (aIndex != -1 && bIndex == -1)
2412-
return false;
2413-
2414-
if (aIndex == -1 && bIndex != -1)
2415-
return true;
2416-
2417-
// A is earlier in the record type layout than B
24182408
if (aIndex < bIndex)
24192409
return first;
24202410

2421-
if (bIndex < aIndex)
2411+
if (aIndex > bIndex)
24222412
return !first;
24232413
}
24242414

2425-
// Iterated the entire list and couldn't make a decision, all
2426-
// elements were likely the same. Return false, since the sort
2427-
// comparator should return false for equal elements.
2428-
return false;
2415+
// Iterated the up until the end of the smallest member and
2416+
// they were found to be equal up to that point, so select
2417+
// the member with the lowest index count, so the "parent"
2418+
return memberIndicesA.size() < memberIndicesB.size();
24292419
});
24302420

24312421
return llvm::cast<omp::MapInfoOp>(
@@ -2596,17 +2586,8 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
25962586
/*isSigned=*/false);
25972587
combinedInfo.Sizes.push_back(size);
25982588

2599-
// TODO: This will need to be expanded to include the whole host of logic for
2600-
// the map flags that Clang currently supports (e.g. it should take the map
2601-
// flag of the parent map flag, remove the OMP_MAP_TARGET_PARAM and do some
2602-
// further case specific flag modifications). For the moment, it handles what
2603-
// we support as expected.
2604-
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2605-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2606-
26072589
llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
26082590
ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
2609-
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
26102591

26112592
// This creates the initial MEMBER_OF mapping that consists of
26122593
// the parent/top level container (same as above effectively, except
@@ -2615,6 +2596,12 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
26152596
// only relevant if the structure in its totality is being mapped,
26162597
// otherwise the above suffices.
26172598
if (!parentClause.getPartialMap()) {
2599+
// TODO: This will need to be expanded to include the whole host of logic
2600+
// for the map flags that Clang currently supports (e.g. it should do some
2601+
// further case specific flag modifications). For the moment, it handles
2602+
// what we support as expected.
2603+
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
2604+
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
26182605
combinedInfo.Types.emplace_back(mapFlag);
26192606
combinedInfo.DevicePointers.emplace_back(
26202607
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
@@ -2665,6 +2652,31 @@ static void processMapMembersWithParent(
26652652

26662653
assert(memberDataIdx >= 0 && "could not find mapped member of structure");
26672654

2655+
// If we're currently mapping a pointer to a block of data, we must
2656+
// initially map the pointer, and then attatch/bind the data with a
2657+
// subsequent map to the pointer. This segment of code generates the
2658+
// pointer mapping, which can in certain cases be optimised out as Clang
2659+
// currently does in its lowering. However, for the moment we do not do so,
2660+
// in part as we currently have substantially less information on the data
2661+
// being mapped at this stage.
2662+
if (checkIfPointerMap(memberClause)) {
2663+
auto mapFlag = llvm::omp::OpenMPOffloadMappingFlags(
2664+
memberClause.getMapType().value());
2665+
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2666+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
2667+
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
2668+
combinedInfo.Types.emplace_back(mapFlag);
2669+
combinedInfo.DevicePointers.emplace_back(
2670+
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2671+
combinedInfo.Names.emplace_back(
2672+
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2673+
combinedInfo.BasePointers.emplace_back(
2674+
mapData.BasePointers[mapDataIndex]);
2675+
combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
2676+
combinedInfo.Sizes.emplace_back(builder.getInt64(
2677+
moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
2678+
}
2679+
26682680
// Same MemberOfFlag to indicate its link with parent and other members
26692681
// of.
26702682
auto mapFlag =
@@ -2680,7 +2692,10 @@ static void processMapMembersWithParent(
26802692
mapData.DevicePointers[memberDataIdx]);
26812693
combinedInfo.Names.emplace_back(
26822694
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2683-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2695+
uint64_t basePointerIndex =
2696+
checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
2697+
combinedInfo.BasePointers.emplace_back(
2698+
mapData.BasePointers[basePointerIndex]);
26842699
combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
26852700
combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
26862701
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2633,8 +2633,8 @@ func.func @omp_map_with_members(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm
26332633
// CHECK: %[[MAP4:.*]] = omp.map.info var_ptr(%[[ARG4]] : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
26342634
%mapv5 = omp.map.info var_ptr(%arg4 : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
26352635

2636-
// CHECK: %[[MAP5:.*]] = omp.map.info var_ptr(%[[ARG5]] : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%[[MAP3]], %[[MAP4]] : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2637-
%mapv6 = omp.map.info var_ptr(%arg5 : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%mapv4, %mapv5 : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2636+
// CHECK: %[[MAP5:.*]] = omp.map.info var_ptr(%[[ARG5]] : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%[[MAP3]], %[[MAP4]] : [1, 0], [1, 1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2637+
%mapv6 = omp.map.info var_ptr(%arg5 : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%mapv4, %mapv5 : [1, 0], [1, 1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
26382638

26392639
// CHECK: omp.target_exit_data map_entries(%[[MAP3]], %[[MAP4]], %[[MAP5]] : !llvm.ptr, !llvm.ptr, !llvm.ptr)
26402640
omp.target_exit_data map_entries(%mapv4, %mapv5, %mapv6 : !llvm.ptr, !llvm.ptr, !llvm.ptr){}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// This test checks the offload sizes, map types and base pointers and pointers
4+
// provided to the OpenMP kernel argument structure are correct when lowering
5+
// to LLVM-IR from MLIR when performing explicit member mapping of a record type
6+
// that includes records with pointer members in various locations of the record
7+
// types hierarchy.
8+
9+
module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
10+
llvm.func @omp_nested_derived_type_alloca_map(%arg0: !llvm.ptr) {
11+
%0 = llvm.mlir.constant(4 : index) : i64
12+
%1 = llvm.mlir.constant(1 : index) : i64
13+
%2 = llvm.mlir.constant(2 : index) : i64
14+
%3 = llvm.mlir.constant(0 : index) : i64
15+
%4 = llvm.mlir.constant(6 : index) : i64
16+
%5 = omp.map.bounds lower_bound(%3 : i64) upper_bound(%0 : i64) extent(%0 : i64) stride(%1 : i64) start_idx(%3 : i64) {stride_in_bytes = true}
17+
%6 = llvm.getelementptr %arg0[0, 6] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, struct<(ptr, i64, i32, i8, i8, i8, i8)>, array<10 x i32>, f32, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32, struct<(f32, array<10 x i32>, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32)>)>
18+
%7 = llvm.getelementptr %6[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, array<10 x i32>, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32)>
19+
%8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
20+
%9 = omp.map.info var_ptr(%7 : !llvm.ptr, i32) var_ptr_ptr(%8 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) bounds(%5) -> !llvm.ptr {name = ""}
21+
%10 = omp.map.info var_ptr(%7 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "one_l%nest%array_k"}
22+
%11 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.struct<(f32, struct<(ptr, i64, i32, i8, i8, i8, i8)>, array<10 x i32>, f32, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32, struct<(f32, array<10 x i32>, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32)>)>) map_clauses(tofrom) capture(ByRef) members(%10, %9 : [6,2], [6,2,0] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "one_l", partial_map = true}
23+
omp.target map_entries(%10 -> %arg1, %9 -> %arg2, %11 -> %arg3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
24+
omp.terminator
25+
}
26+
llvm.return
27+
}
28+
}
29+
30+
// CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [4 x i64] [i64 0, i64 48, i64 8, i64 20]
31+
// CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710675]
32+
33+
// CHECK: define void @omp_nested_derived_type_alloca_map(ptr %[[ARG:.*]]) {
34+
35+
// CHECK: %[[NESTED_DTYPE_MEMBER_GEP:.*]] = getelementptr { float, { ptr, i64, i32, i8, i8, i8, i8 }, [10 x i32], float, { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i32, { float, [10 x i32], { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i32 } }, ptr %[[ARG]], i32 0, i32 6
36+
// CHECK: %[[NESTED_STRUCT_PTR_MEMBER_GEP:.*]] = getelementptr { float, [10 x i32], { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i32 }, ptr %[[NESTED_DTYPE_MEMBER_GEP]], i32 0, i32 2
37+
// CHECK: %[[NESTED_STRUCT_PTR_MEMBER_BADDR_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]], i32 0, i32 0
38+
// CHECK: %[[NESTED_STRUCT_PTR_MEMBER_BADDR_LOAD:.*]] = load ptr, ptr %[[NESTED_STRUCT_PTR_MEMBER_BADDR_GEP]], align 8
39+
// CHECK: %[[ARR_OFFSET:.*]] = getelementptr inbounds i32, ptr %[[NESTED_STRUCT_PTR_MEMBER_BADDR_LOAD]], i64 0
40+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_1:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]], i64 1
41+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_2:.*]] = ptrtoint ptr %[[DTYPE_SIZE_SEGMENT_CALC_1]] to i64
42+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_3:.*]] = ptrtoint ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]] to i64
43+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_4:.*]] = sub i64 %[[DTYPE_SIZE_SEGMENT_CALC_2]], %[[DTYPE_SIZE_SEGMENT_CALC_3]]
44+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_5:.*]] = sdiv exact i64 %[[DTYPE_SIZE_SEGMENT_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
45+
46+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
47+
// CHECK: store ptr %[[ARG]], ptr %[[BASE_PTRS]], align 8
48+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 0
49+
// CHECK: store ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]], ptr %[[OFFLOAD_PTRS]], align 8
50+
// CHECK: %[[OFFLOAD_SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 0
51+
// CHECK: store i64 %[[DTYPE_SIZE_SEGMENT_CALC_5]], ptr %[[OFFLOAD_SIZES]], align 8
52+
53+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
54+
// CHECK: store ptr %[[ARG]], ptr %[[BASE_PTRS]], align 8
55+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 1
56+
// CHECK: store ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]], ptr %[[OFFLOAD_PTRS]], align 8
57+
58+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
59+
// CHECK: store ptr %[[ARG]], ptr %[[BASE_PTRS]], align 8
60+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 2
61+
// CHECK: store ptr %[[NESTED_STRUCT_PTR_MEMBER_BADDR_GEP]], ptr %[[OFFLOAD_PTRS]], align 8
62+
63+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 3
64+
// CHECK: store ptr %[[NESTED_STRUCT_PTR_MEMBER_BADDR_GEP]], ptr %[[BASE_PTRS]], align 8
65+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 3
66+
// CHECK: store ptr %[[ARR_OFFSET]], ptr %[[OFFLOAD_PTRS]], align 8

mlir/test/Target/LLVMIR/omptarget-nested-record-type-mapping-host.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ llvm.func @_QQmain() {
2121
%9 = llvm.getelementptr %4[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, array<10 x i32>, struct<(f32, i32)>, i32)>
2222
%10 = omp.map.bounds lower_bound(%2 : i64) upper_bound(%1 : i64) extent(%0 : i64) stride(%2 : i64) start_idx(%2 : i64)
2323
%11 = omp.map.info var_ptr(%9 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%10) -> !llvm.ptr
24-
%12 = omp.map.info var_ptr(%4 : !llvm.ptr, !llvm.struct<(f32, array<10 x i32>, struct<(f32, i32)>, i32)>) map_clauses(tofrom) capture(ByRef) members(%6, %8, %11 : [3, -1], [2, 1], [1, -1] : !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr {partial_map = true}
24+
%12 = omp.map.info var_ptr(%4 : !llvm.ptr, !llvm.struct<(f32, array<10 x i32>, struct<(f32, i32)>, i32)>) map_clauses(tofrom) capture(ByRef) members(%6, %8, %11 : [3], [2, 1], [1] : !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr {partial_map = true}
2525
omp.target map_entries(%6 -> %arg0, %8 -> %arg1, %11 -> %arg2, %12 -> %arg3 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) {
2626
omp.terminator
2727
}

0 commit comments

Comments
 (0)