@@ -111,6 +111,38 @@ static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
111
111
return funcOp;
112
112
}
113
113
114
+ static bool isUserOutsideSR (Operation *user, Operation *parentOp,
115
+ SingleRegion sr) {
116
+ while (user->getParentOp () != parentOp)
117
+ user = user->getParentOp ();
118
+ return sr.begin ->getBlock () != user->getBlock () ||
119
+ !(user->isBeforeInBlock (&*sr.end ) && sr.begin ->isBeforeInBlock (user));
120
+ }
121
+
122
+ static bool isTransitivelyUsedOutside (Value v, SingleRegion sr) {
123
+ Block *srBlock = sr.begin ->getBlock ();
124
+ Operation *parentOp = srBlock->getParentOp ();
125
+
126
+ for (auto &use : v.getUses ()) {
127
+ Operation *user = use.getOwner ();
128
+ if (isUserOutsideSR (user, parentOp, sr))
129
+ return true ;
130
+
131
+ // Results of nested users cannot be used outside of the SR
132
+ if (user->getBlock () != srBlock)
133
+ continue ;
134
+
135
+ // A non-safe to parallelize operation will be handled separately
136
+ if (!isSafeToParallelize (user))
137
+ continue ;
138
+
139
+ for (auto res : user->getResults ())
140
+ if (isTransitivelyUsedOutside (res, sr))
141
+ return true ;
142
+ }
143
+ return false ;
144
+ }
145
+
114
146
static void parallelizeRegion (Region &sourceRegion, Region &targetRegion,
115
147
IRMapping &rootMapping, Location loc) {
116
148
Operation *parentOp = sourceRegion.getParentOp ();
@@ -166,19 +198,11 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
166
198
// Prepare reloaded values for results of operations that cannot be
167
199
// safely parallelized and which are used after the region `sr`
168
200
for (auto res : op.getResults ()) {
169
- for (auto &use : res.getUses ()) {
170
- Operation *user = use.getOwner ();
171
- while (user->getParentOp () != parentOp)
172
- user = user->getParentOp ();
173
- // TODO we need to look at transitively used vals
174
- if (true || !(user->isBeforeInBlock (&*sr.end ) &&
175
- sr.begin ->isBeforeInBlock (user))) {
176
- auto alloc =
177
- mapReloadedValue (use.get (), allocaBuilder, singleBuilder,
178
- parallelBuilder, singleMapping);
179
- if (alloc)
180
- copyPrivate.push_back (alloc);
181
- }
201
+ if (isTransitivelyUsedOutside (res, sr)) {
202
+ auto alloc = mapReloadedValue (res, allocaBuilder, singleBuilder,
203
+ parallelBuilder, singleMapping);
204
+ if (alloc)
205
+ copyPrivate.push_back (alloc);
182
206
}
183
207
}
184
208
}
@@ -236,7 +260,6 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
236
260
omp::SingleOperands singleOperands;
237
261
if (isLast)
238
262
singleOperands.nowait = rootBuilder.getUnitAttr ();
239
- auto insPtAtSingle = rootBuilder.saveInsertionPoint ();
240
263
singleOperands.copyprivateVars =
241
264
moveToSingle (std::get<SingleRegion>(opOrSingle), allocaBuilder,
242
265
singleBuilder, parallelBuilder);
0 commit comments