Skip to content

Commit 074e4d2

Browse files
committed
[mlir][OpenMP] Annotate private vars with map_idx when needed
This PR extends the MLIR representation for `omp.target` ops by adding a `map_idx` to `private` vars. This annotation stores the index of the map info operand corresponding to the private var. If the variable does not have a map operand, the `map_idx` attribute is either not present at all or its value is `-1`.
1 parent 3291372 commit 074e4d2

File tree

5 files changed

+124
-38
lines changed

5 files changed

+124
-38
lines changed

flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//===- MapsForPrivatizedSymbols.cpp
2-
//-----------------------------------------===//
1+
//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.
@@ -28,6 +27,7 @@
2827
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
2928
#include "flang/Optimizer/HLFIR/HLFIROps.h"
3029
#include "flang/Optimizer/OpenMP/Passes.h"
30+
3131
#include "mlir/Dialect/Func/IR/FuncOps.h"
3232
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3333
#include "mlir/IR/BuiltinAttributes.h"
@@ -124,6 +124,8 @@ class MapsForPrivatizedSymbolsPass
124124
if (targetOp.getPrivateVars().empty())
125125
return;
126126
OperandRange privVars = targetOp.getPrivateVars();
127+
llvm::SmallVector<int64_t> privVarMapIdx;
128+
127129
std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
128130
SmallVector<omp::MapInfoOp, 4> mapInfoOps;
129131
for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
@@ -133,17 +135,25 @@ class MapsForPrivatizedSymbolsPass
133135
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
134136
targetOp, privatizerName);
135137
if (!privatizerNeedsMap(privatizer)) {
138+
privVarMapIdx.push_back(-1);
136139
continue;
137140
}
141+
142+
privVarMapIdx.push_back(targetOp.getMapVars().size() +
143+
mapInfoOps.size());
144+
138145
builder.setInsertionPoint(targetOp);
139146
Location loc = targetOp.getLoc();
140147
omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
141148
mapInfoOps.push_back(mapInfoOp);
149+
142150
LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
143151
LLVM_DEBUG(mapInfoOp.dump());
144152
}
145153
if (!mapInfoOps.empty()) {
146154
mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
155+
targetOp.setPrivateMapsAttr(
156+
mlir::DenseI64ArrayAttr::get(targetOp.getContext(), privVarMapIdx));
147157
}
148158
});
149159
if (!mapInfoOpsForTarget.empty()) {

flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,12 @@ end subroutine target_allocatable
171171
! CHECK_SAME %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] :
172172
! CHECK-SAME !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.ref<!fir.boxchar<1>>)
173173
! CHECK-SAME: private(
174-
! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]],
174+
! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1],
175175
! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]],
176176
! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]],
177-
! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]],
177+
! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2],
178178
! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]],
179-
! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] :
179+
! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] :
180180
! CHECK-SAME: !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<f32>, !fir.ref<i64>, !fir.box<!fir.array<?xf32>>, !fir.ref<complex<f32>>, !fir.boxchar<1>) {
181181
! CHECK-NOT: fir.alloca
182182
! CHECK: hlfir.declare %[[ALLOC_ARG]]

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1225,8 +1225,18 @@ def TargetOp : OpenMP_Op<"target", traits = [
12251225
The optional `if_expr` parameter specifies a boolean result of a conditional
12261226
check. If this value is 1 or is not provided then the target region runs on
12271227
a device, if it is 0 then the target region is executed on the host device.
1228+
1229+
The `private_maps` attribute connects `private` operands to their corresponding
1230+
`map` operands. For `private` operands that require a map, the value of the
1231+
corresponding element in the attribute is the index of the `map` operand
1232+
(relative to other `map` operands not the whole operands of the operation). For
1233+
`private` opernads that do not require a map, this value is -1 (which is omitted
1234+
from the assembly foramt printing).
12281235
}] # clausesDescription;
12291236

1237+
let arguments = !con(clausesArgs,
1238+
(ins OptionalAttr<DenseI64ArrayAttr>:$private_maps));
1239+
12301240
let builders = [
12311241
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
12321242
];
@@ -1239,7 +1249,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
12391249
custom<InReductionMapPrivateRegion>(
12401250
$region, $in_reduction_vars, type($in_reduction_vars),
12411251
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
1242-
$private_vars, type($private_vars), $private_syms) attr-dict
1252+
$private_vars, type($private_vars), $private_syms, $private_maps)
1253+
attr-dict
12431254
}];
12441255

