Skip to content

[flang][OpenMP] Rewrite omp.loop to semantically equivalent ops #115443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions flang/include/flang/Common/OpenMP-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
//===-- include/flang/Common/OpenMP-utils.h --------------------*- C++ -*-====//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef FORTRAN_COMMON_OPENMP_UTILS_H_
#define FORTRAN_COMMON_OPENMP_UTILS_H_

#include "flang/Semantics/symbol.h"

#include "mlir/IR/Builders.h"
#include "mlir/IR/Value.h"

#include "llvm/ADT/ArrayRef.h"

namespace Fortran::common::openmp {
/// Structure holding the information needed to create and bind entry block
/// arguments associated to a single clause.
struct EntryBlockArgsEntry {
llvm::ArrayRef<const Fortran::semantics::Symbol *> syms;
llvm::ArrayRef<mlir::Value> vars;

bool isValid() const {
// This check allows specifying a smaller number of symbols than values
// because in some case cases a single symbol generates multiple block
// arguments.
return syms.size() <= vars.size();
}
};

/// Structure holding the information needed to create and bind entry block
/// arguments associated to all clauses that can define them.
struct EntryBlockArgs {
EntryBlockArgsEntry inReduction;
EntryBlockArgsEntry map;
EntryBlockArgsEntry priv;
EntryBlockArgsEntry reduction;
EntryBlockArgsEntry taskReduction;
EntryBlockArgsEntry useDeviceAddr;
EntryBlockArgsEntry useDevicePtr;

bool isValid() const {
return inReduction.isValid() && map.isValid() && priv.isValid() &&
reduction.isValid() && taskReduction.isValid() &&
useDeviceAddr.isValid() && useDevicePtr.isValid();
}

auto getSyms() const {
return llvm::concat<const Fortran::semantics::Symbol *const>(
inReduction.syms, map.syms, priv.syms, reduction.syms,
taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
}

auto getVars() const {
return llvm::concat<const mlir::Value>(inReduction.vars, map.vars,
priv.vars, reduction.vars, taskReduction.vars, useDeviceAddr.vars,
useDevicePtr.vars);
}
};

mlir::Block *genEntryBlock(
mlir::OpBuilder &builder, const EntryBlockArgs &args, mlir::Region &region);
} // namespace Fortran::common::openmp

#endif // FORTRAN_COMMON_OPENMP_UTILS_H_
20 changes: 20 additions & 0 deletions flang/include/flang/Optimizer/OpenMP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,29 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> {
];
}


// Needs to be scheduled on Module as we create functions in it
def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
let summary = "Lower workshare construct";
}

def GenericLoopConversionPass
: Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
let summary = "Converts OpenMP generic `omp.loop` to semantically "
"equivalent OpenMP ops";
let description = [{
Rewrites `omp.loop` ops to their semantically equivalent nest of ops. The
rewrite depends on the nesting/combination structure of the `loop` op
within its surrounding context as well as its `bind` clause value.

We assume for now that all `omp.loop` ops will occur inside `FuncOp`'s. This
will most likely remain the case in the future; even if, for example, we
need a loop in copying data for a `firstprivate` variable, this loop will
be nested in a constructor, an overloaded operator, or a runtime function.
}];
let dependentDialects = [
"mlir::omp::OpenMPDialect"
];
}

#endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
4 changes: 4 additions & 0 deletions flang/lib/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ add_flang_library(FortranCommon
default-kinds.cpp
idioms.cpp
LangOptions.cpp
OpenMP-utils.cpp
Version.cpp
${version_inc}

LINK_COMPONENTS
Support

LINK_LIBS
MLIRIR
)
47 changes: 47 additions & 0 deletions flang/lib/Common/OpenMP-utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//===-- include/flang/Common/OpenMP-utils.cpp ------------------*- C++ -*-====//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "flang/Common/OpenMP-utils.h"

#include "mlir/IR/OpDefinition.h"

