Skip to content

Commit c4c1030

Browse files
committed
[mlir] support collapsed loops in OpenMP-to-LLVM translation
Reviewed By: Meinersbur Differential Revision: https://reviews.llvm.org/D105706
1 parent 17e9732 commit c4c1030

File tree

2 files changed

+119
-33
lines changed

2 files changed

+119
-33
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 58 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -252,25 +252,12 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
252252
if (loop.lowerBound().empty())
253253
return failure();
254254

255-
if (loop.getNumLoops() != 1)
256-
return opInst.emitOpError("collapsed loops not yet supported");
257-
258255
// Static is the default.
259256
omp::ClauseScheduleKind schedule = omp::ClauseScheduleKind::Static;
260257
if (loop.schedule_val().hasValue())
261258
schedule =
262259
*omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue());
263260

264-
// Find the loop configuration.
265-
llvm::Value *lowerBound = moduleTranslation.lookupValue(loop.lowerBound()[0]);
266-
llvm::Value *upperBound = moduleTranslation.lookupValue(loop.upperBound()[0]);
267-
llvm::Value *step = moduleTranslation.lookupValue(loop.step()[0]);
268-
llvm::Type *ivType = step->getType();
269-
llvm::Value *chunk =
270-
loop.schedule_chunk_var()
271-
? moduleTranslation.lookupValue(loop.schedule_chunk_var())
272-
: llvm::ConstantInt::get(ivType, 1);
273-
274261
// Set up the source location value for OpenMP runtime.
275262
llvm::DISubprogram *subprogram =
276263
builder.GetInsertBlock()->getParent()->getSubprogram();
@@ -279,22 +266,29 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
279266
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
280267
llvm::DebugLoc(diLoc));
281268

282-
// Generator of the canonical loop body. Produces an SESE region of basic
283-
// blocks.
269+
// Generator of the canonical loop body.
284270
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
285271
// relying on captured variables.
272+
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
273+
SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
286274
LogicalResult bodyGenStatus = success();
287275
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
288-
llvm::IRBuilder<>::InsertPointGuard guard(builder);
289-
290276
// Make sure further conversions know about the induction variable.
291-
moduleTranslation.mapValue(loop.getRegion().front().getArgument(0), iv);
277+
moduleTranslation.mapValue(
278+
loop.getRegion().front().getArgument(loopInfos.size()), iv);
279+
280+
// Capture the body insertion point for use in nested loops. BodyIP of the
281+
// CanonicalLoopInfo always points to the beginning of the entry block of
282+
// the body.
283+
bodyInsertPoints.push_back(ip);
284+
285+
if (loopInfos.size() != loop.getNumLoops() - 1)
286+
return;
292287

288+
// Convert the body of the loop.
293289
llvm::BasicBlock *entryBlock = ip.getBlock();
294290
llvm::BasicBlock *exitBlock =
295291
entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
296-
297-
// Convert the body of the loop.
298292
convertOmpOpRegions(loop.region(), "omp.wsloop.region", *entryBlock,
299293
*exitBlock, builder, moduleTranslation, bodyGenStatus);
300294
};
@@ -303,21 +297,49 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
303297
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
304298
// i.e. it has a positive step, uses signed integer semantics. Reconsider
305299
// this code when WsLoop clearly supports more cases.
300+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
301+
for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
302+
llvm::Value *lowerBound =
303+
moduleTranslation.lookupValue(loop.lowerBound()[i]);
304+
llvm::Value *upperBound =
305+
moduleTranslation.lookupValue(loop.upperBound()[i]);
306+
llvm::Value *step = moduleTranslation.lookupValue(loop.step()[i]);
307+
308+
// Make sure loop trip count are emitted in the preheader of the outermost
309+
// loop at the latest so that they are all available for the new collapsed
310+
// loop will be created below.
311+
llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
312+
llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
313+
if (i != 0) {
314+
loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
315+
llvm::DebugLoc(diLoc));
316+
computeIP = loopInfos.front()->getPreheaderIP();
317+
}
318+
loopInfos.push_back(ompBuilder->createCanonicalLoop(
319+
loc, bodyGen, lowerBound, upperBound, step,
320+
/*IsSigned=*/true, loop.inclusive(), computeIP));
321+
322+
if (failed(bodyGenStatus))
323+
return failure();
324+
}
325+
326+
// Collapse loops. Store the insertion point because LoopInfos may get
327+
// invalidated.
328+
llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
306329
llvm::CanonicalLoopInfo *loopInfo =
307-
moduleTranslation.getOpenMPBuilder()->createCanonicalLoop(
308-
ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
309-
/*InclusiveStop=*/loop.inclusive());
310-
if (failed(bodyGenStatus))
311-
return failure();
330+
ompBuilder->collapseLoops(diLoc, loopInfos, {});
312331