12451256
let hasVerifier = 1;

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 73 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -487,9 +487,11 @@ struct PrivateParseArgs {
487487
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
488488
llvm::SmallVectorImpl<Type> &types;
489489
ArrayAttr &syms;
490+
DenseI64ArrayAttr *mapIndices;
490491
PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
491-
SmallVectorImpl<Type> &types, ArrayAttr &syms)
492-
: vars(vars), types(types), syms(syms) {}
492+
SmallVectorImpl<Type> &types, ArrayAttr &syms,
493+
DenseI64ArrayAttr *mapIndices = nullptr)
494+
: vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
493495
};
494496
struct ReductionParseArgs {
495497
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +519,10 @@ static ParseResult parseClauseWithRegionArgs(
517519
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
518520
SmallVectorImpl<Type> &types,
519521
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
520-
ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
522+
ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
523+
DenseBoolArrayAttr *byref = nullptr) {
521524
SmallVector<SymbolRefAttr> symbolVec;
525+
SmallVector<int64_t> mapIndicesVec;
522526
SmallVector<bool> isByRefVec;
523527
unsigned regionArgOffset = regionPrivateArgs.size();
524528

@@ -538,6 +542,16 @@ static ParseResult parseClauseWithRegionArgs(
538542
parser.parseArgument(regionPrivateArgs.emplace_back()))
539543
return failure();
540544

545+
if (mapIndices) {
546+
if (parser.parseOptionalLSquare().succeeded()) {
547+
if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
548+
parser.parseInteger(mapIndicesVec.emplace_back()) ||
549+
parser.parseRSquare())
550+
return failure();
551+
} else
552+
mapIndicesVec.push_back(-1);
553+
}
554+
541555
return success();
542556
}))
543557
return failure();
@@ -571,6 +585,10 @@ static ParseResult parseClauseWithRegionArgs(
571585
*symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
572586
}
573587

588+
if (!mapIndicesVec.empty())
589+
*mapIndices =
590+
mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
591+
574592
if (byref)
575593
*byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
576594

@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
595613
static ParseResult parseBlockArgClause(
596614
OpAsmParser &parser,
597615
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
598-
StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
616+
StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
599617
if (succeeded(parser.parseOptionalKeyword(keyword))) {
600-
if (!reductionArgs)
618+
if (!privateArgs)
601619
return failure();
602620

603-
if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
604-
reductionArgs->types, entryBlockArgs,
605-
&reductionArgs->syms)))
621+
if (failed(parseClauseWithRegionArgs(
622+
parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
623+
&privateArgs->syms, privateArgs->mapIndices)))
606624
return failure();
607625
}
608626
return success();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
618636

619637
if (failed(parseClauseWithRegionArgs(
620638
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
621-
&reductionArgs->syms, &reductionArgs->byref)))
639+
&reductionArgs->syms, /*mapIndices=*/nullptr,
640+
&reductionArgs->byref)))
622641
return failure();
623642
}
624643
return success();
@@ -674,12 +693,14 @@ static ParseResult parseInReductionMapPrivateRegion(
674693
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
675694
SmallVectorImpl<Type> &mapTypes,
676695
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
677-
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
696+
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
697+
DenseI64ArrayAttr &privateMaps) {
678698
AllRegionParseArgs args;
679699
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
680700
inReductionByref, inReductionSyms);
681701
args.mapArgs.emplace(mapVars, mapTypes);
682-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
702+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
703+
&privateMaps);
683704
return parseBlockArgRegion(parser, region, args);
684705
}
685706

@@ -776,8 +797,10 @@ struct PrivatePrintArgs {
776797
ValueRange vars;
777798
TypeRange types;
778799
ArrayAttr syms;
779-
PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
780-
: vars(vars), types(types), syms(syms) {}
800+
DenseI64ArrayAttr mapIndices;
801+
PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
802+
DenseI64ArrayAttr mapIndices)
803+
: vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
781804
};
782805
struct ReductionPrintArgs {
783806
ValueRange vars;
@@ -804,6 +827,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
804827
ValueRange argsSubrange,
805828
ValueRange operands, TypeRange types,
806829
ArrayAttr symbols = nullptr,
830+
DenseI64ArrayAttr mapIndices = nullptr,
807831
DenseBoolArrayAttr byref = nullptr) {
808832
if (argsSubrange.empty())
809833
return;
@@ -815,21 +839,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
815839
symbols = ArrayAttr::get(ctx, values);
816840
}
817841

842+
if (!mapIndices) {
843+
llvm::SmallVector<int64_t> values(operands.size(), -1);
844+
mapIndices = DenseI64ArrayAttr::get(ctx, values);
845+
}
846+
818847
if (!byref) {
819848
mlir::SmallVector<bool> values(operands.size(), false);
820849
byref = DenseBoolArrayAttr::get(ctx, values);
821850
}
822851