namespace Fortran::common::openmp {
mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args,
mlir::Region &region) {
assert(args.isValid() && "invalid args");
assert(region.empty() && "non-empty region");

llvm::SmallVector<mlir::Type> types;
llvm::SmallVector<mlir::Location> locs;
unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
args.priv.vars.size() + args.reduction.vars.size() +
args.taskReduction.vars.size() + args.useDeviceAddr.vars.size() +
args.useDevicePtr.vars.size();
types.reserve(numVars);
locs.reserve(numVars);

auto extractTypeLoc = [&types, &locs](llvm::ArrayRef<mlir::Value> vals) {
llvm::transform(vals, std::back_inserter(types),
[](mlir::Value v) { return v.getType(); });
llvm::transform(vals, std::back_inserter(locs),
[](mlir::Value v) { return v.getLoc(); });
};

// Populate block arguments in clause name alphabetical order to match
// expected order by the BlockArgOpenMPOpInterface.
extractTypeLoc(args.inReduction.vars);
extractTypeLoc(args.map.vars);
extractTypeLoc(args.priv.vars);
extractTypeLoc(args.reduction.vars);
extractTypeLoc(args.taskReduction.vars);
extractTypeLoc(args.useDeviceAddr.vars);
extractTypeLoc(args.useDevicePtr.vars);

return builder.createBlock(&region, {}, types, locs);
}
} // namespace Fortran::common::openmp
106 changes: 9 additions & 97 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/OpenMP-utils.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
Expand All @@ -41,57 +42,12 @@
#include "llvm/Frontend/OpenMP/OMPConstants.h"

using namespace Fortran::lower::omp;
using namespace Fortran::common::openmp;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: For consistency, it looks like Fortran::common::omp would be a better name for the namespace in the end.


//===----------------------------------------------------------------------===//
// Code generation helper functions
//===----------------------------------------------------------------------===//

