Skip to content

Commit 50df0ff

Browse files
committed
[OpenMP][MLIR] Add new arguments to map_info to help support record type maps
This PR adds two new fields to omp.map_info, one BoolAttr and one I64ArrayAttr. The BoolAttr is named partial_map, and is a flag that indicates if the record type captured by the map_info operation is a partial map, or if it is mapped in its entirety, this currently helps the later lowering determine the type of map entries that need to be generated. The I64ArrayAttr named members_index is intended to track the placement of each member map_info operations (and by extension mapped member variable) placement in the parent record type. This may need to be extended to an N-D array for nested member mapping. Pull Request: #82851
1 parent 7fd6cb2 commit 50df0ff

File tree

3 files changed

+117
-7
lines changed

3 files changed

+117
-7
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,10 +1423,12 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
14231423
TypeAttr:$var_type,
14241424
Optional<OpenMP_PointerLikeType>:$var_ptr_ptr,
14251425
Variadic<OpenMP_PointerLikeType>:$members,
1426+
OptionalAttr<AnyIntElementsAttr>:$members_index,
14261427
Variadic<MapBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
14271428
OptionalAttr<UI64Attr>:$map_type,
14281429
OptionalAttr<VariableCaptureKindAttr>:$map_capture_type,
1429-
OptionalAttr<StrAttr>:$name);
1430+
OptionalAttr<StrAttr>:$name,
1431+
DefaultValuedAttr<BoolAttr, "false">:$partial_map);
14301432
let results = (outs OpenMP_PointerLikeType:$omp_ptr);
14311433

14321434
let description = [{
@@ -1462,10 +1464,14 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
14621464
- `var_type`: The type of the variable to copy.
14631465
- `var_ptr_ptr`: Used when the variable copied is a member of a class, structure
14641466
or derived type and refers to the originating struct.
1465-
- `members`: Used to indicate mapped child members for the current MapInfoOp,
1467+
- `members`: Used to indicate mapped child members for the current MapInfoOp,
14661468
represented as other MapInfoOp's, utilised in cases where a parent structure
14671469
type and members of the structure type are being mapped at the same time.
14681470
For example: map(to: parent, parent->member, parent->member2[:10])
1471+
- `members_index`: Used to indicate the ordering of members within the containing
1472+
parent (generally a record type such as a structure, class or derived type),
1473+
e.g. struct {int x, float y, double z}, x would be 0, y would be 1, and z
1474+
would be 2. This aids the mapping.
14691475
- `bounds`: Used when copying slices of array's, pointers or pointer members of
14701476
objects (e.g. derived types or classes), indicates the bounds to be copied
14711477
of the variable. When it's an array slice it is in rank order where rank 0
@@ -1476,6 +1482,8 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
14761482
- 'map_capture_type': Capture type for the variable e.g. this, byref, byvalue, byvla
14771483
this can affect how the variable is lowered.
14781484
- `name`: Holds the name of variable as specified in user clause (including bounds).
1485+
- `partial_map`: The record type being mapped will not be mapped in its entirety,
1486+
it may be used however, in a mapping to bind it's mapped components together.
14791487
}];
14801488

14811489
let assemblyFormat = [{
@@ -1484,7 +1492,7 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
14841492
`var_ptr_ptr` `(` $var_ptr_ptr `:` type($var_ptr_ptr) `)`
14851493
| `map_clauses` `(` custom<MapClause>($map_type) `)`
14861494
| `capture` `(` custom<CaptureType>($map_capture_type) `)`
1487-
| `members` `(` $members `:` type($members) `)`
1495+
| `members` `(` $members `:` custom<MembersIndex>($members_index) `:` type($members) `)`
14881496
| `bounds` `(` $bounds `)`
14891497
) `->` type($omp_ptr) attr-dict
14901498
}];

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,79 @@ static void printMapClause(OpAsmPrinter &p, Operation *op,
994994
}
995995
}
996996

