Skip to content

Commit b27db91

Browse files
committed
[OpenMP][MLIR] Descriptor explicit member map lowering changes
This is one of 3 PRs in a PR stack that aims to add support for explicit mapping of allocatable members in derived types. The primary changes in this PR are the OpenMPToLLVMIRTranslation.cpp changes, which are small and seek to alter the current member mapping to add an additional map insertion for pointers. Effectively, if the member is a pointer (currently indicated by having a varPtrPtr field) we add an additional map for the pointer and then alter the subsequent mapping of the member (the data) to utilise the member rather than the parents base pointer. This appears to be necessary in certain cases when mapping pointer data within record types to avoid segfaulting on device (due to incorrect data mapping). In general this record type mapping may be simplifiable in the future. There are also additions of tests which should help to showcase the affect of the changes above.
1 parent 5192cb7 commit b27db91

File tree

7 files changed

+197
-130
lines changed

7 files changed

+197
-130
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
895895
TypeAttr:$var_type,
896896
Optional<OpenMP_PointerLikeType>:$var_ptr_ptr,
897897
Variadic<OpenMP_PointerLikeType>:$members,
898-
OptionalAttr<AnyIntElementsAttr>:$members_index,
898+
OptionalAttr<IndexListArrayAttr>:$members_index,
899899
Variadic<OpenMP_MapBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
900900
OptionalAttr<UI64Attr>:$map_type,
901901
OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,16 +1395,15 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
13951395
}
13961396

13971397
static ParseResult parseMembersIndex(OpAsmParser &parser,
1398-
DenseIntElementsAttr &membersIdx) {
1399-
SmallVector<APInt> values;
1400-
int64_t value;
1401-
int64_t shape[2] = {0, 0};
1402-
unsigned shapeTmp = 0;
1398+
ArrayAttr &membersIdx) {
1399+
SmallVector<Attribute> values, memberIdxs;
1400+
14031401
auto parseIndices = [&]() -> ParseResult {
1402+
int64_t value;
14041403
if (parser.parseInteger(value))
14051404
return failure();
1406-
shapeTmp++;
1407-
values.push_back(APInt(32, value, /*isSigned=*/true));
1405+
values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1406+
APInt(64, value, /*isSigned=*/false)));
14081407
return success();
14091408
};
14101409

@@ -1418,52 +1417,29 @@ static ParseResult parseMembersIndex(OpAsmParser &parser,
14181417
if (failed(parser.parseRSquare()))
14191418
return failure();
14201419

1421-
// Only set once, if any indices are not the same size
1422-
// we error out in the next check as that's unsupported
1423-
if (shape[1] == 0)
1424-
shape[1] = shapeTmp;
1425-
1426-
// Verify that the recently parsed list is equal to the
1427-
// first one we parsed, they must be equal lengths to
1428-
// keep the rectangular shape DenseIntElementsAttr
1429-
// requires
1430-
if (shapeTmp != shape[1])
1431-
return failure();
1432-
1433-
shapeTmp = 0;
1434-
shape[0]++;
1420+
memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1421+
values.clear();
14351422
} while (succeeded(parser.parseOptionalComma()));
14361423

1437-
if (!values.empty()) {
1438-
ShapedType valueType =
1439-
VectorType::get(shape, IntegerType::get(parser.getContext(), 32));
1440-
membersIdx = DenseIntElementsAttr::get(valueType, values);
1441-
}
1424+
if (!memberIdxs.empty())
1425+
membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
14421426

14431427
return success();
14441428
}
14451429

