Skip to content

[flang] Simplify hlfir.sum total reductions. #119482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 13, 2024

Conversation

vzakhari
Copy link
Contributor

I am trying to switch to keeping the reduction value in a temporary
scalar location so that I can use hlfir::genLoopNest easily.
This also allows using omp.loop_nest with worksharing for OpenMP.

I am trying to switch to keeping the reduction value in a temporary
scalar location so that I can use hlfir::genLoopNest easily.
This also allows using omp.loop_nest with worksharing for OpenMP.
@vzakhari vzakhari requested review from tblah and jeanPerier December 11, 2024 01:52
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Dec 11, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

I am trying to switch to keeping the reduction value in a temporary
scalar location so that I can use hlfir::genLoopNest easily.
This also allows using omp.loop_nest with worksharing for OpenMP.


Patch is 46.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119482.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+98-83)
  • (modified) flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir (+163-126)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index b61f9767ccc2b8..2bb1a786f6c12c 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -17,6 +17,7 @@
 #include "flang/Optimizer/HLFIR/HLFIRDialect.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -105,34 +106,47 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
                   mlir::PatternRewriter &rewriter) const override {
     mlir::Location loc = sum.getLoc();
     fir::FirOpBuilder builder{rewriter, sum.getOperation()};
-    hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
-    assert(expr && "expected an expression type for the result of hlfir.sum");
-    mlir::Type elementType = expr.getElementType();
+    mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
     hlfir::Entity array = hlfir::Entity{sum.getArray()};
     mlir::Value mask = sum.getMask();
     mlir::Value dim = sum.getDim();
-    int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+    bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
+    int64_t dimVal =
+        isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
     mlir::Value resultShape, dimExtent;
-    std::tie(resultShape, dimExtent) =
-        genResultShape(loc, builder, array, dimVal);
+    llvm::SmallVector<mlir::Value> arrayExtents;
+    if (isTotalReduction)
+      arrayExtents = genArrayExtents(loc, builder, array);
+    else
+      std::tie(resultShape, dimExtent) =
+          genResultShapeForPartialReduction(loc, builder, array, dimVal);
+
+    // If the mask is present and is a scalar, then we'd better load its value
+    // outside of the reduction loop making the loop unswitching easier.
+    mlir::Value isPresentPred, maskValue;
+    if (mask) {
+      if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
+        // MASK represented by a box might be dynamically optional,
+        // so we have to check for its presence before accessing it.
+        isPresentPred =
+            builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
+      }
+
+      if (hlfir::Entity{mask}.isScalar())
+        maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
+    }
 
     auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
                          mlir::ValueRange inputIndices) -> hlfir::Entity {
       // Loop over all indices in the DIM dimension, and reduce all values.
-      // We do not need to create the reduction loop always: if we can
-      // slice the input array given the inputIndices, then we can
-      // just apply a new SUM operation (total reduction) to the slice.
-      // For the time being, generate the explicit loop because the slicing
-      // requires generating an elemental operation for the input array
-      // (and the mask, if present).
-      // TODO: produce the slices and new SUM after adding a pattern
-      // for expanding total reduction SUM case.
-      mlir::Type indexType = builder.getIndexType();
-      auto one = builder.createIntegerConstant(loc, indexType, 1);
-      auto ub = builder.createConvert(loc, indexType, dimExtent);
+      // If DIM is not present, do total reduction.
 
+      // Create temporary scalar for keeping the running reduction value.
+      mlir::Value reductionTemp =
+          builder.createTemporaryAlloc(loc, elementType, ".sum.reduction");
       // Initial value for the reduction.
       mlir::Value initValue = genInitValue(loc, builder, elementType);
+      builder.create<fir::StoreOp>(loc, initValue, reductionTemp);
 
       // The reduction loop may be unordered if FastMathFlags::reassoc
       // transformations are allowed. The integer reduction is always
@@ -141,42 +155,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
                          static_cast<bool>(sum.getFastmath() &
                                            mlir::arith::FastMathFlags::reassoc);
 
-      // If the mask is present and is a scalar, then we'd better load its value
-      // outside of the reduction loop making the loop unswitching easier.
-      // Maybe it is worth hoisting it from the elemental operation as well.
-      mlir::Value isPresentPred, maskValue;
-      if (mask) {
-        if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
-          // MASK represented by a box might be dynamically optional,
-          // so we have to check for its presence before accessing it.
-          isPresentPred =
-              builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
-        }
-
-        if (hlfir::Entity{mask}.isScalar())
-          maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
-      }
+      llvm::SmallVector<mlir::Value> extents;
+      if (isTotalReduction)
+        extents = arrayExtents;
+      else
+        extents.push_back(
+            builder.createConvert(loc, builder.getIndexType(), dimExtent));
 
       // NOTE: the outer elemental operation may be lowered into
       // omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
       // loop may appear disjoint from the workshare loop nest.
