@@ -252,25 +252,12 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
252
252
if (loop.lowerBound ().empty ())
253
253
return failure ();
254
254
255
- if (loop.getNumLoops () != 1 )
256
- return opInst.emitOpError (" collapsed loops not yet supported" );
257
-
258
255
// Static is the default.
259
256
omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static;
260
257
if (loop.schedule_val ().hasValue ())
261
258
schedule =
262
259
*omp::symbolizeClauseScheduleKind (loop.schedule_val ().getValue ());
263
260
264
- // Find the loop configuration.
265
- llvm::Value *lowerBound = moduleTranslation.lookupValue (loop.lowerBound ()[0 ]);
266
- llvm::Value *upperBound = moduleTranslation.lookupValue (loop.upperBound ()[0 ]);
267
- llvm::Value *step = moduleTranslation.lookupValue (loop.step ()[0 ]);
268
- llvm::Type *ivType = step->getType ();
269
- llvm::Value *chunk =
270
- loop.schedule_chunk_var ()
271
- ? moduleTranslation.lookupValue (loop.schedule_chunk_var ())
272
- : llvm::ConstantInt::get (ivType, 1 );
273
-
274
261
// Set up the source location value for OpenMP runtime.
275
262
llvm::DISubprogram *subprogram =
276
263
builder.GetInsertBlock ()->getParent ()->getSubprogram ();
@@ -279,22 +266,29 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
279
266
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder.saveIP (),
280
267
llvm::DebugLoc (diLoc));
281
268
282
- // Generator of the canonical loop body. Produces an SESE region of basic
283
- // blocks.
269
+ // Generator of the canonical loop body.
284
270
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
285
271
// relying on captured variables.
272
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
273
+ SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
286
274
LogicalResult bodyGenStatus = success ();
287
275
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
288
- llvm::IRBuilder<>::InsertPointGuard guard (builder);
289
-
290
276
// Make sure further conversions know about the induction variable.
291
- moduleTranslation.mapValue (loop.getRegion ().front ().getArgument (0 ), iv);
277
+ moduleTranslation.mapValue (
278
+ loop.getRegion ().front ().getArgument (loopInfos.size ()), iv);
279
+
280
+ // Capture the body insertion point for use in nested loops. BodyIP of the
281
+ // CanonicalLoopInfo always points to the beginning of the entry block of
282
+ // the body.
283
+ bodyInsertPoints.push_back (ip);
284
+
285
+ if (loopInfos.size () != loop.getNumLoops () - 1 )
286
+ return ;
292
287
288
+ // Convert the body of the loop.
293
289
llvm::BasicBlock *entryBlock = ip.getBlock ();
294
290
llvm::BasicBlock *exitBlock =
295
291
entryBlock->splitBasicBlock (ip.getPoint (), " omp.wsloop.exit" );
296
-
297
- // Convert the body of the loop.
298
292
convertOmpOpRegions (loop.region (), " omp.wsloop.region" , *entryBlock,
299
293
*exitBlock, builder, moduleTranslation, bodyGenStatus);
300
294
};
@@ -303,21 +297,49 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
303
297
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
304
298
// i.e. it has a positive step, uses signed integer semantics. Reconsider
305
299
// this code when WsLoop clearly supports more cases.
300
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
301
+ for (unsigned i = 0 , e = loop.getNumLoops (); i < e; ++i) {
302
+ llvm::Value *lowerBound =
303
+ moduleTranslation.lookupValue (loop.lowerBound ()[i]);
304
+ llvm::Value *upperBound =
305
+ moduleTranslation.lookupValue (loop.upperBound ()[i]);
306
+ llvm::Value *step = moduleTranslation.lookupValue (loop.step ()[i]);
307
+
308
+ // Make sure loop trip count are emitted in the preheader of the outermost
309
+ // loop at the latest so that they are all available for the new collapsed
310
+ // loop will be created below.
311
+ llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
312
+ llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP ;
313
+ if (i != 0 ) {
314
+ loc = llvm::OpenMPIRBuilder::LocationDescription (bodyInsertPoints.back (),
315
+ llvm::DebugLoc (diLoc));
316
+ computeIP = loopInfos.front ()->getPreheaderIP ();
317
+ }
318
+ loopInfos.push_back (ompBuilder->createCanonicalLoop (
319
+ loc, bodyGen, lowerBound, upperBound, step,
320
+ /* IsSigned=*/ true , loop.inclusive (), computeIP));
321
+
322
+ if (failed (bodyGenStatus))
323
+ return failure ();
324
+ }
325
+
326
+ // Collapse loops. Store the insertion point because LoopInfos may get
327
+ // invalidated.
328
+ llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front ()->getAfterIP ();
306
329
llvm::CanonicalLoopInfo *loopInfo =
307
- moduleTranslation.getOpenMPBuilder ()->createCanonicalLoop (
308
- ompLoc, bodyGen, lowerBound, upperBound, step, /* IsSigned=*/ true ,
309
- /* InclusiveStop=*/ loop.inclusive ());
310
- if (failed (bodyGenStatus))
311
- return failure ();
330
+ ompBuilder->collapseLoops (diLoc, loopInfos, {});
312
331
332
+ // Find the loop configuration.
333
+ llvm::Type *ivType = loopInfo->getIndVar ()->getType ();
334
+ llvm::Value *chunk =
335
+ loop.schedule_chunk_var ()
336
+ ? moduleTranslation.lookupValue (loop.schedule_chunk_var ())
337
+ : llvm::ConstantInt::get (ivType, 1 );
313
338
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
314
339
findAllocaInsertPoint (builder, moduleTranslation);
315
- llvm::OpenMPIRBuilder::InsertPointTy afterIP;
316
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
317
340
if (schedule == omp::ClauseScheduleKind::Static) {
318
- loopInfo = ompBuilder->createStaticWorkshareLoop (ompLoc, loopInfo, allocaIP,
319
- !loop.nowait (), chunk);
320
- afterIP = loopInfo->getAfterIP ();
341
+ ompBuilder->createStaticWorkshareLoop (ompLoc, loopInfo, allocaIP,
342
+ !loop.nowait (), chunk);
321
343
} else {
322
344
llvm::omp::OMPScheduleType schedType;
323
345
switch (schedule) {
@@ -338,11 +360,14 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
338
360
break ;
339
361
}
340
362
341
- afterIP = ompBuilder->createDynamicWorkshareLoop (
342
- ompLoc, loopInfo, allocaIP, schedType, !loop.nowait (), chunk);
363
+ ompBuilder->createDynamicWorkshareLoop (ompLoc, loopInfo, allocaIP,
364
+ schedType, !loop.nowait (), chunk);
343
365
}
344
366
345
- // Continue building IR after the loop.
367
+ // Continue building IR after the loop. Note that the LoopInfo returned by
368
+ // `collapseLoops` points inside the outermost loop and is intended for
369
+ // potential further loop transformations. Use the insertion point stored
370
+ // before collapsing loops instead.
346
371
builder.restoreIP (afterIP);
347
372
return success ();
348
373
}
0 commit comments