14461430
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1447-
DenseIntElementsAttr membersIdx) {
1448-
llvm::ArrayRef<int64_t> shape = membersIdx.getShapedType().getShape();
1449-
assert(shape.size() <= 2);
1450-
1431+
ArrayAttr membersIdx) {
14511432
if (!membersIdx)
14521433
return;
14531434

1454-
for (int i = 0; i < shape[0]; ++i) {
1435+
llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
14551436
p << "[";
1456-
int rowOffset = i * shape[1];
1457-
for (int j = 0; j < shape[1]; ++j) {
1458-
p << membersIdx.getValues<int32_t>()[rowOffset + j];
1459-
if ((j + 1) < shape[1])
1460-
p << ",";
1461-
}
1437+
auto memberIdx = cast<ArrayAttr>(v);
1438+
llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1439+
p << cast<IntegerAttr>(v2).getInt();
1440+
});
14621441
p << "]";
1463-
1464-
if ((i + 1) < shape[0])
1465-
p << ", ";
1466-
}
1442+
});
14671443
}
14681444

14691445
static void printCaptureType(OpAsmPrinter &p, Operation *op,

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2530,46 +2530,36 @@ static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
25302530

25312531
static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
25322532
bool first) {
2533-
DenseIntElementsAttr indexAttr = mapInfo.getMembersIndexAttr();
2534-
2533+
ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
25352534
// Only 1 member has been mapped, we can return it.
25362535
if (indexAttr.size() == 1)
2537-
if (auto mapOp =
2538-
dyn_cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp()))
2539-
return mapOp;
2536+
return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
25402537

2541-
llvm::ArrayRef<int64_t> shape = indexAttr.getShapedType().getShape();
2542-
llvm::SmallVector<size_t> indices(shape[0]);
2538+
llvm::SmallVector<size_t> indices(indexAttr.size());
25432539
std::iota(indices.begin(), indices.end(), 0);
25442540

25452541
llvm::sort(indices.begin(), indices.end(),
25462542
[&](const size_t a, const size_t b) {
2547-
auto indexValues = indexAttr.getValues<int32_t>();
2548-
for (int i = 0; i < shape[1]; ++i) {
2549-
int aIndex = indexValues[a * shape[1] + i];
2550-
int bIndex = indexValues[b * shape[1] + i];
2543+
auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
2544+
auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
2545+
for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
2546+
int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
2547+
int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
25512548

25522549
if (aIndex == bIndex)
25532550
continue;
25542551

2555-
if (aIndex != -1 && bIndex == -1)
2556-
return false;
2557-
2558-
if (aIndex == -1 && bIndex != -1)
2559-
return true;
2560-
2561-
// A is earlier in the record type layout than B
25622552
if (aIndex < bIndex)
25632553
return first;
25642554

2565-
if (bIndex < aIndex)
2555+
if (aIndex > bIndex)
25662556
return !first;
25672557
}
25682558

2569-
// Iterated the entire list and couldn't make a decision, all
2570-
// elements were likely the same. Return false, since the sort
2571-
// comparator should return false for equal elements.
2572-
return false;
2559+
// Iterated the up until the end of the smallest member and
2560+
// they were found to be equal up to that point, so select
2561+
// the member with the lowest index count, so the "parent"
2562+
return memberIndicesA.size() < memberIndicesB.size();
25732563
});
25742564

25752565
return llvm::cast<omp::MapInfoOp>(
@@ -2740,17 +2730,8 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
27402730
/*isSigned=*/false);
27412731
combinedInfo.Sizes.push_back(size);
27422732

2743-
// TODO: This will need to be expanded to include the whole host of logic for
2744-
// the map flags that Clang currently supports (e.g. it should take the map
2745-
// flag of the parent map flag, remove the OMP_MAP_TARGET_PARAM and do some
2746-
// further case specific flag modifications). For the moment, it handles what
2747-
// we support as expected.
2748-
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
2749-
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
2750-
27512733
llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
27522734
ompBuilder.getMemberOfFlag(combinedInfo.BasePointers.size() - 1);
2753-
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
27542735