997+
static ParseResult parseMembersIndex(OpAsmParser &parser,
998+
DenseIntElementsAttr &membersIdx) {
999+
SmallVector<APInt> values;
1000+
int64_t value;
1001+
int64_t shape[2] = {0, 0};
1002+
unsigned shapeTmp = 0;
1003+
auto parseIndices = [&]() -> ParseResult {
1004+
if (parser.parseInteger(value))
1005+
return failure();
1006+
shapeTmp++;
1007+
values.push_back(APInt(32, value));
1008+
return success();
1009+
};
1010+
1011+
do {
1012+
if (failed(parser.parseLSquare()))
1013+
return failure();
1014+
1015+
if (parser.parseCommaSeparatedList(parseIndices))
1016+
return failure();
1017+
1018+
if (failed(parser.parseRSquare()))
1019+
return failure();
1020+
1021+
// Only set once, if any indices are not the same size
1022+
// we error out in the next check as that's unsupported
1023+
if (shape[1] == 0)
1024+
shape[1] = shapeTmp;
1025+
1026+
// Verify that the recently parsed list is equal to the
1027+
// first one we parsed, they must be equal lengths to
1028+
// keep the rectangular shape DenseIntElementsAttr
1029+
// requires
1030+
if (shapeTmp != shape[1])
1031+
return failure();
1032+
1033+
shapeTmp = 0;
1034+
shape[0]++;
1035+
} while (succeeded(parser.parseOptionalComma()));
1036+
1037+
if (!values.empty()) {
1038+
ShapedType valueType =
1039+
VectorType::get(shape, IntegerType::get(parser.getContext(), 32));
1040+
membersIdx = DenseIntElementsAttr::get(valueType, values);
1041+
}
1042+
1043+
return success();
1044+
}
1045+
1046+
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1047+
DenseIntElementsAttr membersIdx) {
1048+
llvm::ArrayRef<int64_t> shape = membersIdx.getShapedType().getShape();
1049+
assert(shape.size() <= 2);
1050+
1051+
if (!membersIdx)
1052+
return;
1053+
1054+
for (int i = 0; i < shape[0]; ++i) {
1055+
p << "[";
1056+
int rowOffset = i * shape[1];
1057+
for (int j = 0; j < shape[1]; ++j) {
1058+
p << membersIdx.getValues<
1059+
int32_t>()[rowOffset + j];
1060+
if ((j + 1) < shape[1])
1061+
p << ",";
1062+
}
1063+
p << "]";
1064+
1065+
if ((i + 1) < shape[0])
1066+
p << ", ";
1067+
}
1068+
}
1069+
9971070
static ParseResult
9981071
parseMapEntries(OpAsmParser &parser,
9991072
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapOperands,

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,8 +2306,6 @@ func.func @omp_requires_multiple() -> ()
23062306
return
23072307
}
23082308

