Skip to content

[Flang][OpenMP] DISTRIBUTE PARALLEL DO lowering #106207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 114 additions & 31 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ struct OpWithBodyGenInfo {
return *this;
}

OpWithBodyGenInfo &setGenSkeletonOnly(bool value) {
genSkeletonOnly = value;
return *this;
}

/// [inout] converter to use for the clauses.
lower::AbstractConverter &converter;
/// [in] Symbol table
Expand All @@ -538,6 +543,9 @@ struct OpWithBodyGenInfo {
/// [in] if provided, emits the op's region entry. Otherwise, an emtpy block
/// is created in the region.
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr;
/// [in] if set to `true`, skip generating nested evaluations and dispatching
/// any further leaf constructs.
bool genSkeletonOnly = false;
};

/// Create the body (block) for an OpenMP Operation.
Expand Down Expand Up @@ -600,20 +608,22 @@ static void createBodyOfOp(mlir::Operation &op, const OpWithBodyGenInfo &info,
}
}

if (ConstructQueue::const_iterator next = std::next(item);
next != queue.end()) {
genOMPDispatch(info.converter, info.symTable, info.semaCtx, info.eval,
info.loc, queue, next);
} else {
// genFIR(Evaluation&) tries to patch up unterminated blocks, causing
// a lot of complications for our approach if the terminator generation
// is delayed past this point. Insert a temporary terminator here, then
// delete it.
firOpBuilder.setInsertionPointToEnd(&op.getRegion(0).back());
auto *temp = lower::genOpenMPTerminator(firOpBuilder, &op, info.loc);
firOpBuilder.setInsertionPointAfter(marker);
genNestedEvaluations(info.converter, info.eval);
temp->erase();
if (!info.genSkeletonOnly) {
if (ConstructQueue::const_iterator next = std::next(item);
next != queue.end()) {
genOMPDispatch(info.converter, info.symTable, info.semaCtx, info.eval,
info.loc, queue, next);
} else {
// genFIR(Evaluation&) tries to patch up unterminated blocks, causing
// a lot of complications for our approach if the terminator generation
// is delayed past this point. Insert a temporary terminator here, then
// delete it.
firOpBuilder.setInsertionPointToEnd(&op.getRegion(0).back());
auto *temp = lower::genOpenMPTerminator(firOpBuilder, &op, info.loc);
firOpBuilder.setInsertionPointAfter(marker);
genNestedEvaluations(info.converter, info.eval);
temp->erase();
}
}

// Get or create a unique exiting block from the given region, or
Expand Down Expand Up @@ -1445,7 +1455,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
const ConstructQueue &queue, ConstructQueue::const_iterator item,
mlir::omp::ParallelOperands &clauseOps,
llvm::ArrayRef<const semantics::Symbol *> reductionSyms,
llvm::ArrayRef<mlir::Type> reductionTypes) {
llvm::ArrayRef<mlir::Type> reductionTypes,
DataSharingProcessor *dsp, bool isComposite = false) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

auto reductionCallback = [&](mlir::Operation *op) {
Expand All @@ -1457,17 +1468,17 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_parallel)
.setClauses(&item->clauses)
.setGenRegionEntryCb(reductionCallback);

if (!enableDelayedPrivatization)
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item,
clauseOps);

DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
lower::omp::isLastItemInQueue(item, queue),
/*useDelayedPrivatization=*/true, &symTable);
dsp.processStep1(&clauseOps);
.setGenRegionEntryCb(reductionCallback)
.setGenSkeletonOnly(isComposite);

if (!enableDelayedPrivatization) {
auto parallelOp =
genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
parallelOp.setComposite(isComposite);
return parallelOp;
}

assert(dsp && "expected valid DataSharingProcessor");
auto genRegionEntryCB = [&](mlir::Operation *op) {
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);

Expand All @@ -1491,8 +1502,8 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
allRegionArgLocs);

llvm::SmallVector<const semantics::Symbol *> allSymbols(reductionSyms);
allSymbols.append(dsp.getDelayedPrivSymbols().begin(),
dsp.getDelayedPrivSymbols().end());
allSymbols.append(dsp->getDelayedPrivSymbols().begin(),
dsp->getDelayedPrivSymbols().end());