332+
// Find the loop configuration.
333+
llvm::Type *ivType = loopInfo->getIndVar()->getType();
334+
llvm::Value *chunk =
335+
loop.schedule_chunk_var()
336+
? moduleTranslation.lookupValue(loop.schedule_chunk_var())
337+
: llvm::ConstantInt::get(ivType, 1);
313338
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
314339
findAllocaInsertPoint(builder, moduleTranslation);
315-
llvm::OpenMPIRBuilder::InsertPointTy afterIP;
316-
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
317340
if (schedule == omp::ClauseScheduleKind::Static) {
318-
loopInfo = ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
319-
!loop.nowait(), chunk);
320-
afterIP = loopInfo->getAfterIP();
341+
ompBuilder->createStaticWorkshareLoop(ompLoc, loopInfo, allocaIP,
342+
!loop.nowait(), chunk);
321343
} else {
322344
llvm::omp::OMPScheduleType schedType;
323345
switch (schedule) {
@@ -338,11 +360,14 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
338360
break;
339361
}
340362

341-
afterIP = ompBuilder->createDynamicWorkshareLoop(
342-
ompLoc, loopInfo, allocaIP, schedType, !loop.nowait(), chunk);
363+
ompBuilder->createDynamicWorkshareLoop(ompLoc, loopInfo, allocaIP,
364+
schedType, !loop.nowait(), chunk);
343365
}
344366

345-
// Continue building IR after the loop.
367+
// Continue building IR after the loop. Note that the LoopInfo returned by
368+
// `collapseLoops` points inside the outermost loop and is intended for
369+
// potential further loop transformations. Use the insertion point stored
370+
// before collapsing loops instead.
346371
builder.restoreIP(afterIP);
347372
return success();
348373
}

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,8 @@ llvm.func @test_omp_wsloop_guided(%lb : i64, %ub : i64, %step : i64) -> () {
467467
llvm.return
468468
}
469469

470+
// -----
471+
470472
// CHECK-LABEL: @omp_critical
471473
llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
472474
// CHECK: call void @__kmpc_critical_with_hint({{.*}}critical_user_.var{{.*}}, i32 0)
@@ -488,6 +490,65 @@ llvm.func @omp_critical(%x : !llvm.ptr<i32>, %xval : i32) -> () {
488490
omp.terminator
489491
}
490492
// CHECK: call void @__kmpc_end_critical({{.*}}critical_user_mutex.var{{.*}})
493+
llvm.return
494+
}
495+
496+
// -----
491497

498+
// Check that the loop bounds are emitted in the correct location in case of
499+
// collapse. This only checks the overall shape of the IR, detailed checking
500+
// is done by the OpenMPIRBuilder.
501+
502+
// CHECK-LABEL: @collapse_wsloop
503+
// CHECK: i32* noalias %[[TIDADDR:[0-9A-Za-z.]*]]
504+
// CHECK: load i32, i32* %[[TIDADDR]]
505+
// CHECK: store
506+
// CHECK: load
507+
// CHECK: %[[LB0:.*]] = load i32
508+
// CHECK: %[[UB0:.*]] = load i32
509+
// CHECK: %[[STEP0:.*]] = load i32
510+
// CHECK: %[[LB1:.*]] = load i32
511+
// CHECK: %[[UB1:.*]] = load i32
512+
// CHECK: %[[STEP1:.*]] = load i32
513+
// CHECK: %[[LB2:.*]] = load i32
514+
// CHECK: %[[UB2:.*]] = load i32
515+
// CHECK: %[[STEP2:.*]] = load i32
516+
llvm.func @collapse_wsloop(
517+
%0: i32, %1: i32, %2: i32,
518+
%3: i32, %4: i32, %5: i32,
519+
%6: i32, %7: i32, %8: i32,
520+
%20: !llvm.ptr<i32>) {
521+
omp.parallel {
522+
// CHECK: icmp slt i32 %[[LB0]], 0
523+
// CHECK-COUNT-4: select
524+
// CHECK: %[[TRIPCOUNT0:.*]] = select
525+
// CHECK: br label %[[PREHEADER:.*]]
526+
//
527+
// CHECK: [[PREHEADER]]:
528+
// CHECK: icmp slt i32 %[[LB1]], 0
529+
// CHECK-COUNT-4: select
530+
// CHECK: %[[TRIPCOUNT1:.*]] = select
531+
// CHECK: icmp slt i32 %[[LB2]], 0
532+
// CHECK-COUNT-4: select
533+
// CHECK: %[[TRIPCOUNT2:.*]] = select
534+
// CHECK: %[[PROD:.*]] = mul nuw i32 %[[TRIPCOUNT0]], %[[TRIPCOUNT1]]
535+
// CHECK: %[[TOTAL:.*]] = mul nuw i32 %[[PROD]], %[[TRIPCOUNT2]]
536+
// CHECK: br label %[[COLLAPSED_PREHEADER:.*]]
537+
//
538+
// CHECK: [[COLLAPSED_PREHEADER]]:
539+
// CHECK: store i32 0, i32*
540+
// CHECK: %[[TOTAL_SUB_1:.*]] = sub i32 %[[TOTAL]], 1
541+
// CHECK: store i32 %[[TOTAL_SUB_1]], i32*
542+
// CHECK: call void @__kmpc_for_static_init_4u
543+
omp.wsloop (%arg0, %arg1, %arg2) : i32 = (%0, %1, %2) to (%3, %4, %5) step (%6, %7, %8) collapse(3) {
544+
%31 = llvm.load %20 : !llvm.ptr<i32>
545+
%32 = llvm.add %31, %arg0 : i32
546+
%33 = llvm.add %32, %arg1 : i32
547+
%34 = llvm.add %33, %arg2 : i32
548+
llvm.store %34, %20 : !llvm.ptr<i32>
549+
omp.yield
550+
}
551+
omp.terminator
552+
}
492553
llvm.return
493554
}

0 commit comments

Comments
 (0)