823-
llvm::interleaveComma(
824-
llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
825-
[&p](auto t) {
826-
auto [op, arg, sym, isByRef] = t;
827-
if (isByRef)
828-
p << "byref ";
829-
if (sym)
830-
p << sym << " ";
831-
p << op << " -> " << arg;
832-
});
852+
llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
853+
mapIndices.asArrayRef(),
854+
byref.asArrayRef()),
855+
p, [&p](auto t) {
856+
auto [op, arg, sym, map, isByRef] = t;
857+
if (isByRef)
858+
p << "byref ";
859+
if (sym)
860+
p << sym << " ";
861+
862+
p << op << " -> " << arg;
863+
864+
if (map != -1)
865+
p << " [map_idx=" << map << "]";
866+
});
833867
p << " : ";
834868
llvm::interleaveComma(types, p);
835869
p << ") ";
@@ -849,7 +883,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
849883
if (privateArgs)
850884
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
851885
privateArgs->vars, privateArgs->types,
852-
privateArgs->syms);
886+
privateArgs->syms, privateArgs->mapIndices);
853887
}
854888

855889
static void
@@ -859,7 +893,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
859893
if (reductionArgs)
860894
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
861895
reductionArgs->vars, reductionArgs->types,
862-
reductionArgs->syms, reductionArgs->byref);
896+
reductionArgs->syms, /*mapIndices=*/nullptr,
897+
reductionArgs->byref);
863898
}
864899

865900
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
@@ -891,12 +926,13 @@ static void printInReductionMapPrivateRegion(
891926
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
892927
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
893928
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
894-
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
929+
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
930+
DenseI64ArrayAttr privateMaps) {
895931
AllRegionPrintArgs args;
896932
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
897933
inReductionByref, inReductionSyms);
898934
args.mapArgs.emplace(mapVars, mapTypes);
899-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
935+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
900936
printBlockArgRegion(p, op, region, args);
901937
}
902938

@@ -908,7 +944,8 @@ static void printInReductionPrivateRegion(
908944
AllRegionPrintArgs args;
909945
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
910946
inReductionByref, inReductionSyms);
911-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
947+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
948+
/*mapIndices=*/nullptr);
912949
printBlockArgRegion(p, op, region, args);
913950
}
914951

@@ -921,7 +958,8 @@ static void printInReductionPrivateReductionRegion(
921958
AllRegionPrintArgs args;
922959
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
923960
inReductionByref, inReductionSyms);
924-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
961+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
962+
/*mapIndices=*/nullptr);
925963
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
926964
reductionSyms);
927965
printBlockArgRegion(p, op, region, args);
@@ -931,7 +969,8 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
931969
ValueRange privateVars, TypeRange privateTypes,
932970
ArrayAttr privateSyms) {
933971
AllRegionPrintArgs args;
934-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
972+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
973+
/*mapIndices=*/nullptr);
935974
printBlockArgRegion(p, op, region, args);
936975
}
937976

@@ -941,7 +980,8 @@ static void printPrivateReductionRegion(
941980
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
942981
ArrayAttr reductionSyms) {
943982
AllRegionPrintArgs args;
944-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
983+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
984+
/*mapIndices=*/nullptr);
945985
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
946986
reductionSyms);
947987
printBlockArgRegion(p, op, region, args);
@@ -1656,7 +1696,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
16561696
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
16571697
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
16581698
clauses.mapVars, clauses.nowait, clauses.privateVars,
1659-
makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
1699+
makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
1700+
/*private_maps=*/nullptr);
16601701
}
16611702

16621703
LogicalResult TargetOp::verify() {

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2750,6 +2750,30 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
27502750
return
27512751
}
27522752

2753+
// CHECK-LABEL: omp_target_private_with_map_idx
2754+
func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
2755+
%mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
2756+
%mapv2 = omp.map.info var_ptr(%map2 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
2757+
2758+
// CHECK: omp.target
2759+
2760+
// CHECK-SAME: map_entries(
2761+
// CHECK-SAME: %{{[^[:space:]]+}} -> %[[MAP1_ARG:[^[:space:]]+]],
2762+
// CHECK-SAME: %{{[^[:space:]]+}} -> %[[MAP2_ARG:[^[:space:]]+]]
2763+
// CHECK-SAME: : memref<?xi32>, memref<?xi32>
2764+
// CHECK-SAME: )
2765+
2766+
// CHECK-SAME: private(
2767+
// CHECK-SAME: @x.privatizer %{{[^[:space:]]+}} -> %[[PRIV_ARG:[^[:space:]]+]] [map_idx=1]
2768+
// CHECK-SAME: : !llvm.ptr
2769+
// CHECK-SAME: )
2770+
omp.target map_entries(%mapv1 -> %arg0, %mapv2 -> %arg1 : memref<?xi32>, memref<?xi32>) private(@x.privatizer %priv_var -> %priv_arg [map_idx=1] : !llvm.ptr) {
2771+
omp.terminator
2772+
}
2773+
2774+
return
2775+
}
2776+
27532777
// CHECK-LABEL: omp_loop
27542778
func.func @omp_loop(%lb : index, %ub : index, %step : index) {
27552779
// CHECK: omp.loop {

0 commit comments

Comments
 (0)