namespace {
/// Structure holding the information needed to create and bind entry block
/// arguments associated to a single clause.
struct EntryBlockArgsEntry {
llvm::ArrayRef<const semantics::Symbol *> syms;
llvm::ArrayRef<mlir::Value> vars;

bool isValid() const {
// This check allows specifying a smaller number of symbols than values
// because in some case cases a single symbol generates multiple block
// arguments.
return syms.size() <= vars.size();
}
};

/// Structure holding the information needed to create and bind entry block
/// arguments associated to all clauses that can define them.
struct EntryBlockArgs {
EntryBlockArgsEntry inReduction;
EntryBlockArgsEntry map;
EntryBlockArgsEntry priv;
EntryBlockArgsEntry reduction;
EntryBlockArgsEntry taskReduction;
EntryBlockArgsEntry useDeviceAddr;
EntryBlockArgsEntry useDevicePtr;

bool isValid() const {
return inReduction.isValid() && map.isValid() && priv.isValid() &&
reduction.isValid() && taskReduction.isValid() &&
useDeviceAddr.isValid() && useDevicePtr.isValid();
}

auto getSyms() const {
return llvm::concat<const semantics::Symbol *const>(
inReduction.syms, map.syms, priv.syms, reduction.syms,
taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
}

auto getVars() const {
return llvm::concat<const mlir::Value>(
inReduction.vars, map.vars, priv.vars, reduction.vars,
taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
}
};
} // namespace

static void genOMPDispatch(lower::AbstractConverter &converter,
lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
Expand Down Expand Up @@ -623,50 +579,6 @@ static void genLoopVars(
firOpBuilder.setInsertionPointAfter(storeOp);
}

/// Create an entry block for the given region, including the clause-defined
/// arguments specified.
///
/// \param [in] converter - PFT to MLIR conversion interface.
/// \param [in] args - entry block arguments information for the given
/// operation.
/// \param [in] region - Empty region in which to create the entry block.
static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
const EntryBlockArgs &args,
mlir::Region &region) {
assert(args.isValid() && "invalid args");
assert(region.empty() && "non-empty region");
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

llvm::SmallVector<mlir::Type> types;
llvm::SmallVector<mlir::Location> locs;
unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
args.priv.vars.size() + args.reduction.vars.size() +
args.taskReduction.vars.size() +
args.useDeviceAddr.vars.size() +
args.useDevicePtr.vars.size();
types.reserve(numVars);
locs.reserve(numVars);

auto extractTypeLoc = [&types, &locs](llvm::ArrayRef<mlir::Value> vals) {
llvm::transform(vals, std::back_inserter(types),
[](mlir::Value v) { return v.getType(); });
llvm::transform(vals, std::back_inserter(locs),
[](mlir::Value v) { return v.getLoc(); });
};

// Populate block arguments in clause name alphabetical order to match
// expected order by the BlockArgOpenMPOpInterface.
extractTypeLoc(args.inReduction.vars);
extractTypeLoc(args.map.vars);
extractTypeLoc(args.priv.vars);
extractTypeLoc(args.reduction.vars);
extractTypeLoc(args.taskReduction.vars);
extractTypeLoc(args.useDeviceAddr.vars);
extractTypeLoc(args.useDevicePtr.vars);

return firOpBuilder.createBlock(&region, {}, types, locs);
}

static void
markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
mlir::omp::DeclareTargetCaptureClause captureClause,
Expand Down Expand Up @@ -919,7 +831,7 @@ static void genBodyOfTargetDataOp(
ConstructQueue::const_iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

genEntryBlock(converter, args, dataOp.getRegion());
genEntryBlock(firOpBuilder, args, dataOp.getRegion());
bindEntryBlockArgs(converter, dataOp, args);

// Insert dummy instruction to remember the insertion position. The
Expand Down Expand Up @@ -996,7 +908,7 @@ static void genBodyOfTargetOp(
auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);

mlir::Region &region = targetOp.getRegion();
mlir::Block *entryBlock = genEntryBlock(converter, args, region);
mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region);
bindEntryBlockArgs(converter, targetOp, args);

// Check if cloning the bounds introduced any dependency on the outer region.
Expand Down Expand Up @@ -1122,7 +1034,7 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter,
auto op = firOpBuilder.create<OpTy>(loc, clauseOps);

// Create entry block with arguments.
genEntryBlock(converter, args, op.getRegion());
genEntryBlock(firOpBuilder, args, op.getRegion());

return op;
}
Expand Down Expand Up @@ -1588,7 +1500,7 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
const EntryBlockArgs &args, DataSharingProcessor *dsp,
bool isComposite = false) {
auto genRegionEntryCB = [&](mlir::Operation *op) {
genEntryBlock(converter, args, op->getRegion(0));
genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
bindEntryBlockArgs(
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
return llvm::to_vector(args.getSyms());
Expand Down Expand Up @@ -1661,12 +1573,12 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
args.reduction.syms = reductionSyms;
args.reduction.vars = clauseOps.reductionVars;

genEntryBlock(converter, args, sectionsOp.getRegion());
genEntryBlock(builder, args, sectionsOp.getRegion());
mlir::Operation *terminator =
lower::genOpenMPTerminator(builder, sectionsOp, loc);

auto genRegionEntryCB = [&](mlir::Operation *op) {
genEntryBlock(converter, args, op->getRegion(0));
genEntryBlock(builder, args, op->getRegion(0));
bindEntryBlockArgs(
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
return llvm::to_vector(args.getSyms());
Expand Down Expand Up @@ -1989,7 +1901,7 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
taskArgs.priv.vars = clauseOps.privateVars;

auto genRegionEntryCB = [&](mlir::Operation *op) {
genEntryBlock(converter, taskArgs, op->getRegion(0));
genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0));
bindEntryBlockArgs(converter,
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
taskArgs);
Expand Down
2 changes: 2 additions & 0 deletions flang/lib/Optimizer/OpenMP/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)

add_flang_library(FlangOpenMPTransforms
FunctionFiltering.cpp
GenericLoopConversion.cpp
MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
Expand All @@ -25,4 +26,5 @@ add_flang_library(FlangOpenMPTransforms
HLFIRDialect
MLIRIR
MLIRPass
MLIRTransformUtils
)
Loading