unsigned argIdx = 0;
for (const semantics::Symbol *arg : allSymbols) {
Expand All @@ -1519,8 +1530,11 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return allSymbols;
};

genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(&dsp);
return genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
genInfo.setGenRegionEntryCb(genRegionEntryCB).setDataSharingProcessor(dsp);
auto parallelOp =
genOpWithBody<mlir::omp::ParallelOp>(genInfo, queue, item, clauseOps);
parallelOp.setComposite(isComposite);
return parallelOp;
}

/// This breaks the normal prototype of the gen*Op functions: adding the
Expand Down Expand Up @@ -2005,8 +2019,16 @@ static void genStandaloneParallel(lower::AbstractConverter &converter,
genParallelClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
reductionTypes, reductionSyms);

std::optional<DataSharingProcessor> dsp;
if (enableDelayedPrivatization) {
dsp.emplace(converter, semaCtx, item->clauses, eval,
lower::omp::isLastItemInQueue(item, queue),
/*useDelayedPrivatization=*/true, &symTable);
dsp->processStep1(&clauseOps);
}
genParallelOp(converter, symTable, semaCtx, eval, loc, queue, item, clauseOps,
reductionSyms, reductionTypes);
reductionSyms, reductionTypes,
enableDelayedPrivatization ? &dsp.value() : nullptr);
}

static void genStandaloneSimd(lower::AbstractConverter &converter,
Expand Down Expand Up @@ -2058,8 +2080,69 @@ static void genCompositeDistributeParallelDo(
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
mlir::Location loc, const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
lower::StatementContext stmtCtx;

assert(std::distance(item, queue.end()) == 3 && "Invalid leaf constructs");
TODO(loc, "Composite DISTRIBUTE PARALLEL DO");
ConstructQueue::const_iterator distributeItem = item;
ConstructQueue::const_iterator parallelItem = std::next(distributeItem);
ConstructQueue::const_iterator doItem = std::next(parallelItem);

// Create parent omp.parallel first.
mlir::omp::ParallelOperands parallelClauseOps;
llvm::SmallVector<const semantics::Symbol *> parallelReductionSyms;
llvm::SmallVector<mlir::Type> parallelReductionTypes;
genParallelClauses(converter, semaCtx, stmtCtx, parallelItem->clauses, loc,
parallelClauseOps, parallelReductionTypes,
parallelReductionSyms);

DataSharingProcessor dsp(converter, semaCtx, doItem->clauses, eval,
/*shouldCollectPreDeterminedSymbols=*/true,
/*useDelayedPrivatization=*/true, &symTable);
dsp.processStep1(&parallelClauseOps);

genParallelOp(converter, symTable, semaCtx, eval, loc, queue, parallelItem,
parallelClauseOps, parallelReductionSyms,
parallelReductionTypes, &dsp, /*isComposite=*/true);

// Clause processing.
mlir::omp::DistributeOperands distributeClauseOps;
genDistributeClauses(converter, semaCtx, stmtCtx, distributeItem->clauses,
loc, distributeClauseOps);

mlir::omp::WsloopOperands wsloopClauseOps;
llvm::SmallVector<const semantics::Symbol *> wsloopReductionSyms;
llvm::SmallVector<mlir::Type> wsloopReductionTypes;
genWsloopClauses(converter, semaCtx, stmtCtx, doItem->clauses, loc,
wsloopClauseOps, wsloopReductionTypes, wsloopReductionSyms);

mlir::omp::LoopNestOperands loopNestClauseOps;
llvm::SmallVector<const semantics::Symbol *> iv;
genLoopNestClauses(converter, semaCtx, eval, doItem->clauses, loc,
loopNestClauseOps, iv);

// Operation creation.
// TODO: Populate entry block arguments with private variables.
auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>(
converter, loc, distributeClauseOps, /*blockArgTypes=*/{});
distributeOp.setComposite(/*val=*/true);

// TODO: Add private variables to entry block arguments.
auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
converter, loc, wsloopClauseOps, wsloopReductionTypes);
wsloopOp.setComposite(/*val=*/true);

// Construct wrapper entry block list and associated symbols. It is important
// that the symbol order and the block argument order match, so that the
// symbol-value bindings created are correct.
auto &wrapperSyms = wsloopReductionSyms;

auto wrapperArgs = llvm::to_vector(
llvm::concat<mlir::BlockArgument>(distributeOp.getRegion().getArguments(),
wsloopOp.getRegion().getArguments()));

genLoopNestOp(converter, symTable, semaCtx, eval, loc, queue, doItem,
loopNestClauseOps, iv, wrapperSyms, wrapperArgs,
llvm::omp::Directive::OMPD_distribute_parallel_do, dsp);
}