-      // Moreover, the inner loop is not strictly nested (due to the reduction
-      // starting value initialization), and the above omp dialect operations
-      // cannot produce results.
-      // It is unclear what we should do about it yet.
-      auto doLoop = builder.create<fir::DoLoopOp>(
-          loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
-          mlir::ValueRange{initValue});
-
-      // Address the input array using the reduction loop's IV
-      // for the DIM dimension.
-      mlir::Value iv = doLoop.getInductionVar();
-      llvm::SmallVector<mlir::Value> indices{inputIndices};
-      indices.insert(indices.begin() + dimVal - 1, iv);
-
-      mlir::OpBuilder::InsertionGuard guard(builder);
-      builder.setInsertionPointToStart(doLoop.getBody());
-      mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
+      bool emitWorkshareLoop =
+          isTotalReduction ? flangomp::shouldUseWorkshareLowering(sum) : false;
+
+      hlfir::LoopNest loopNest = hlfir::genLoopNest(
+          loc, builder, extents, isUnordered, emitWorkshareLoop);
+
+      llvm::SmallVector<mlir::Value> indices;
+      if (isTotalReduction) {
+        indices = loopNest.oneBasedIndices;
+      } else {
+        indices = inputIndices;
+        indices.insert(indices.begin() + dimVal - 1,
+                       loopNest.oneBasedIndices[0]);
+      }
+
+      builder.setInsertionPointToStart(loopNest.body);
       fir::IfOp ifOp;
       if (mask) {
         // Make the reduction value update conditional on the value
@@ -188,16 +192,15 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
         }
         mlir::Value isUnmasked =
             builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
-        ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
-                                         /*withElseRegion=*/true);
-        // In the 'else' block return the current reduction value.
-        builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
-        builder.create<fir::ResultOp>(loc, reductionValue);
+        ifOp = builder.create<fir::IfOp>(loc, isUnmasked,
+                                         /*withElseRegion=*/false);
 
         // In the 'then' block do the actual addition.
         builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
       }
 
+      mlir::Value reductionValue =
+          builder.create<fir::LoadOp>(loc, reductionTemp);
       hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
       hlfir::Entity elementValue =
           hlfir::loadTrivialScalar(loc, builder, element);
@@ -205,15 +208,18 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
       // (e.g. when fast-math is not allowed), but let's start with
       // the simple version.
       reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
-      builder.create<fir::ResultOp>(loc, reductionValue);
-
-      if (ifOp) {
-        builder.setInsertionPointAfter(ifOp);
-        builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
-      }
+      builder.create<fir::StoreOp>(loc, reductionValue, reductionTemp);
 
-      return hlfir::Entity{doLoop.getResult(0)};
+      builder.setInsertionPointAfter(loopNest.outerOp);
+      return hlfir::Entity{builder.create<fir::LoadOp>(loc, reductionTemp)};
     };
+
+    if (isTotalReduction) {
+      hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{});
+      rewriter.replaceOp(sum, result);
+      return mlir::success();
+    }
+
     hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
         loc, builder, elementType, resultShape, {}, genKernel,
         /*isUnordered=*/true, /*polymorphicMold=*/nullptr,
@@ -229,20 +235,29 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
   }
 
 private:
+  static llvm::SmallVector<mlir::Value>
+  genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
+                  hlfir::Entity array) {
+    mlir::Value inShape = hlfir::genShape(loc, builder, array);
+    llvm::SmallVector<mlir::Value> inExtents =
+        hlfir::getExplicitExtentsFromShape(inShape, builder);
+    if (inShape.getUses().empty())
+      inShape.getDefiningOp()->erase();
+    return inExtents;
+  }
+
   // Return fir.shape specifying the shape of the result
   // of a SUM reduction with DIM=dimVal. The second return value
   // is the extent of the DIM dimension.
   static std::tuple<mlir::Value, mlir::Value>
-  genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
-                 hlfir::Entity array, int64_t dimVal) {
-    mlir::Value inShape = hlfir::genShape(loc, builder, array);
+  genResultShapeForPartialReduction(mlir::Location loc,
+                                    fir::FirOpBuilder &builder,
+                                    hlfir::Entity array, int64_t dimVal) {
     llvm::SmallVector<mlir::Value> inExtents =
-        hlfir::getExplicitExtentsFromShape(inShape, builder);
+        genArrayExtents(loc, builder, array);
     assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
            "DIM must be present and a positive constant not exceeding "
            "the array's rank");
-    if (inShape.getUses().empty())
-      inShape.getDefiningOp()->erase();
 
     mlir::Value dimExtent = inExtents[dimVal - 1];
     inExtents.erase(inExtents.begin() + dimVal - 1);
@@ -355,22 +370,22 @@ class SimplifyHLFIRIntrinsics
     target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
       if (!simplifySum)
         return true;
-      if (mlir::Value dim = sum.getDim()) {
-        if (auto dimVal = fir::getIntIfConstant(dim)) {
-          if (!fir::isa_trivial(sum.getType())) {
-            // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
-            // It is only legal when X is 1, and it should probably be
-            // canonicalized into SUM(a).
-            fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
-                hlfir::getFortranElementOrSequenceType(
-                    sum.getArray().getType()));
-            if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
-              // Ignore SUMs with illegal DIM values.
-              // They may appear in dead code,
-              // and they do not have to be converted.
-              return false;
-            }
-          }
+
+      // Always inline total reductions.
+      if (hlfir::Entity{sum}.getRank() == 0)
+        return false;
+      mlir::Value dim = sum.getDim();
+      if (!dim)
+        return false;
+
+      if (auto dimVal = fir::getIntIfConstant(dim)) {
+        fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
+            hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
+        if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
+          // Ignore SUMs with illegal DIM values.
+          // They may appear in dead code,
+          // and they do not have to be converted.
+          return false;
         }
       }
       return true;
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
index 54a592a66670f1..572b9f0da1e4ab 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -14,9 +14,12 @@ func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
 // CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xi32> {
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index):
-// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_7:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK:             fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_3]] step %[[VAL_9]] unordered {
+// CHECK:               %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
 // CHECK:               %[[VAL_12:.*]] = arith.constant 0 : index
 // CHECK:               %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_12]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
 // CHECK:               %[[VAL_14:.*]] = arith.constant 1 : index
@@ -29,9 +32,10 @@ func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
 // CHECK:               %[[VAL_21:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_18]], %[[VAL_20]])  : (!fir.box<!fir.array<2x3xi32>>, index, index) -> !fir.ref<i32>
 // CHECK:               %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
 // CHECK:               %[[VAL_23:.*]] = arith.addi %[[VAL_11]], %[[VAL_22]] : i32
-// CHECK:               fir.result %[[VAL_23]] : i32
+// CHECK:               fir.store %[[VAL_23]] to %[[VAL_7]] : !fir.ref<i32>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
+// CHECK:             %[[VAL_24:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             hlfir.yield_element %[[VAL_24]] : i32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -50,14 +54,18 @@ func.func @sum_expr_known_extents(%arg0: !hlfir.expr<2x3xi32>) {
 // CHECK:           %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xi32> {
 // CHECK:           ^bb0(%[[VAL_6:.*]]: index):
-// CHECK:             %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_7:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK:             %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK:             fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_2]] step %[[VAL_9]] unordered {
+// CHECK:               %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
 // CHECK:               %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xi32>, index, index) -> i32
 // CHECK:               %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