2309-
// -----
2310-
23112309
// CHECK-LABEL: @opaque_pointers_atomic_rwu
23122310
// CHECK-SAME: (%[[v:.*]]: !llvm.ptr, %[[x:.*]]: !llvm.ptr)
23132311
func.func @opaque_pointers_atomic_rwu(%v: !llvm.ptr, %x: !llvm.ptr) {
@@ -2417,8 +2415,8 @@ func.func @omp_target_update_data (%if_cond : i1, %device : si32, %map1: memref<
24172415
func.func @omp_targets_is_allocatable(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> () {
24182416
// CHECK: %[[MAP0:.*]] = omp.map.info var_ptr(%[[ARG0]] : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
24192417
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(tofrom) capture(ByRef) -> !llvm.ptr {name = ""}
2420-
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP0]] : !llvm.ptr) -> !llvm.ptr {name = ""}
2421-
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%mapv1 : !llvm.ptr) -> !llvm.ptr {name = ""}
2418+
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%[[MAP0]] : [0] : !llvm.ptr) -> !llvm.ptr {name = ""}
2419+
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>) map_clauses(tofrom) capture(ByRef) members(%mapv1 : [0] : !llvm.ptr) -> !llvm.ptr {name = ""}
24222420
// CHECK: omp.target map_entries(%[[MAP0]] -> {{.*}}, %[[MAP1]] -> {{.*}} : !llvm.ptr, !llvm.ptr)
24232421
omp.target map_entries(%mapv1 -> %arg2, %mapv2 -> %arg3 : !llvm.ptr, !llvm.ptr) {
24242422
^bb0(%arg2: !llvm.ptr, %arg3 : !llvm.ptr):
@@ -2473,6 +2471,37 @@ func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memre
24732471
}
24742472
// CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>) depend(taskdependin -> [[ARG2]] : memref<?xi32>)
24752473
omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
2474+
2475+
return
2476+
}
2477+
2478+
// CHECK-LABEL: omp_map_with_members
2479+
// CHECK-SAME: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr, %[[ARG3:.*]]: !llvm.ptr, %[[ARG4:.*]]: !llvm.ptr, %[[ARG5:.*]]: !llvm.ptr)
2480+
func.func @omp_map_with_members(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: !llvm.ptr, %arg3: !llvm.ptr, %arg4: !llvm.ptr, %arg5: !llvm.ptr) -> () {
2481+
// CHECK: %[[MAP0:.*]] = omp.map.info var_ptr(%[[ARG0]] : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
2482+
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
2483+
2484+
// CHECK: %[[MAP1:.*]] = omp.map.info var_ptr(%[[ARG1]] : !llvm.ptr, f32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
2485+
%mapv2 = omp.map.info var_ptr(%arg1 : !llvm.ptr, f32) map_clauses(to) capture(ByRef) -> !llvm.ptr {name = ""}
2486+
2487+
// CHECK: %[[MAP2:.*]] = omp.map.info var_ptr(%[[ARG2]] : !llvm.ptr, !llvm.struct<(i32, f32)>) map_clauses(to) capture(ByRef) members(%[[MAP0]], %[[MAP1]] : [0], [1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2488+
%mapv3 = omp.map.info var_ptr(%arg2 : !llvm.ptr, !llvm.struct<(i32, f32)>) map_clauses(to) capture(ByRef) members(%mapv1, %mapv2 : [0], [1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2489+
2490+
// CHECK: omp.target_enter_data map_entries(%[[MAP0]], %[[MAP1]], %[[MAP2]] : !llvm.ptr, !llvm.ptr, !llvm.ptr)
2491+
omp.target_enter_data map_entries(%mapv1, %mapv2, %mapv3 : !llvm.ptr, !llvm.ptr, !llvm.ptr){}
2492+
2493+
// CHECK: %[[MAP3:.*]] = omp.map.info var_ptr(%[[ARG3]] : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
2494+
%mapv4 = omp.map.info var_ptr(%arg3 : !llvm.ptr, i32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
2495+
2496+
// CHECK: %[[MAP4:.*]] = omp.map.info var_ptr(%[[ARG4]] : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
2497+
%mapv5 = omp.map.info var_ptr(%arg4 : !llvm.ptr, f32) map_clauses(from) capture(ByRef) -> !llvm.ptr {name = ""}
2498+
2499+
// CHECK: %[[MAP5:.*]] = omp.map.info var_ptr(%[[ARG5]] : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%[[MAP3]], %[[MAP4]] : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2500+
%mapv6 = omp.map.info var_ptr(%arg5 : !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>) map_clauses(from) capture(ByRef) members(%mapv4, %mapv5 : [1,0], [1,1] : !llvm.ptr, !llvm.ptr) -> !llvm.ptr {name = "", partial_map = true}
2501+
2502+
// CHECK: omp.target_exit_data map_entries(%[[MAP3]], %[[MAP4]], %[[MAP5]] : !llvm.ptr, !llvm.ptr, !llvm.ptr)
2503+
omp.target_exit_data map_entries(%mapv4, %mapv5, %mapv6 : !llvm.ptr, !llvm.ptr, !llvm.ptr){}
2504+
24762505
return
24772506
}
24782507

0 commit comments

Comments
 (0)