Skip to content

[Flang][OpenMP] NFC: Refactor reduction code #79876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

kiranchandramohan
Copy link
Contributor

Introduces a new enumeration to list all Fortran reduction identifiers. Moves the combiner code-generation into a separate function for possible reuse in array context in future.

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp labels Jan 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2024

@llvm/pr-subscribers-flang-fir-hlfir

@llvm/pr-subscribers-flang-openmp

Author: Kiran Chandramohan (kiranchandramohan)

Changes

Introduces a new enumeration to list all Fortran reduction identifiers. Moves the combiner code-generation into a separate function for possible reuse in array context in future.


Patch is 23.40 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79876.diff

1 Files Affected:

  • (modified) flang/lib/Lower/OpenMP.cpp (+170-167)
diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp
index 7dd25f75d9eb76f..52d222f3d601f6a 100644
--- a/flang/lib/Lower/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP.cpp
@@ -726,21 +726,59 @@ static void checkMapType(mlir::Location location, mlir::Type type) {
 
 class ReductionProcessor {
 public:
-  enum IntrinsicProc { MAX, MIN, IAND, IOR, IEOR };
-  static IntrinsicProc
+  // TODO: Move this enumeration to the OpenMP dialect
+  enum ReductionIdentifier {
+    ID,
+    USER_DEF_OP,
+    ADD,
+    SUBTRACT,
+    MULTIPLY,
+    AND,
+    OR,
+    EQV,
+    NEQV,
+    MAX,
+    MIN,
+    IAND,
+    IOR,
+    IEOR
+  };
+  static ReductionIdentifier
   getReductionType(const Fortran::parser::ProcedureDesignator &pd) {
-    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+    auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
                        getRealName(pd).ToString())
-                       .Case("max", IntrinsicProc::MAX)
-                       .Case("min", IntrinsicProc::MIN)
-                       .Case("iand", IntrinsicProc::IAND)
-                       .Case("ior", IntrinsicProc::IOR)
-                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Case("max", ReductionIdentifier::MAX)
+                       .Case("min", ReductionIdentifier::MIN)
+                       .Case("iand", ReductionIdentifier::IAND)
+                       .Case("ior", ReductionIdentifier::IOR)
+                       .Case("ieor", ReductionIdentifier::IEOR)
                        .Default(std::nullopt);
     assert(redType && "Invalid Reduction");
     return *redType;
   }
 
+  static ReductionIdentifier getReductionType(
+      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp) {
+    switch (intrinsicOp) {
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+      return ReductionIdentifier::ADD;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Subtract:
+      return ReductionIdentifier::SUBTRACT;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+      return ReductionIdentifier::MULTIPLY;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
+      return ReductionIdentifier::AND;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+      return ReductionIdentifier::EQV;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
+      return ReductionIdentifier::OR;
+    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      return ReductionIdentifier::NEQV;
+    default:
+      llvm_unreachable("unexpected intrinsic operator in reduction");
+    }
+  }
+
   static bool supportedIntrinsicProcReduction(
       const Fortran::parser::ProcedureDesignator &pd) {
     const auto *name{Fortran::parser::Unwrap<Fortran::parser::Name>(pd)};
@@ -748,13 +786,13 @@ class ReductionProcessor {
     if (!name->symbol->GetUltimate().attrs().test(
             Fortran::semantics::Attr::INTRINSIC))
       return false;
-    auto redType = llvm::StringSwitch<std::optional<IntrinsicProc>>(
+    auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
                        getRealName(name).ToString())
-                       .Case("max", IntrinsicProc::MAX)
-                       .Case("min", IntrinsicProc::MIN)
-                       .Case("iand", IntrinsicProc::IAND)
-                       .Case("ior", IntrinsicProc::IOR)
-                       .Case("ieor", IntrinsicProc::IEOR)
+                       .Case("max", ReductionIdentifier::MAX)
+                       .Case("min", ReductionIdentifier::MIN)
+                       .Case("iand", ReductionIdentifier::IAND)
+                       .Case("ior", ReductionIdentifier::IOR)
+                       .Case("ieor", ReductionIdentifier::IEOR)
                        .Default(std::nullopt);
     if (redType)
       return true;
@@ -812,32 +850,30 @@ class ReductionProcessor {
   /// reductionOpName. For example:
   ///    0 + x = x,
   ///    1 * x = x
-  static int getOperationIdentity(
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Location loc) {
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+  static int getOperationIdentity(ReductionIdentifier redId,
+                                  mlir::Location loc) {
+    switch (redId) {
+    case ReductionIdentifier::ADD:
+    case ReductionIdentifier::OR:
+    case ReductionIdentifier::NEQV:
       return 0;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
+    case ReductionIdentifier::MULTIPLY:
+    case ReductionIdentifier::AND:
+    case ReductionIdentifier::EQV:
       return 1;
     default:
       TODO(loc, "Reduction of some intrinsic operators is not supported");
     }
   }
 
-  static mlir::Value getIntrinsicProcInitValue(
-      mlir::Location loc, mlir::Type type,
-      const Fortran::parser::ProcedureDesignator &procDesignator,
-      fir::FirOpBuilder &builder) {
+  static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
+                                           ReductionIdentifier redId,
+                                           fir::FirOpBuilder &builder) {
     assert((fir::isa_integer(type) || fir::isa_real(type) ||
             type.isa<fir::LogicalType>()) &&
            "only integer, logical and real types are currently supported");
-    switch (getReductionType(procDesignator)) {
-    case IntrinsicProc::MAX: {
+    switch (redId) {
+    case ReductionIdentifier::MAX: {
       if (auto ty = type.dyn_cast<mlir::FloatType>()) {
         const llvm::fltSemantics &sem = ty.getFloatSemantics();
         return builder.createRealConstant(
@@ -847,7 +883,7 @@ class ReductionProcessor {
       int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, minInt);
     }
-    case IntrinsicProc::MIN: {
+    case ReductionIdentifier::MIN: {
       if (auto ty = type.dyn_cast<mlir::FloatType>()) {
         const llvm::fltSemantics &sem = ty.getFloatSemantics();
         return builder.createRealConstant(
@@ -857,46 +893,50 @@ class ReductionProcessor {
       int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, maxInt);
     }
-    case IntrinsicProc::IOR: {
+    case ReductionIdentifier::IOR: {
       unsigned bits = type.getIntOrFloatBitWidth();
       int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, zeroInt);
     }
-    case IntrinsicProc::IEOR: {
+    case ReductionIdentifier::IEOR: {
       unsigned bits = type.getIntOrFloatBitWidth();
       int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, zeroInt);
     }
-    case IntrinsicProc::IAND: {
+    case ReductionIdentifier::IAND: {
       unsigned bits = type.getIntOrFloatBitWidth();
       int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
       return builder.createIntegerConstant(loc, type, allOnInt);
     }
-    }
-    llvm_unreachable("Unknown Reduction Intrinsic");
-  }
+    case ReductionIdentifier::ADD:
+    case ReductionIdentifier::MULTIPLY:
+    case ReductionIdentifier::AND:
+    case ReductionIdentifier::OR:
+    case ReductionIdentifier::EQV:
+    case ReductionIdentifier::NEQV:
+      if (type.isa<mlir::FloatType>())
+        return builder.create<mlir::arith::ConstantOp>(
+            loc, type,
+            builder.getFloatAttr(type,
+                                 (double)getOperationIdentity(redId, loc)));
+
+      if (type.isa<fir::LogicalType>()) {
+        mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
+            loc, builder.getI1Type(),
+            builder.getIntegerAttr(builder.getI1Type(),
+                                   getOperationIdentity(redId, loc)));
+        return builder.createConvert(loc, type, intConst);
+      }
 
-  static mlir::Value getIntrinsicOpInitValue(
-      mlir::Location loc, mlir::Type type,
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      fir::FirOpBuilder &builder) {
-    if (type.isa<mlir::FloatType>())
       return builder.create<mlir::arith::ConstantOp>(
           loc, type,
-          builder.getFloatAttr(type,
-                               (double)getOperationIdentity(intrinsicOp, loc)));
-
-    if (type.isa<fir::LogicalType>()) {
-      mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
-          loc, builder.getI1Type(),
-          builder.getIntegerAttr(builder.getI1Type(),
-                                 getOperationIdentity(intrinsicOp, loc)));
-      return builder.createConvert(loc, type, intConst);
+          builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
+    case ReductionIdentifier::ID:
+    case ReductionIdentifier::USER_DEF_OP:
+    case ReductionIdentifier::SUBTRACT:
+      TODO(loc, "Reduction of some identifier types is not supported");
     }
-
-    return builder.create<mlir::arith::ConstantOp>(
-        loc, type,
-        builder.getIntegerAttr(type, getOperationIdentity(intrinsicOp, loc)));
+    llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
   }
 
   template <typename FloatOp, typename IntegerOp>
@@ -910,118 +950,46 @@ class ReductionProcessor {
     return builder.create<FloatOp>(loc, op1, op2);
   }
 
-  /// Creates an OpenMP reduction declaration and inserts it into the provided
-  /// symbol table. The declaration has a constant initializer with the neutral
-  /// value `initValue`, and the reduction combiner carried over from `reduce`.
-  /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      const Fortran::parser::ProcedureDesignator &procDesignator,
-      mlir::Type type, mlir::Location loc) {
-    mlir::OpBuilder::InsertionGuard guard(builder);
-    mlir::ModuleOp module = builder.getModule();
-
-    auto decl =
-        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-    if (decl)
-      return decl;
-
-    mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
-        loc, reductionOpName, type);
-    builder.createBlock(&decl.getInitializerRegion(),
-                        decl.getInitializerRegion().end(), {type}, {loc});
-    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-    mlir::Value init =
-        getIntrinsicProcInitValue(loc, type, procDesignator, builder);
-    builder.create<mlir::omp::YieldOp>(loc, init);
-
-    builder.createBlock(&decl.getReductionRegion(),
-                        decl.getReductionRegion().end(), {type, type},
-                        {loc, loc});
-
-    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
+  static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
+                                          mlir::Location loc,
+                                          ReductionIdentifier redId,
+                                          mlir::Type type, mlir::Value op1,
+                                          mlir::Value op2) {
     mlir::Value reductionOp;
-    switch (getReductionType(procDesignator)) {
-    case IntrinsicProc::MAX:
+    switch (redId) {
+    case ReductionIdentifier::MAX:
       reductionOp =
           getReductionOperation<mlir::arith::MaximumFOp, mlir::arith::MaxSIOp>(
               builder, type, loc, op1, op2);
       break;
-    case IntrinsicProc::MIN:
+    case ReductionIdentifier::MIN:
       reductionOp =
           getReductionOperation<mlir::arith::MinimumFOp, mlir::arith::MinSIOp>(
               builder, type, loc, op1, op2);
       break;
-    case IntrinsicProc::IOR:
+    case ReductionIdentifier::IOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
       break;
-    case IntrinsicProc::IEOR:
+    case ReductionIdentifier::IEOR:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
       break;
-    case IntrinsicProc::IAND:
+    case ReductionIdentifier::IAND:
       assert((type.isIntOrIndex()) && "only integer is expected");
       reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
       break;
-    }
-
-    builder.create<mlir::omp::YieldOp>(loc, reductionOp);
-    return decl;
-  }
-
-  /// Creates an OpenMP reduction declaration and inserts it into the provided
-  /// symbol table. The declaration has a constant initializer with the neutral
-  /// value `initValue`, and the reduction combiner carried over from `reduce`.
-  /// TODO: Generalize this for non-integer types, add atomic region.
-  static mlir::omp::ReductionDeclareOp createReductionDecl(
-      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
-      Fortran::parser::DefinedOperator::IntrinsicOperator intrinsicOp,
-      mlir::Type type, mlir::Location loc) {
-    mlir::OpBuilder::InsertionGuard guard(builder);
-    mlir::ModuleOp module = builder.getModule();
-
-    auto decl =
-        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
-    if (decl)
-      return decl;
-
-    mlir::OpBuilder modBuilder(module.getBodyRegion());
-
-    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
-        loc, reductionOpName, type);
-    builder.createBlock(&decl.getInitializerRegion(),
-                        decl.getInitializerRegion().end(), {type}, {loc});
-    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
-    mlir::Value init = getIntrinsicOpInitValue(loc, type, intrinsicOp, builder);
-    builder.create<mlir::omp::YieldOp>(loc, init);
-
-    builder.createBlock(&decl.getReductionRegion(),
-                        decl.getReductionRegion().end(), {type, type},
-                        {loc, loc});
-
-    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
-    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
-    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
-
-    mlir::Value reductionOp;
-    switch (intrinsicOp) {
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
+    case ReductionIdentifier::ADD:
       reductionOp =
           getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp>(
               builder, type, loc, op1, op2);
       break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
+    case ReductionIdentifier::MULTIPLY:
       reductionOp =
           getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp>(
               builder, type, loc, op1, op2);
       break;
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::AND: {
+    case ReductionIdentifier::AND: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1031,7 +999,7 @@ class ReductionProcessor {
       reductionOp = builder.createConvert(loc, type, andiOp);
       break;
     }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::OR: {
+    case ReductionIdentifier::OR: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1040,7 +1008,7 @@ class ReductionProcessor {
       reductionOp = builder.createConvert(loc, type, oriOp);
       break;
     }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV: {
+    case ReductionIdentifier::EQV: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1050,7 +1018,7 @@ class ReductionProcessor {
       reductionOp = builder.createConvert(loc, type, cmpiOp);
       break;
     }
-    case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV: {
+    case ReductionIdentifier::NEQV: {
       mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
       mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
 
@@ -1064,7 +1032,46 @@ class ReductionProcessor {
       TODO(loc, "Reduction of some intrinsic operators is not supported");
     }
 
+    return reductionOp;
+  }
+
+  /// Creates an OpenMP reduction declaration and inserts it into the provided
+  /// symbol table. The declaration has a constant initializer with the neutral
+  /// value `initValue`, and the reduction combiner carried over from `reduce`.
+  /// TODO: Generalize this for non-integer types, add atomic region.
+  static mlir::omp::ReductionDeclareOp createReductionDecl(
+      fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
+      const ReductionIdentifier redId, mlir::Type type, mlir::Location loc) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    mlir::ModuleOp module = builder.getModule();
+
+    auto decl =
+        module.lookupSymbol<mlir::omp::ReductionDeclareOp>(reductionOpName);
+    if (decl)
+      return decl;
+
+    mlir::OpBuilder modBuilder(module.getBodyRegion());
+
+    decl = modBuilder.create<mlir::omp::ReductionDeclareOp>(
+        loc, reductionOpName, type);
+    builder.createBlock(&decl.getInitializerRegion(),
+                        decl.getInitializerRegion().end(), {type}, {loc});
+    builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
+    mlir::Value init = getReductionInitValue(loc, type, redId, builder);
+    builder.create<mlir::omp::YieldOp>(loc, init);
+
+    builder.createBlock(&decl.getReductionRegion(),
+                        decl.getReductionRegion().end(), {type, type},
+                        {loc, loc});
+
+    builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
+    mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
+    mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
+
+    mlir::Value reductionOp =
+        createScalarCombiner(builder, loc, redId, type, op1, op2);
     builder.create<mlir::omp::YieldOp>(loc, reductionOp);
+
     return decl;
   }
 
@@ -1087,15 +1094,15 @@ class ReductionProcessor {
       const auto &intrinsicOp{
           std::get<Fortran::parser::DefinedOperator::IntrinsicOperator>(
               redDefinedOp->u)};
-      switch (intrinsicOp) {
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::Add:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::Multiply:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::AND:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::EQV:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::OR:
-      case Fortran::parser::DefinedOperator::IntrinsicOperator::NEQV:
+      ReductionIdentifier redId = getReductionType(intrinsicOp);
+      switch (redId) {
+      case ReductionIdentifier::ADD:
+      case ReductionIdentifier::MULTIPLY:
+      case ReductionIdentifier::AND:
+      case ReductionIdentifier::EQV:
+      case ReductionIdentifier::OR:
+      case ReductionIdentifier::NEQV:
         break;
-
       default:
         TODO(currentLocation,
              "Reduction of some intrinsic operators is not supported");
@@ -1115,11 +1122,11 @@ class ReductionProcessor {
               decl = createReductionDecl(
                   firOpBuilder,
                   getReductionName(intrinsicOp, firOpBuilder.getI1Type()),
-                  intrinsicOp, redType, currentLocation);
+                  redId, redType, currentLocation);
             else if (redType.isIntOrIndexOrFloat()) {
               decl = createReductionDecl(firOpBuilder,
                                ...
[truncated]

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Sorry this fell off my radar for so long. No need to re-approve after fixing the minor comment.

Introduces a new enumeration to list all Fortran reduction identifiers.
Moves the combiner code-generation into a separate function for possible
reuse in array context in future.
@kiranchandramohan kiranchandramohan force-pushed the reduction-refactor-scalar branch from 8ad53ab to ed67a15 Compare February 9, 2024 16:43
@kiranchandramohan kiranchandramohan merged commit 301f684 into llvm:main Feb 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants