-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[Flang][OpenMP] NFC: Refactor reduction code #79876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flang][OpenMP] NFC: Refactor reduction code #79876
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-flang-openmp Author: Kiran Chandramohan (kiranchandramohan) ChangesIntroduces a new enumeration to list all Fortran reduction identifiers. Moves the combiner code-generation into a separate function for possible reuse in array context in future. Patch is 23.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79876.diff 1 Files Affected:
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 7dd25f75d9eb76f..52d222f3d601f6a 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -726,21 +726,59 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
class ReductionProcessor {
public:
- enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
- static IntrinsicProc
+ // TODO: Move this enumeration to the OpenMP dialect
+ enum ReductionIdentifier {
+ ID,
+ USER_DEF_OP,
+ ADD,
+ SUBTRACT,
+ MULTIPLY,
+ AND,
+ OR,
+ EQV,
+ NEQV,
+ MAX,
+ MIN,
+ IAND,
+ IOR,
+ IEOR
+ };
+ static ReductionIdentifier
getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
- auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+ auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
getRealName(pd).ToString())
- .Case("max", IntrinsicProc::MAX)
- .Case("min", IntrinsicProc::MIN)
- .Case("iand", IntrinsicProc::IAND)
- .Case("ior", IntrinsicProc::IOR)
- .Case("ieor", IntrinsicProc::IEOR)
+ .Case("max", ReductionIdentifier::MAX)
+ .Case("min", ReductionIdentifier::MIN)
+ .Case("iand", ReductionIdentifier::IAND)
+ .Case("ior", ReductionIdentifier::IOR)
+ .Case("ieor", ReductionIdentifier::IEOR)
.Default(std::nullopt);
assert(redType && "Invalid Reduction");
return *redType;
}
+ static ReductionIdentifier getReductionType(
+ Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+ switch (intrinsicOp) {
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ return ReductionIdentifier::ADD;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+ return ReductionIdentifier::SUBTRACT;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ return ReductionIdentifier::MULTIPLY;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+ return ReductionIdentifier::AND;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ return ReductionIdentifier::EQV;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+ return ReductionIdentifier::OR;
+ case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return ReductionIdentifier::NEQV;
+ default:
+ llvm_unreachable("unexpected intrinsic operator in reduction");
+ }
+ }
+
static bool supportedIntrinsicProcReduction(
const Fortran::parser::ProcedureDesignator &pd) {
const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
@@ -748,13 +786,13 @@ class ReductionProcessor {
if (!name->symbol->GetUltimate().attrs().test(
Fortran::semantics::Attr::INTRINSIC))
return false;
- auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+ auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
getRealName(name).ToString())
- .Case("max", IntrinsicProc::MAX)
- .Case("min", IntrinsicProc::MIN)
- .Case("iand", IntrinsicProc::IAND)
- .Case("ior", IntrinsicProc::IOR)
- .Case("ieor", IntrinsicProc::IEOR)
+ .Case("max", ReductionIdentifier::MAX)
+ .Case("min", ReductionIdentifier::MIN)
+ .Case("iand", ReductionIdentifier::IAND)
+ .Case("ior", ReductionIdentifier::IOR)
+ .Case("ieor", ReductionIdentifier::IEOR)
.Default(std::nullopt);
if (redType)
return true;
@@ -812,32 +850,30 @@ class ReductionProcessor {
/// reductionOpName. For example:
/// 0 + x = x,
/// 1 * x = x
- static int getOperationIdentity(
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Location loc) {
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ static int getOperationIdentity(ReductionIdentifier redId,
+ mlir::Location loc) {
+ switch (redId) {
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::NEQV:
return 0;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::EQV:
return 1;
default:
TODO(loc, "Reduction of some intrinsic operators is not supported");
}
}
- static mlir::Value getIntrinsicProcInitValue(
- mlir::Location loc, mlir::Type type,
- const Fortran::parser::ProcedureDesignator &procDesignator,
- fir::FirOpBuilder &builder) {
+ static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
+ ReductionIdentifier redId,
+ fir::FirOpBuilder &builder) {
assert((fir::isa_integer(type) || fir::isa_real(type) ||
type.isa<fir::LogicalType>()) &&
"only integer, logical and real types are currently supported");
- switch (getReductionType(procDesignator)) {
- case IntrinsicProc::MAX: {
+ switch (redId) {
+ case ReductionIdentifier::MAX: {
if (auto ty = type.dyn_cast<mlir::FloatType>()) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
@@ -847,7 +883,7 @@ class ReductionProcessor {
int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, minInt);
}
- case IntrinsicProc::MIN: {
+ case ReductionIdentifier::MIN: {
if (auto ty = type.dyn_cast<mlir::FloatType>()) {
const llvm::fltSemantics &sem = ty.getFloatSemantics();
return builder.createRealConstant(
@@ -857,46 +893,50 @@ class ReductionProcessor {
int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, maxInt);
}
- case IntrinsicProc::IOR: {
+ case ReductionIdentifier::IOR: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, zeroInt);
}
- case IntrinsicProc::IEOR: {
+ case ReductionIdentifier::IEOR: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, zeroInt);
}
- case IntrinsicProc::IAND: {
+ case ReductionIdentifier::IAND: {
unsigned bits = type.getIntOrFloatBitWidth();
int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
return builder.createIntegerConstant(loc, type, allOnInt);
}
- }
- llvm_unreachable("Unknown Reduction Intrinsic");
- }
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::EQV:
+ case ReductionIdentifier::NEQV:
+ if (type.isa<mlir::FloatType>())
+ return builder.create<mlir::arith::ConstantOp>(
+ loc, type,
+ builder.getFloatAttr(type,
+ (double)getOperationIdentity(redId, loc)));
+
+ if (type.isa<fir::LogicalType>()) {
+ mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
+ loc, builder.getI1Type(),
+ builder.getIntegerAttr(builder.getI1Type(),
+ getOperationIdentity(redId, loc)));
+ return builder.createConvert(loc, type, intConst);
+ }
- static mlir::Value getIntrinsicOpInitValue(
- mlir::Location loc, mlir::Type type,
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- fir::FirOpBuilder &builder) {
- if (type.isa<mlir::FloatType>())
return builder.create<mlir::arith::ConstantOp>(
loc, type,
- builder.getFloatAttr(type,
- (double)getOperationIdentity(intrinsicOp, loc)));
-
- if (type.isa<fir::LogicalType>()) {
- mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
- loc, builder.getI1Type(),
- builder.getIntegerAttr(builder.getI1Type(),
- getOperationIdentity(intrinsicOp, loc)));
- return builder.createConvert(loc, type, intConst);
+ builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
+ case ReductionIdentifier::ID:
+ case ReductionIdentifier::USER_DEF_OP:
+ case ReductionIdentifier::SUBTRACT:
+ TODO(loc, "Reduction of some identifier types is not supported");
}
-
- return builder.create<mlir::arith::ConstantOp>(
- loc, type,
- builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+ llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
}
template <typename FloatOp, typename IntegerOp>
@@ -910,118 +950,46 @@ class ReductionProcessor {
return builder.create<FloatOp>(loc, op1, op2);
}
- /// Creates an OpenMP reduction declaration and inserts it into the provided
- /// symbol table. The declaration has a constant initializer with the neutral
- /// value `initValue`, and the reduction combiner carried over from `reduce`.
- /// TODO: Generalize this for non-integer types, add atomic region.
- static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- const Fortran::parser::ProcedureDesignator &procDesignator,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
-
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
-
- mlir::OpBuilder modBuilder(module.getBodyRegion());
-
- decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
- loc, reductionOpName, type);
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- mlir::Value init =
- getIntrinsicProcInitValue(loc, type, procDesignator, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
-
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
-
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
+ static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
+ mlir::Location loc,
+ ReductionIdentifier redId,
+ mlir::Type type, mlir::Value op1,
+ mlir::Value op2) {
mlir::Value reductionOp;
- switch (getReductionType(procDesignator)) {
- case IntrinsicProc::MAX:
+ switch (redId) {
+ case ReductionIdentifier::MAX:
reductionOp =
getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
builder, type, loc, op1, op2);
break;
- case IntrinsicProc::MIN:
+ case ReductionIdentifier::MIN:
reductionOp =
getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
builder, type, loc, op1, op2);
break;
- case IntrinsicProc::IOR:
+ case ReductionIdentifier::IOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
break;
- case IntrinsicProc::IEOR:
+ case ReductionIdentifier::IEOR:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
break;
- case IntrinsicProc::IAND:
+ case ReductionIdentifier::IAND:
assert((type.isIntOrIndex()) && "only integer is expected");
reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
break;
- }
-
- builder.create<mlir::omp::YieldOp>(loc, reductionOp);
- return decl;
- }
-
- /// Creates an OpenMP reduction declaration and inserts it into the provided
- /// symbol table. The declaration has a constant initializer with the neutral
- /// value `initValue`, and the reduction combiner carried over from `reduce`.
- /// TODO: Generalize this for non-integer types, add atomic region.
- static mlir::omp::ReductionDeclareOp createReductionDecl(
- fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
- Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
- mlir::Type type, mlir::Location loc) {
- mlir::OpBuilder::InsertionGuard guard(builder);
- mlir::ModuleOp module = builder.getModule();
-
- auto decl =
- module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
- if (decl)
- return decl;
-
- mlir::OpBuilder modBuilder(module.getBodyRegion());
-
- decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
- loc, reductionOpName, type);
- builder.createBlock(&decl.getInitializerRegion(),
- decl.getInitializerRegion().end(), {type}, {loc});
- builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
- mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
- builder.create<mlir::omp::YieldOp>(loc, init);
-
- builder.createBlock(&decl.getReductionRegion(),
- decl.getReductionRegion().end(), {type, type},
- {loc, loc});
-
- builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
- mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
- mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
- mlir::Value reductionOp;
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+ case ReductionIdentifier::ADD:
reductionOp =
getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
builder, type, loc, op1, op2);
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+ case ReductionIdentifier::MULTIPLY:
reductionOp =
getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
builder, type, loc, op1, op2);
break;
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
+ case ReductionIdentifier::AND: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
@@ -1031,7 +999,7 @@ class ReductionProcessor {
reductionOp = builder.createConvert(loc, type, andiOp);
break;
}
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
+ case ReductionIdentifier::OR: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
@@ -1040,7 +1008,7 @@ class ReductionProcessor {
reductionOp = builder.createConvert(loc, type, oriOp);
break;
}
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
+ case ReductionIdentifier::EQV: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
@@ -1050,7 +1018,7 @@ class ReductionProcessor {
reductionOp = builder.createConvert(loc, type, cmpiOp);
break;
}
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
+ case ReductionIdentifier::NEQV: {
mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
@@ -1064,7 +1032,46 @@ class ReductionProcessor {
TODO(loc, "Reduction of some intrinsic operators is not supported");
}
+ return reductionOp;
+ }
+
+ /// Creates an OpenMP reduction declaration and inserts it into the provided
+ /// symbol table. The declaration has a constant initializer with the neutral
+ /// value `initValue`, and the reduction combiner carried over from `reduce`.
+ /// TODO: Generalize this for non-integer types, add atomic region.
+ static mlir::omp::ReductionDeclareOp createReductionDecl(
+ fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+ const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::ModuleOp module = builder.getModule();
+
+ auto decl =
+ module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+ if (decl)
+ return decl;
+
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+ decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+ loc, reductionOpName, type);
+ builder.createBlock(&decl.getInitializerRegion(),
+ decl.getInitializerRegion().end(), {type}, {loc});
+ builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+ mlir::Value init = getReductionInitValue(loc, type, redId, builder);
+ builder.create<mlir::omp::YieldOp>(loc, init);
+
+ builder.createBlock(&decl.getReductionRegion(),
+ decl.getReductionRegion().end(), {type, type},
+ {loc, loc});
+
+ builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+ mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+ mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+ mlir::Value reductionOp =
+ createScalarCombiner(builder, loc, redId, type, op1, op2);
builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+
return decl;
}
@@ -1087,15 +1094,15 @@ class ReductionProcessor {
const auto &intrinsicOp{
std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
redDefinedOp->u)};
- switch (intrinsicOp) {
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
- case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+ ReductionIdentifier redId = getReductionType(intrinsicOp);
+ switch (redId) {
+ case ReductionIdentifier::ADD:
+ case ReductionIdentifier::MULTIPLY:
+ case ReductionIdentifier::AND:
+ case ReductionIdentifier::EQV:
+ case ReductionIdentifier::OR:
+ case ReductionIdentifier::NEQV:
break;
-
default:
TODO(currentLocation,
"Reduction of some intrinsic operators is not supported");
@@ -1115,11 +1122,11 @@ class ReductionProcessor {
decl = createReductionDecl(
firOpBuilder,
getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
- intrinsicOp, redType, currentLocation);
+ redId, redType, currentLocation);
else if (redType.isIntOrIndexOrFloat()) {
decl = createReductionDecl(firOpBuilder,
...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Sorry this fell off my radar for so long. No need to re-approve after fixing the minor comment.
Introduces a new enumeration to list all Fortran reduction identifiers. Moves the combiner code-generation into a separate function for possible reuse in array context in future.
8ad53ab
to
ed67a15
Compare
Introduces a new enumeration to list all Fortran reduction identifiers. Moves the combiner code-generation into a separate function for possible reuse in array context in future.