-// CHECK:               fir.result %[[VAL_13]] : i32
+// CHECK:               fir.store %[[VAL_13]] to %[[VAL_7]] : !fir.ref<i32>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_9]] : i32
+// CHECK:             %[[VAL_14:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
+// CHECK:             hlfir.yield_element %[[VAL_14]] : i32
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -77,12 +85,15 @@ func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xcomplex<f64>> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = fir.alloca complex<f64> {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
 // CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
 // CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
 // CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_3]]#1 step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+// CHECK:             fir.store %[[VAL_12]] to %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:             %[[VAL_13:.*]] = arith.constant 1 : index
+// CHECK:             fir.do_loop %[[VAL_14:.*]] = %[[VAL_13]] to %[[VAL_3]]#1 step %[[VAL_13]] {
+// CHECK:               %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_16:.*]] = arith.constant 0 : index
 // CHECK:               %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
 // CHECK:               %[[VAL_18:.*]] = arith.constant 1 : index
@@ -95,9 +106,10 @@ func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:               %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]])  : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
 // CHECK:               %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
-// CHECK:               fir.result %[[VAL_27]] : complex<f64>
+// CHECK:               fir.store %[[VAL_27]] to %[[VAL_8]] : !fir.ref<complex<f64>>
 // CHECK:             }
-// CHECK:             hlfir.yield_element %[[VAL_13]] : complex<f64>
+// CHECK:             %[[VAL_28:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK:             hlfir.yield_element %[[VAL_28]] : complex<f64>
 // CHECK:           }
 // CHECK:           return
 // CHECK:         }
@@ -116,12 +128,15 @@ func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
 // CHECK:           %[[VAL_5:.*]] = fir.shape %[[VAL_3]]#1 : (index) -> !fir.shape<1>
 // CHECK:           %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xcomplex<f64>> {
 // CHECK:           ^bb0(%[[VAL_7:.*]]: index):
-// CHECK:             %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK:             %[[VAL_8:.*]] = fir.alloca complex<f64> {bindc_name = ".sum.reduction"}
 // CHECK:             %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
 // CHECK:             %[[VAL_10:.*]] = fir.undefined complex<f64>
 // CHECK:             %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
 // CHECK:             %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK:             %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_4]] step %...
[truncated]

@vzakhari vzakhari requested a review from tblah December 12, 2024 00:48

// Create temporary scalar for keeping the running reduction value.
mlir::Value reductionTemp =
builder.createTemporaryAlloc(loc, elementType, ".sum.reduction");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeanPerier, what do you think about calling this outside of genKernel? It looks like it results in stacksave/stackrestore in the stack reclaim pass (after the elemental is transformed into loops), which is not ideal. I think it should be safe to hoist this call provided that the initializing store is kept inside the elemental.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes some sense to me, the only impact I see is that it may make harder parallelization of SUM(, DIM) which is otherwise trivial (each threads does reduces into an element of the result array).

Can you try what happens with a SUM(, DIM) inside a workshare construct? Since you enable the rewrite to an elemental, I think that the elemental to omp loop should kick in and hoisting the alloca may be bad there.

Maybe the stack reclaim pass should hoist constant size alloca outside of loops (with the assumption that parallelization of the loops happened at that point), at least for scalars. This may have impacts on the stack size of course, but for scalars that should be limited.

Since SUM(, DIM) was not parallelized before anyway, your solution would still be acceptable to me though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fir::FirOpBuilder::getAllocaBlock understands OpenMP operations and should give safe insertion points to move allocas in the stack reclaim pass. OpenMP parallelisation will all have happened by then.

If OpenMP workshare support is blocking optimizations in earlier passes please let me know and I will see if I can rethink the design.

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@vzakhari
Copy link
Contributor Author

I tried the following test:

  subroutine test(x, y)
    real :: x(:,:), y(:)
    !$omp parallel workshare
    y = sum(x,dim=1)
    !$omp end parallel workshare
  end subroutine test

When I place the alloca inside the elemental, I get the following:

