35
35
#include < mlir/IR/IRMapping.h>
36
36
#include < mlir/IR/OpDefinition.h>
37
37
#include < mlir/IR/PatternMatch.h>
38
+ #include < mlir/IR/Value.h>
38
39
#include < mlir/IR/Visitors.h>
39
40
#include < mlir/Interfaces/SideEffectInterfaces.h>
40
41
#include < mlir/Support/LLVM.h>
@@ -188,14 +189,19 @@ static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
188
189
if (isUserOutsideSR (user, parentOp, sr))
189
190
return true ;
190
191
191
- // Results of nested users cannot be used outside of the SR
192
+ // Now we know user is inside `sr`.
193
+
194
+ // Results of nested users cannot be used outside of `sr`.
192
195
if (user->getBlock () != srBlock)
193
196
continue ;
194
197
195
- // A non-safe to parallelize operation will be handled separately
198
+ // A non-safe to parallelize operation will be checked for uses outside
199
+ // separately.
196
200
if (!isSafeToParallelize (user))
197
201
continue ;
198
202
203
+ // For safe to parallelize operations, we need to check if there is a
204
+ // transitive use of `v` through them.
199
205
for (auto res : user->getResults ())
200
206
if (isTransitivelyUsedOutside (res, sr))
201
207
return true ;
@@ -242,7 +248,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
242
248
for (Operation &op : llvm::make_range (sr.begin , sr.end )) {
243
249
if (isSafeToParallelize (&op)) {
244
250
singleBuilder.clone (op, singleMapping);
245
- parallelBuilder.clone (op, rootMapping);
251
+ if (llvm::all_of (op.getOperands (), [&](Value opr) {
252
+ return rootMapping.contains (opr);
253
+ })) {
254
+ // Safe to parallelize operations which have all operands available in
255
+ // the root parallel block can be executed there.
256
+ parallelBuilder.clone (op, rootMapping);
257
+ } else {
258
+ // If any operand was not available, it means that there was no
259
+ // transitive use of a non-safe-to-parallelize operation outside `sr`.
260
+ // This means that there should be no transitive uses outside `sr` of
261
+ // `op`.
262
+ assert (llvm::all_of (op.getResults (), [&](Value v) {
263
+ return !isTransitivelyUsedOutside (v, sr);
264
+ }));
265
+ }
246
266
} else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
247
267
auto hoisted =
248
268
cast<fir::AllocaOp>(allocaBuilder.clone (*alloca, singleMapping));
@@ -252,7 +272,7 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
252
272
} else {
253
273
singleBuilder.clone (op, singleMapping);
254
274
// Prepare reloaded values for results of operations that cannot be
255
- // safely parallelized and which are used after the region `sr`
275
+ // safely parallelized and which are used after the region `sr`.
256
276
for (auto res : op.getResults ()) {
257
277
if (isTransitivelyUsedOutside (res, sr)) {
258
278
auto alloc = mapReloadedValue (res, allocaBuilder, singleBuilder,
0 commit comments