Skip to content

Commit 6bd0a22

Browse files
authored
[flang][OpenMP] Handle expressions in target ... do loop control (#107)
When emitting the ops required to compute the target loop's trip count, flang might emit ops outside the target regions that operands defined inside the region. This is fixed up by `HostClausesInsertionGuard` already. However, the current support only handles simple cases. If the loop control contains more elaborate expressions, the fix up logic does not handle it properly. This PR handles such cases.
1 parent 0ef3716 commit 6bd0a22

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "flang/Parser/parse-tree.h"
3434
#include "flang/Semantics/openmp-directive-sets.h"
3535
#include "flang/Semantics/tools.h"
36+
#include "mlir/Analysis/SliceAnalysis.h"
3637
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
3738
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3839
#include "mlir/Transforms/RegionUtils.h"
@@ -123,9 +124,100 @@ class HostClausesInsertionGuard {
123124
mlir::OpBuilder::InsertPoint ip;
124125
mlir::omp::TargetOp targetOp;
125126

127+
// Finds the list of op operands that escape the target op's region; that is:
128+
// the operands that are used outside the target op but defined inside it.
129+
void
130+
findEscapingOpOperands(llvm::DenseSet<mlir::OpOperand *> &escapingOperands) {
131+
if (!targetOp)
132+
return;
133+
134+
mlir::Region *targetParentRegion = targetOp->getParentRegion();
135+
assert(targetParentRegion != nullptr &&
136+
"Expected omp.target op to be nested in a parent region.");
137+
138+
// Walk the parent region in pre-order to make sure we visit `targetOp`
139+
// before its nested ops.
140+
targetParentRegion->walk<mlir::WalkOrder::PreOrder>(
141+
[&](mlir::Operation *op) {
142+
// Once we come across `targetOp`, we interrupt the walk since we
143+
// already visited all the ops that come before it in the region.
144+
if (op == targetOp)
145+
return mlir::WalkResult::interrupt();
146+
147+
for (mlir::OpOperand &operand : op->getOpOperands()) {
148+
mlir::Operation *operandDefiningOp = operand.get().getDefiningOp();
149+
150+
if (operandDefiningOp == nullptr)
151+
continue;
152+
153+
auto parentTargetOp =
154+
operandDefiningOp->getParentOfType<mlir::omp::TargetOp>();
155+
156+
if (parentTargetOp != targetOp)
157+
continue;
158+
159+
escapingOperands.insert(&operand);
160+
}
161+
162+
return mlir::WalkResult::advance();
163+
});
164+
}
165+
166+
// For an escaping operand, clone its use-def chain (i.e. its backward slice)
167+
// outside the target region.
168+
//
169+
// \return the last op in the chain (this is the op that defines the escaping
170+
// operand).
171+
mlir::Operation *
172+
cloneOperandSliceOutsideTargetOp(mlir::OpOperand *escapingOperand) {
173+
mlir::Operation *operandDefiningOp = escapingOperand->get().getDefiningOp();
174+
llvm::SetVector<mlir::Operation *> backwardSlice;
175+
mlir::BackwardSliceOptions sliceOptions;
176+
sliceOptions.inclusive = true;
177+
mlir::getBackwardSlice(operandDefiningOp, &backwardSlice, sliceOptions);
178+
179+
auto ip = builder.saveInsertionPoint();
180+
181+
mlir::IRMapping mapper;
182+
builder.setInsertionPoint(escapingOperand->getOwner());
183+
mlir::Operation *lastSliceOp;
184+
185+
for (auto *op : backwardSlice)
186+
lastSliceOp = builder.clone(*op, mapper);
187+
188+
builder.restoreInsertionPoint(ip);
189+
return lastSliceOp;
190+
}
191+
126192
/// Fixup any uses of target region block arguments that we have just created
127193
/// outside of the target region, and replace them by their host values.
128194
void fixupExtractedHostOps() {
195+
llvm::DenseSet<mlir::OpOperand *> escapingOperands;
196+
findEscapingOpOperands(escapingOperands);
197+
198+
for (mlir::OpOperand *operand : escapingOperands) {
199+
mlir::Operation *operandDefiningOp = operand->get().getDefiningOp();
200+
assert(operandDefiningOp != nullptr &&
201+
"Expected escaping operand to have a defining op (i.e. not to be "
202+
"a block argument)");
203+
mlir::Operation *lastSliceOp = cloneOperandSliceOutsideTargetOp(operand);
204+
205+
// Find the index of the operand in the list of results produced by its
206+
// defining op.
207+
unsigned operandResultIdx = 0;
208+
for (auto [idx, res] : llvm::enumerate(operandDefiningOp->getResults())) {
209+
if (res == operand->get()) {
210+
operandResultIdx = idx;
211+
break;
212+
}
213+
}
214+
215+
// Replace the escaping operand with the corresponding value from the
216+
// op that we cloned outside the target op.
217+
operand->getOwner()->setOperand(operand->getOperandNumber(),
218+
lastSliceOp->getResult(operandResultIdx));
219+
}
220+
129221
auto useOutsideTargetRegion = [](mlir::OpOperand &operand) {
130222
if (mlir::Operation *owner = operand.getOwner())
131223
return !owner->getParentOfType<mlir::omp::TargetOp>();
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
! Verifies that if expressions are used to compute a target parallel loop, that
2+
! no values escape the target region when flang emits the ops corresponding to
3+
! these expressions (for example the compute the trip count for the target region).
4+
5+
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
6+
7+
subroutine foo(upper_bound)
8+
implicit none
9+
integer :: upper_bound
10+
integer :: nodes(1 : upper_bound)
11+
integer :: i
12+
13+
!$omp target teams distribute parallel do
14+
do i = 1, ubound(nodes,1)
15+
nodes(i) = i
16+
end do
17+
!$omp end target teams distribute parallel do
18+
end subroutine
19+
20+
! CHECK: func.func @_QPfoo(%[[FUNC_ARG:.*]]: !fir.ref<i32> {fir.bindc_name = "upper_bound"}) {
21+
! CHECK: %[[UB_ALLOC:.*]] = fir.alloca i32
22+
! CHECK: fir.dummy_scope : !fir.dscope
23+
! CHECK: %[[UB_DECL:.*]]:2 = hlfir.declare %[[FUNC_ARG]] {{.*}} {uniq_name = "_QFfooEupper_bound"}
24+
25+
! CHECK: omp.map.info
26+
! CHECK: omp.map.info
27+
! CHECK: omp.map.info
28+
29+
! Verify that we load from the original/host allocation of the `upper_bound`
30+
! variable rather than the corresponding target region arg.
31+
32+
! CHECK: fir.load %[[UB_ALLOC]] : !fir.ref<i32>
33+
! CHECK: omp.target
34+
35+
! CHECK: }

0 commit comments

Comments
 (0)