27552736
// This creates the initial MEMBER_OF mapping that consists of
27562737
// the parent/top level container (same as above effectively, except
@@ -2759,6 +2740,12 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
27592740
// only relevant if the structure in its totality is being mapped,
27602741
// otherwise the above suffices.
27612742
if (!parentClause.getPartialMap()) {
2743+
// TODO: This will need to be expanded to include the whole host of logic
2744+
// for the map flags that Clang currently supports (e.g. it should do some
2745+
// further case specific flag modifications). For the moment, it handles
2746+
// what we support as expected.
2747+
llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
2748+
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
27622749
combinedInfo.Types.emplace_back(mapFlag);
27632750
combinedInfo.DevicePointers.emplace_back(
27642751
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
@@ -2809,6 +2796,31 @@ static void processMapMembersWithParent(
28092796

28102797
assert(memberDataIdx >= 0 && "could not find mapped member of structure");
28112798

2799+
// If we're currently mapping a pointer to a block of data, we must
2800+
// initially map the pointer, and then attatch/bind the data with a
2801+
// subsequent map to the pointer. This segment of code generates the
2802+
// pointer mapping, which can in certain cases be optimised out as Clang
2803+
// currently does in its lowering. However, for the moment we do not do so,
2804+
// in part as we currently have substantially less information on the data
2805+
// being mapped at this stage.
2806+
if (checkIfPointerMap(memberClause)) {
2807+
auto mapFlag = llvm::omp::OpenMPOffloadMappingFlags(
2808+
memberClause.getMapType().value());
2809+
mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
2810+
mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
2811+
ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
2812+
combinedInfo.Types.emplace_back(mapFlag);
2813+
combinedInfo.DevicePointers.emplace_back(
2814+
llvm::OpenMPIRBuilder::DeviceInfoTy::None);
2815+
combinedInfo.Names.emplace_back(
2816+
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2817+
combinedInfo.BasePointers.emplace_back(
2818+
mapData.BasePointers[mapDataIndex]);
2819+
combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
2820+
combinedInfo.Sizes.emplace_back(builder.getInt64(
2821+
moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
2822+
}
2823+
28122824
// Same MemberOfFlag to indicate its link with parent and other members
28132825
// of.
28142826
auto mapFlag =
@@ -2824,7 +2836,10 @@ static void processMapMembersWithParent(
28242836
mapData.DevicePointers[memberDataIdx]);
28252837
combinedInfo.Names.emplace_back(
28262838
LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
2827-
combinedInfo.BasePointers.emplace_back(mapData.BasePointers[mapDataIndex]);
2839+
uint64_t basePointerIndex =
2840+
checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
2841+
combinedInfo.BasePointers.emplace_back(
2842+
mapData.BasePointers[basePointerIndex]);
28282843
combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
28292844
combinedInfo.Sizes.emplace_back(mapData.Sizes[memberDataIdx]);
28302845
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2633,8 +2633,8 @@ func.func @omp_map_with_members(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm
26332633
// CHECK: %[[MAP4:.*]] = omp.map.info var_ptr(%[[ARG4]] : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
26342634
%mapv5 = omp.map.info var_ptr(%arg4 : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
26352635

2636-
// CHECK: %[[MAP5:.*]] = omp.map.info var_ptr(%[[ARG5]] : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%[[MAP3]], %[[MAP4]] : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2637-
%mapv6 = omp.map.info var_ptr(%arg5 : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%mapv4, %mapv5 : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2636+
// CHECK: %[[MAP5:.*]] = omp.map.info var_ptr(%[[ARG5]] : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%[[MAP3]], %[[MAP4]] : [1, 0], [1, 1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2637+
%mapv6 = omp.map.info var_ptr(%arg5 : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%mapv4, %mapv5 : [1, 0], [1, 1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
26382638

26392639
// CHECK: omp.target_exit_data map_entries(%[[MAP3]], %[[MAP4]], %[[MAP5]] : !llvm.ptr, !llvm.ptr, !llvm.ptr)
26402640
omp.target_exit_data map_entries(%mapv4, %mapv5, %mapv6 : !llvm.ptr, !llvm.ptr, !llvm.ptr){}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// This test checks the offload sizes, map types and base pointers and pointers
4+
// provided to the OpenMP kernel argument structure are correct when lowering
5+
// to LLVM-IR from MLIR when performing explicit member mapping of a record type
6+
// that includes records with pointer members in various locations of the record
7+
// types hierarchy.
8+
9+
module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
10+
llvm.func @omp_nested_derived_type_alloca_map(%arg0: !llvm.ptr) {
11+
%0 = llvm.mlir.constant(4 : index) : i64
12+
%1 = llvm.mlir.constant(1 : index) : i64
13+
%2 = llvm.mlir.constant(2 : index) : i64
14+
%3 = llvm.mlir.constant(0 : index) : i64
15+
%4 = llvm.mlir.constant(6 : index) : i64
16+
%5 = omp.map.bounds lower_bound(%3 : i64) upper_bound(%0 : i64) extent(%0 : i64) stride(%1 : i64) start_idx(%3 : i64) {stride_in_bytes = true}
17+
%6 = llvm.getelementptr %arg0[0, 6] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, struct<(ptr, i64, i32, i8, i8, i8, i8)>, array<10 x i32>, f32, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32, struct<(f32, array<10 x i32>, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32)>)>
18+
%7 = llvm.getelementptr %6[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, array<10 x i32>, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32)>
19+
%8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>
20+
%9 = omp.map.info var_ptr(%7 : !llvm.ptr, i32) var_ptr_ptr(%8 : !llvm.ptr) map_clauses(tofrom) capture(ByRef) bounds(%5) -> !llvm.ptr {name = ""}
21+
%10 = omp.map.info var_ptr(%7 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = "one_l%nest%array_k"}
22+
%11 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.struct<(f32, struct<(ptr, i64, i32, i8, i8, i8, i8)>, array<10 x i32>, f32, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32, struct<(f32, array<10 x i32>, struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)>, i32)>)>) map_clauses(tofrom) capture(ByRef) members(%10, %9 : [6,2], [6,2,0] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "one_l", partial_map = true}
23+
omp.target map_entries(%10 -> %arg1, %9 -> %arg2, %11 -> %arg3 : !llvm.ptr, !llvm.ptr, !llvm.ptr) {
24+
omp.terminator
25+
}
26+
llvm.return
27+
}
28+
}
29+
30+
// CHECK: @.offload_sizes{{.*}} = private unnamed_addr constant [4 x i64] [i64 0, i64 48, i64 8, i64 20]
31+
// CHECK: @.offload_maptypes{{.*}} = private unnamed_addr constant [4 x i64] [i64 32, i64 281474976710659, i64 281474976710659, i64 281474976710675]
32+
33+
// CHECK: define void @omp_nested_derived_type_alloca_map(ptr %[[ARG:.*]]) {
34+
35+
// CHECK: %[[NESTED_DTYPE_MEMBER_GEP:.*]] = getelementptr { float, { ptr, i64, i32, i8, i8, i8, i8 }, [10 x i32], float, { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i32, { float, [10 x i32], { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i32 } }, ptr %[[ARG]], i32 0, i32 6
36+
// CHECK: %[[NESTED_STRUCT_PTR_MEMBER_GEP:.*]] = getelementptr { float, [10 x i32], { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i32 }, ptr %[[NESTED_DTYPE_MEMBER_GEP]], i32 0, i32 2
37+
// CHECK: %[[NESTED_STRUCT_PTR_MEMBER_BADDR_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]], i32 0, i32 0
38+
// CHECK: %[[NESTED_STRUCT_PTR_MEMBER_BADDR_LOAD:.*]] = load ptr, ptr %[[NESTED_STRUCT_PTR_MEMBER_BADDR_GEP]], align 8
39+
// CHECK: %[[ARR_OFFSET:.*]] = getelementptr inbounds i32, ptr %[[NESTED_STRUCT_PTR_MEMBER_BADDR_LOAD]], i64 0
40+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_1:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]], i64 1
41+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_2:.*]] = ptrtoint ptr %[[DTYPE_SIZE_SEGMENT_CALC_1]] to i64
42+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_3:.*]] = ptrtoint ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]] to i64
43+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_4:.*]] = sub i64 %[[DTYPE_SIZE_SEGMENT_CALC_2]], %[[DTYPE_SIZE_SEGMENT_CALC_3]]
44+
// CHECK: %[[DTYPE_SIZE_SEGMENT_CALC_5:.*]] = sdiv exact i64 %[[DTYPE_SIZE_SEGMENT_CALC_4]], ptrtoint (ptr getelementptr (i8, ptr null, i32 1) to i64)
45+
46+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 0
47+
// CHECK: store ptr %[[ARG]], ptr %[[BASE_PTRS]], align 8
48+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 0
49+
// CHECK: store ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]], ptr %[[OFFLOAD_PTRS]], align 8
50+
// CHECK: %[[OFFLOAD_SIZES:.*]] = getelementptr inbounds [4 x i64], ptr %.offload_sizes, i32 0, i32 0
51+
// CHECK: store i64 %[[DTYPE_SIZE_SEGMENT_CALC_5]], ptr %[[OFFLOAD_SIZES]], align 8
52+
53+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 1
54+
// CHECK: store ptr %[[ARG]], ptr %[[BASE_PTRS]], align 8
55+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 1
56+
// CHECK: store ptr %[[NESTED_STRUCT_PTR_MEMBER_GEP]], ptr %[[OFFLOAD_PTRS]], align 8
57+
58+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 2
59+
// CHECK: store ptr %[[ARG]], ptr %[[BASE_PTRS]], align 8
60+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 2
61+
// CHECK: store ptr %[[NESTED_STRUCT_PTR_MEMBER_BADDR_GEP]], ptr %[[OFFLOAD_PTRS]], align 8
62+
63+
// CHECK: %[[BASE_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_baseptrs, i32 0, i32 3
64+
// CHECK: store ptr %[[NESTED_STRUCT_PTR_MEMBER_BADDR_GEP]], ptr %[[BASE_PTRS]], align 8
65+
// CHECK: %[[OFFLOAD_PTRS:.*]] = getelementptr inbounds [4 x ptr], ptr %.offload_ptrs, i32 0, i32 3
66+
// CHECK: store ptr %[[ARR_OFFSET]], ptr %[[OFFLOAD_PTRS]], align 8