// -----// IR Dump After OptimizedBufferization (opt-bufferization) //----- //
  omp.parallel {
    omp.workshare {
      omp.workshare.loop_wrapper {
        omp.loop_nest (%arg2) : index = (%c1) to (%4#1) inclusive step (%c1) {
          %5 = fir.alloca f32 {bindc_name = ".sum.reduction", pinned}
          fir.do_loop

// -----// IR Dump After LowerWorkshare (lower-workshare) //----- //
    omp.parallel {
      omp.wsloop nowait {
        omp.loop_nest (%arg2) : index = (%c1) to (%6#1) inclusive step (%c1) {
          %7 = fir.alloca f32 {bindc_name = ".sum.reduction", pinned}
          fir.do_loop

// -----// IR Dump Before LLVMAddComdats (llvm-add-comdats) //----- //
    omp.parallel {
      %52 = llvm.alloca %51 x f32 {bindc_name = ".sum.reduction", pinned} : (i64) -> !llvm.ptr
      omp.wsloop nowait {
        omp.loop_nest (%arg2) : i64 = (%5) to (%62) inclusive step (%5) {
          llvm.store %4, %52 {tbaa = [#tbaa_tag3]} : f32, !llvm.ptr

; *** IR Dump After Annotation2MetadataPass on [module] ***
define internal void @_QFPtest..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #1 {
omp.par.region1:                                  ; preds = %omp.par.region
  %2 = alloca float, i64 1, align 4
...
define internal void @_QFPtest(ptr %0, ptr %1) #0 {
  call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @_QFPtest..omp_par, ptr %structArg)

So it seems that we end up allocating the temporary in each thread, and then the stack is automatically reclaimed after exiting _QFPtest..omp_par. It also looks fine to me at the MLIR level.

When I place the alloca outside the elemental:

// -----// IR Dump After OptimizedBufferization (opt-bufferization) //----- //
  omp.parallel {
    omp.workshare {
      %3 = fir.alloca f32 {bindc_name = ".sum.reduction", pinned}
      omp.workshare.loop_wrapper {
        omp.loop_nest (%arg2) : index = (%c1) to (%5#1) inclusive step (%c1) {
          fir.store %cst to %3 : !fir.ref<f32>
          fir.do_loop

// -----// IR Dump After LowerWorkshare (lower-workshare) //----- //
    omp.parallel {
      %5 = fir.alloca f32 {bindc_name = ".sum.reduction", pinned}
      omp.single copyprivate(%5 -> @_workshare_copy_f32 : !fir.ref<f32>) {
        omp.terminator
      }
      omp.wsloop nowait {
        omp.loop_nest (%arg2) : index = (%c1) to (%7#1) inclusive step (%c1) {
          fir.store %cst to %5 : !fir.ref<f32>
          fir.do_loop

// -----// IR Dump Before LLVMAddComdats (llvm-add-comdats) //----- //
    omp.parallel {
      %52 = llvm.alloca %51 x f32 {bindc_name = ".sum.reduction", pinned} : (i64) -> !llvm.ptr
      omp.single copyprivate(%52 -> @_workshare_copy_f32 : !llvm.ptr) {
        omp.terminator
      }
      omp.wsloop nowait {
        omp.loop_nest (%arg2) : i64 = (%5) to (%63) inclusive step (%5) {
          llvm.store %4, %52 {tbaa = [#tbaa_tag4]} : f32, !llvm.ptr

; *** IR Dump After Annotation2MetadataPass on [module] ***
define internal void @_QFPtest..omp_par(ptr noalias %tid.addr, ptr noalias %zero.addr, ptr %0) #1 {
omp.par.region1:                                  ; preds = %omp.par.region
  %2 = alloca float, i64 1, align 4
  %3 = alloca i32, align 4
  store i32 0, ptr %3, align 4
  %omp_global_thread_num2 = call i32 @__kmpc_global_thread_num(ptr @1)
  %4 = call i32 @__kmpc_single(ptr @1, i32 %omp_global_thread_num2)
  %5 = icmp ne i32 %4, 0
  br i1 %5, label %omp_region.body, label %omp_region.end

omp_region.end:                                   ; preds = %omp.par.region1, %omp.region.cont3
  %omp_global_thread_num4 = call i32 @__kmpc_global_thread_num(ptr @1)
  %6 = load i32, ptr %3, align 4
  call void @__kmpc_copyprivate(ptr @1, i32 %omp_global_thread_num4, i64 0, ptr %2, ptr @_workshare_copy_f32, i32 %6)
...
define internal void @_QFPtest(ptr %0, ptr %1) #0 {
  call void (ptr, i32, ptr, ...) @__kmpc_fork_call(ptr @1, i32 1, ptr @_QFPtest..omp_par, ptr %structArg)

So besides the single copyprivate we end up with the same behavior, I think.

It seems to me hoisting the alloca should be okay, but might be not enough to get rid of stacksave/stackrestore always (e.g. if we generate the elemental inside another loop, we will end up with the stack bookkeeping anyway).

I am now inclining to get back to the SSA reduction, though modifying hlfir::genLoopNest for this will look awkward (due to no support of SSA reduction by the related OpenMP constructs). So I will probably add something like hlfir::genLoopNestWithSSAReduction that will only generate fir.do_loops, and then we can try to merge the two while also resolving it for OpenMP.

What do you think?

@tblah
Copy link
Contributor

tblah commented Dec 13, 2024

What do you think?

Thanks for taking the time to take an in-depth look at OpenMP.

I am worried the added single copyprivate could create a performance problem. Although the real OpenMP applications I have looked at so far do not use any transformational intrinsics inside of workshare.

I agree that hlfir::genLoopNestWithSSAReduction will be a good idea, and we can (eventually) use a custom implementation for openmp using the omp.wsloop reduction(...).

Pinging @ivanradanov and @Thirumalai-Shaktivel from AMD

@vzakhari
Copy link
Contributor Author

I added genLoopNestWithReductions in the latest commit. It actually allows using either SSA or in-memory reductions, and the implementation of genLoopNestWithReductions may choose one or another depending on the context. This requires that the innermost loop body generator is passed with the initial values of the iter-args (reductions in this case), and returns the updated ("incremeneted") values, so that genLoopNestWithReductions can generate proper update of the reduction SSA/in-memory values depending on what the actual loop operations are generated. So the body generator does not need to be aware about what terminators the particular loop operations are using and how reduction values are updated.

I think genLoopNestWithReductions is more generic than genLoopNest, but I did not want to replace all uses of genLoopNest right now, moreover, genLoopNestWithReductions does not support generating proper worksharing operations right now.

@vzakhari vzakhari requested review from jeanPerier and tblah December 13, 2024 16:58
Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me. I don't want to hold this up on OpenMP concerns.

I can roughly imagine how this could be implemented for OpenMP (not for this patch). I think ultimately we will need the genBody callback to operate on scalars so that we can generate the right reduction clause implementation (or alternatively have an argument saying what the openmp intrinsic reduction kind is then use the lowering code to generate the omp.declare_reduction).

I'm imagining something like

omp.declare_reduction @something : type init {
  ^bb0(%unused: type)
    %init_val = // The arith.constant passed as the init value for the reduction variable
    omp.yield(%init_val)
} combiner {
^bb0(%lhs: type, %rhs: type):
  // generated by the genBody callback. e.g. for type == i32
  %res = arith.addi %lhs, %rhs : i32
  omp.yield(%res)
}

func.func [...] {
  %fortran_variable:2 = hlfir.declare [...]
  omp.wsloop reduction(@something %fortran_variable#0 -> %arg0 : !fir.ref<type>) {
    omp.loop_nest /*indices are %arg1...%argn*/ {
      %privatized_variable = hlfir.declare %arg0 [...]
      // generated by genLoopWIthReductions:
      %rhs = hlfir.designate [...]
      // generated by genBody:
      %res = operation %privatized_variable %rhs : type
      omp.yield
    }
  }
  // The result has been stored to %fortran_variable
}

This would give some loss of generality on the indexing. Maybe there should be two callbacks: one for indexing and one for the "combiner" op. IMO this can be done in a later patch.

@vzakhari
Copy link
Contributor Author

Thank you for the example, Tom! Yes, it should be possible to split the body generator into parts so that it works for generating the combiner code and the rhs value code.

@vzakhari vzakhari merged commit a00946f into llvm:main Dec 13, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants