Skip to content

Commit c5b5369

Browse files
committed
Fix bug and add better clarification comments
1 parent 4c207b5 commit c5b5369

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <mlir/IR/IRMapping.h>
3636
#include <mlir/IR/OpDefinition.h>
3737
#include <mlir/IR/PatternMatch.h>
38+
#include <mlir/IR/Value.h>
3839
#include <mlir/IR/Visitors.h>
3940
#include <mlir/Interfaces/SideEffectInterfaces.h>
4041
#include <mlir/Support/LLVM.h>
@@ -188,14 +189,19 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
188189
if (isUserOutsideSR(user, parentOp, sr))
189190
return true;
190191

191-
// Results of nested users cannot be used outside of the SR
192+
// Now we know user is inside `sr`.
193+
194+
// Results of nested users cannot be used outside of `sr`.
192195
if (user->getBlock() != srBlock)
193196
continue;
194197

195-
// A non-safe to parallelize operation will be handled separately
198+
// A non-safe to parallelize operation will be checked for uses outside
199+
// separately.
196200
if (!isSafeToParallelize(user))
197201
continue;
198202

203+
// For safe to parallelize operations, we need to check if there is a
204+
// transitive use of `v` through them.
199205
for (auto res : user->getResults())
200206
if (isTransitivelyUsedOutside(res, sr))
201207
return true;
@@ -242,7 +248,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
242248
for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
243249
if (isSafeToParallelize(&op)) {
244250
singleBuilder.clone(op, singleMapping);
245-
parallelBuilder.clone(op, rootMapping);
251+
if (llvm::all_of(op.getOperands(), [&](Value opr) {
252+
return rootMapping.contains(opr);
253+
})) {
254+
// Safe to parallelize operations which have all operands available in
255+
// the root parallel block can be executed there.
256+
parallelBuilder.clone(op, rootMapping);
257+
} else {
258+
// If any operand was not available, it means that there was no
259+
// transitive use of a non-safe-to-parallelize operation outside `sr`.
260+
// This means that there should be no transitive uses outside `sr` of
261+
// `op`.
262+
assert(llvm::all_of(op.getResults(), [&](Value v) {
263+
return !isTransitivelyUsedOutside(v, sr);
264+
}));
265+
}
246266
} else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
247267
auto hoisted =
248268
cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
@@ -252,7 +272,7 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
252272
} else {
253273
singleBuilder.clone(op, singleMapping);
254274
// Prepare reloaded values for results of operations that cannot be
255-
// safely parallelized and which are used after the region `sr`
275+
// safely parallelized and which are used after the region `sr`.
256276
for (auto res : op.getResults()) {
257277
if (isTransitivelyUsedOutside(res, sr)) {
258278
auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: fir-opt --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
2+
3+
// Check that the safe to parallelize `fir.declare` op will not be parallelized
4+
// due to its operand %alloc not being reloaded outside the omp.single.
5+
6+
func.func @foo() {
7+
%c0 = arith.constant 0 : index
8+
omp.workshare {
9+
%alloc = fir.allocmem !fir.array<?xf32>, %c0 {bindc_name = ".tmp.forall", uniq_name = ""}
10+
%shape = fir.shape %c0 : (index) -> !fir.shape<1>
11+
%declare = fir.declare %alloc(%shape) {uniq_name = ".tmp.forall"} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.heap<!fir.array<?xf32>>
12+
fir.freemem %alloc : !fir.heap<!fir.array<?xf32>>
13+
omp.terminator
14+
}
15+
return
16+
}

0 commit comments

Comments
 (0)