mlir/test/Target/LLVMIR/omptarget-nested-record-type-mapping-host.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ llvm.func @_QQmain() {
2121
%9 = llvm.getelementptr %4[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(f32, array<10 x i32>, struct<(f32, i32)>, i32)>
2222
%10 = omp.map.bounds lower_bound(%2 : i64) upper_bound(%1 : i64) extent(%0 : i64) stride(%2 : i64) start_idx(%2 : i64)
2323
%11 = omp.map.info var_ptr(%9 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%10) -> !llvm.ptr
24-
%12 = omp.map.info var_ptr(%4 : !llvm.ptr, !llvm.struct<(f32, array<10 x i32>, struct<(f32, i32)>, i32)>) map_clauses(tofrom) capture(ByRef) members(%6, %8, %11 : [3, -1], [2, 1], [1, -1] : !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr {partial_map = true}
24+
%12 = omp.map.info var_ptr(%4 : !llvm.ptr, !llvm.struct<(f32, array<10 x i32>, struct<(f32, i32)>, i32)>) map_clauses(tofrom) capture(ByRef) members(%6, %8, %11 : [3], [2, 1], [1] : !llvm.ptr, !llvm.ptr, !llvm.ptr) -> !llvm.ptr {partial_map = true}
2525
omp.target map_entries(%6 -> %arg0, %8 -> %arg1, %11 -> %arg2, %12 -> %arg3 : !llvm.ptr, !llvm.ptr, !llvm.ptr, !llvm.ptr) {
2626
omp.terminator
2727
}

0 commit comments

Comments
 (0)