static void genCompositeDistributeParallelDoSimd(
Expand Down
79 changes: 79 additions & 0 deletions flang/test/Lower/OpenMP/distribute-parallel-do.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
! This test checks lowering of OpenMP DISTRIBUTE PARALLEL DO composite
! constructs.

! RUN: bbc -fopenmp -emit-hlfir %s -o - | FileCheck %s
! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - | FileCheck %s

! CHECK-LABEL: func.func @_QPdistribute_parallel_do_num_threads(
subroutine distribute_parallel_do_num_threads()
!$omp teams

! CHECK: omp.parallel num_threads({{.*}}) private({{.*}}) {
! CHECK: omp.distribute {
! CHECK-NEXT: omp.wsloop {
! CHECK-NEXT: omp.loop_nest
!$omp distribute parallel do num_threads(10)
do index_ = 1, 10
end do
!$omp end distribute parallel do

!$omp end teams
end subroutine distribute_parallel_do_num_threads

! CHECK-LABEL: func.func @_QPdistribute_parallel_do_dist_schedule(
subroutine distribute_parallel_do_dist_schedule()
!$omp teams

! CHECK: omp.parallel private({{.*}}) {
! CHECK: omp.distribute dist_schedule_static dist_schedule_chunk_size({{.*}}) {
! CHECK-NEXT: omp.wsloop {
! CHECK-NEXT: omp.loop_nest
!$omp distribute parallel do dist_schedule(static, 4)
do index_ = 1, 10
end do
!$omp end distribute parallel do

!$omp end teams
end subroutine distribute_parallel_do_dist_schedule

! CHECK-LABEL: func.func @_QPdistribute_parallel_do_ordered(
subroutine distribute_parallel_do_ordered()
!$omp teams

! CHECK: omp.parallel private({{.*}}) {
! CHECK: omp.distribute {
! CHECK-NEXT: omp.wsloop ordered(1) {
! CHECK-NEXT: omp.loop_nest
!$omp distribute parallel do ordered(1)
do index_ = 1, 10
end do
!$omp end distribute parallel do

!$omp end teams
end subroutine distribute_parallel_do_ordered

! CHECK-LABEL: func.func @_QPdistribute_parallel_do_private(
subroutine distribute_parallel_do_private()
! CHECK: %[[INDEX_ALLOC:.*]] = fir.alloca i32
! CHECK: %[[INDEX:.*]]:2 = hlfir.declare %[[INDEX_ALLOC]]
! CHECK: %[[X_ALLOC:.*]] = fir.alloca i64
! CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_ALLOC]]
integer(8) :: x

! CHECK: omp.teams {
!$omp teams

! CHECK: omp.parallel private(@{{.*}} %[[X]]#0 -> %[[X_ARG:.*]] : !fir.ref<i64>,
! CHECK-SAME: @{{.*}} %[[INDEX]]#0 -> %[[INDEX_ARG:.*]] : !fir.ref<i32>) {
! CHECK: %[[X_PRIV:.*]]:2 = hlfir.declare %[[X_ARG]]
! CHECK: %[[INDEX_PRIV:.*]]:2 = hlfir.declare %[[INDEX_ARG]]
! CHECK: omp.distribute {
! CHECK-NEXT: omp.wsloop {
! CHECK-NEXT: omp.loop_nest
!$omp distribute parallel do private(x)
do index_ = 1, 10
end do
!$omp end distribute parallel do

!$omp end teams
end subroutine distribute_parallel_do_private
Loading
Loading