-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang] Simplify hlfir.sum total reductions. #119482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I am trying to switch to keeping the reduction value in a temporary scalar location so that I can use hlfir::genLoopNest easily. This also allows using omp.loop_nest with worksharing for OpenMP.
@llvm/pr-subscribers-flang-fir-hlfir Author: Slava Zakharin (vzakhari) ChangesI am trying to switch to keeping the reduction value in a temporary Patch is 46.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/119482.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index b61f9767ccc2b8..2bb1a786f6c12c 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -17,6 +17,7 @@
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinDialect.h"
@@ -105,34 +106,47 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = sum.getLoc();
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
- hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
- assert(expr && "expected an expression type for the result of hlfir.sum");
- mlir::Type elementType = expr.getElementType();
+ mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
hlfir::Entity array = hlfir::Entity{sum.getArray()};
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
- int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
+ bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
+ int64_t dimVal =
+ isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
mlir::Value resultShape, dimExtent;
- std::tie(resultShape, dimExtent) =
- genResultShape(loc, builder, array, dimVal);
+ llvm::SmallVector<mlir::Value> arrayExtents;
+ if (isTotalReduction)
+ arrayExtents = genArrayExtents(loc, builder, array);
+ else
+ std::tie(resultShape, dimExtent) =
+ genResultShapeForPartialReduction(loc, builder, array, dimVal);
+
+ // If the mask is present and is a scalar, then we'd better load its value
+ // outside of the reduction loop making the loop unswitching easier.
+ mlir::Value isPresentPred, maskValue;
+ if (mask) {
+ if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
+ // MASK represented by a box might be dynamically optional,
+ // so we have to check for its presence before accessing it.
+ isPresentPred =
+ builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
+ }
+
+ if (hlfir::Entity{mask}.isScalar())
+ maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
+ }
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange inputIndices) -> hlfir::Entity {
// Loop over all indices in the DIM dimension, and reduce all values.
- // We do not need to create the reduction loop always: if we can
- // slice the input array given the inputIndices, then we can
- // just apply a new SUM operation (total reduction) to the slice.
- // For the time being, generate the explicit loop because the slicing
- // requires generating an elemental operation for the input array
- // (and the mask, if present).
- // TODO: produce the slices and new SUM after adding a pattern
- // for expanding total reduction SUM case.
- mlir::Type indexType = builder.getIndexType();
- auto one = builder.createIntegerConstant(loc, indexType, 1);
- auto ub = builder.createConvert(loc, indexType, dimExtent);
+ // If DIM is not present, do total reduction.
+ // Create temporary scalar for keeping the running reduction value.
+ mlir::Value reductionTemp =
+ builder.createTemporaryAlloc(loc, elementType, ".sum.reduction");
// Initial value for the reduction.
mlir::Value initValue = genInitValue(loc, builder, elementType);
+ builder.create<fir::StoreOp>(loc, initValue, reductionTemp);
// The reduction loop may be unordered if FastMathFlags::reassoc
// transformations are allowed. The integer reduction is always
@@ -141,42 +155,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
static_cast<bool>(sum.getFastmath() &
mlir::arith::FastMathFlags::reassoc);
- // If the mask is present and is a scalar, then we'd better load its value
- // outside of the reduction loop making the loop unswitching easier.
- // Maybe it is worth hoisting it from the elemental operation as well.
- mlir::Value isPresentPred, maskValue;
- if (mask) {
- if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
- // MASK represented by a box might be dynamically optional,
- // so we have to check for its presence before accessing it.
- isPresentPred =
- builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
- }
-
- if (hlfir::Entity{mask}.isScalar())
- maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
- }
+ llvm::SmallVector<mlir::Value> extents;
+ if (isTotalReduction)
+ extents = arrayExtents;
+ else
+ extents.push_back(
+ builder.createConvert(loc, builder.getIndexType(), dimExtent));
// NOTE: the outer elemental operation may be lowered into
// omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
// loop may appear disjoint from the workshare loop nest.
- // Moreover, the inner loop is not strictly nested (due to the reduction
- // starting value initialization), and the above omp dialect operations
- // cannot produce results.
- // It is unclear what we should do about it yet.
- auto doLoop = builder.create<fir::DoLoopOp>(
- loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
- mlir::ValueRange{initValue});
-
- // Address the input array using the reduction loop's IV
- // for the DIM dimension.
- mlir::Value iv = doLoop.getInductionVar();
- llvm::SmallVector<mlir::Value> indices{inputIndices};
- indices.insert(indices.begin() + dimVal - 1, iv);
-
- mlir::OpBuilder::InsertionGuard guard(builder);
- builder.setInsertionPointToStart(doLoop.getBody());
- mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
+ bool emitWorkshareLoop =
+ isTotalReduction ? flangomp::shouldUseWorkshareLowering(sum) : false;
+
+ hlfir::LoopNest loopNest = hlfir::genLoopNest(
+ loc, builder, extents, isUnordered, emitWorkshareLoop);
+
+ llvm::SmallVector<mlir::Value> indices;
+ if (isTotalReduction) {
+ indices = loopNest.oneBasedIndices;
+ } else {
+ indices = inputIndices;
+ indices.insert(indices.begin() + dimVal - 1,
+ loopNest.oneBasedIndices[0]);
+ }
+
+ builder.setInsertionPointToStart(loopNest.body);
fir::IfOp ifOp;
if (mask) {
// Make the reduction value update conditional on the value
@@ -188,16 +192,15 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
}
mlir::Value isUnmasked =
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
- ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
- /*withElseRegion=*/true);
- // In the 'else' block return the current reduction value.
- builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
- builder.create<fir::ResultOp>(loc, reductionValue);
+ ifOp = builder.create<fir::IfOp>(loc, isUnmasked,
+ /*withElseRegion=*/false);
// In the 'then' block do the actual addition.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
}
+ mlir::Value reductionValue =
+ builder.create<fir::LoadOp>(loc, reductionTemp);
hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
hlfir::Entity elementValue =
hlfir::loadTrivialScalar(loc, builder, element);
@@ -205,15 +208,18 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
// (e.g. when fast-math is not allowed), but let's start with
// the simple version.
reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
- builder.create<fir::ResultOp>(loc, reductionValue);
-
- if (ifOp) {
- builder.setInsertionPointAfter(ifOp);
- builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
- }
+ builder.create<fir::StoreOp>(loc, reductionValue, reductionTemp);
- return hlfir::Entity{doLoop.getResult(0)};
+ builder.setInsertionPointAfter(loopNest.outerOp);
+ return hlfir::Entity{builder.create<fir::LoadOp>(loc, reductionTemp)};
};
+
+ if (isTotalReduction) {
+ hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{});
+ rewriter.replaceOp(sum, result);
+ return mlir::success();
+ }
+
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
loc, builder, elementType, resultShape, {}, genKernel,
/*isUnordered=*/true, /*polymorphicMold=*/nullptr,
@@ -229,20 +235,29 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
}
private:
+ static llvm::SmallVector<mlir::Value>
+ genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
+ hlfir::Entity array) {
+ mlir::Value inShape = hlfir::genShape(loc, builder, array);
+ llvm::SmallVector<mlir::Value> inExtents =
+ hlfir::getExplicitExtentsFromShape(inShape, builder);
+ if (inShape.getUses().empty())
+ inShape.getDefiningOp()->erase();
+ return inExtents;
+ }
+
// Return fir.shape specifying the shape of the result
// of a SUM reduction with DIM=dimVal. The second return value
// is the extent of the DIM dimension.
static std::tuple<mlir::Value, mlir::Value>
- genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
- hlfir::Entity array, int64_t dimVal) {
- mlir::Value inShape = hlfir::genShape(loc, builder, array);
+ genResultShapeForPartialReduction(mlir::Location loc,
+ fir::FirOpBuilder &builder,
+ hlfir::Entity array, int64_t dimVal) {
llvm::SmallVector<mlir::Value> inExtents =
- hlfir::getExplicitExtentsFromShape(inShape, builder);
+ genArrayExtents(loc, builder, array);
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
- if (inShape.getUses().empty())
- inShape.getDefiningOp()->erase();
mlir::Value dimExtent = inExtents[dimVal - 1];
inExtents.erase(inExtents.begin() + dimVal - 1);
@@ -355,22 +370,22 @@ class SimplifyHLFIRIntrinsics
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
if (!simplifySum)
return true;
- if (mlir::Value dim = sum.getDim()) {
- if (auto dimVal = fir::getIntIfConstant(dim)) {
- if (!fir::isa_trivial(sum.getType())) {
- // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
- // It is only legal when X is 1, and it should probably be
- // canonicalized into SUM(a).
- fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
- hlfir::getFortranElementOrSequenceType(
- sum.getArray().getType()));
- if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
- // Ignore SUMs with illegal DIM values.
- // They may appear in dead code,
- // and they do not have to be converted.
- return false;
- }
- }
+
+ // Always inline total reductions.
+ if (hlfir::Entity{sum}.getRank() == 0)
+ return false;
+ mlir::Value dim = sum.getDim();
+ if (!dim)
+ return false;
+
+ if (auto dimVal = fir::getIntIfConstant(dim)) {
+ fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
+ hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
+ if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
+ // Ignore SUMs with illegal DIM values.
+ // They may appear in dead code,
+ // and they do not have to be converted.
+ return false;
}
}
return true;
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
index 54a592a66670f1..572b9f0da1e4ab 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -14,9 +14,12 @@ func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<2xi32> {
// CHECK: ^bb0(%[[VAL_6:.*]]: index):
-// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_7:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_3]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK: fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<i32>
+// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK: fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_3]] step %[[VAL_9]] unordered {
+// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
// CHECK: %[[VAL_12:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_13:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_12]] : (!fir.box<!fir.array<2x3xi32>>, index) -> (index, index, index)
// CHECK: %[[VAL_14:.*]] = arith.constant 1 : index
@@ -29,9 +32,10 @@ func.func @sum_box_known_extents(%arg0: !fir.box<!fir.array<2x3xi32>>) {
// CHECK: %[[VAL_21:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_18]], %[[VAL_20]]) : (!fir.box<!fir.array<2x3xi32>>, index, index) -> !fir.ref<i32>
// CHECK: %[[VAL_22:.*]] = fir.load %[[VAL_21]] : !fir.ref<i32>
// CHECK: %[[VAL_23:.*]] = arith.addi %[[VAL_11]], %[[VAL_22]] : i32
-// CHECK: fir.result %[[VAL_23]] : i32
+// CHECK: fir.store %[[VAL_23]] to %[[VAL_7]] : !fir.ref<i32>
// CHECK: }
-// CHECK: hlfir.yield_element %[[VAL_9]] : i32
+// CHECK: %[[VAL_24:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
+// CHECK: hlfir.yield_element %[[VAL_24]] : i32
// CHECK: }
// CHECK: return
// CHECK: }
@@ -50,14 +54,18 @@ func.func @sum_expr_known_extents(%arg0: !hlfir.expr<2x3xi32>) {
// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_5:.*]] = hlfir.elemental %[[VAL_4]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xi32> {
// CHECK: ^bb0(%[[VAL_6:.*]]: index):
-// CHECK: %[[VAL_7:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_7:.*]] = fir.alloca i32 {bindc_name = ".sum.reduction"}
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_9:.*]] = fir.do_loop %[[VAL_10:.*]] = %[[VAL_7]] to %[[VAL_2]] step %[[VAL_7]] unordered iter_args(%[[VAL_11:.*]] = %[[VAL_8]]) -> (i32) {
+// CHECK: fir.store %[[VAL_8]] to %[[VAL_7]] : !fir.ref<i32>
+// CHECK: %[[VAL_9:.*]] = arith.constant 1 : index
+// CHECK: fir.do_loop %[[VAL_10:.*]] = %[[VAL_9]] to %[[VAL_2]] step %[[VAL_9]] unordered {
+// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
// CHECK: %[[VAL_12:.*]] = hlfir.apply %[[VAL_0]], %[[VAL_10]], %[[VAL_6]] : (!hlfir.expr<2x3xi32>, index, index) -> i32
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
-// CHECK: fir.result %[[VAL_13]] : i32
+// CHECK: fir.store %[[VAL_13]] to %[[VAL_7]] : !fir.ref<i32>
// CHECK: }
-// CHECK: hlfir.yield_element %[[VAL_9]] : i32
+// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_7]] : !fir.ref<i32>
+// CHECK: hlfir.yield_element %[[VAL_14]] : i32
// CHECK: }
// CHECK: return
// CHECK: }
@@ -77,12 +85,15 @@ func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<3xcomplex<f64>> {
// CHECK: ^bb0(%[[VAL_7:.*]]: index):
-// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_8:.*]] = fir.alloca complex<f64> {bindc_name = ".sum.reduction"}
// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[VAL_10:.*]] = fir.undefined complex<f64>
// CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
// CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_3]]#1 step %[[VAL_8]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (complex<f64>) {
+// CHECK: fir.store %[[VAL_12]] to %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK: %[[VAL_13:.*]] = arith.constant 1 : index
+// CHECK: fir.do_loop %[[VAL_14:.*]] = %[[VAL_13]] to %[[VAL_3]]#1 step %[[VAL_13]] {
+// CHECK: %[[VAL_15:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
// CHECK: %[[VAL_16:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_17:.*]]:3 = fir.box_dims %[[VAL_0]], %[[VAL_16]] : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index) -> (index, index, index)
// CHECK: %[[VAL_18:.*]] = arith.constant 1 : index
@@ -95,9 +106,10 @@ func.func @sum_box_unknown_extent1(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
// CHECK: %[[VAL_25:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_22]], %[[VAL_24]]) : (!fir.box<!fir.array<?x3xcomplex<f64>>>, index, index) -> !fir.ref<complex<f64>>
// CHECK: %[[VAL_26:.*]] = fir.load %[[VAL_25]] : !fir.ref<complex<f64>>
// CHECK: %[[VAL_27:.*]] = fir.addc %[[VAL_15]], %[[VAL_26]] : complex<f64>
-// CHECK: fir.result %[[VAL_27]] : complex<f64>
+// CHECK: fir.store %[[VAL_27]] to %[[VAL_8]] : !fir.ref<complex<f64>>
// CHECK: }
-// CHECK: hlfir.yield_element %[[VAL_13]] : complex<f64>
+// CHECK: %[[VAL_28:.*]] = fir.load %[[VAL_8]] : !fir.ref<complex<f64>>
+// CHECK: hlfir.yield_element %[[VAL_28]] : complex<f64>
// CHECK: }
// CHECK: return
// CHECK: }
@@ -116,12 +128,15 @@ func.func @sum_box_unknown_extent2(%arg0: !fir.box<!fir.array<?x3xcomplex<f64>>>
// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_3]]#1 : (index) -> !fir.shape<1>
// CHECK: %[[VAL_6:.*]] = hlfir.elemental %[[VAL_5]] unordered : (!fir.shape<1>) -> !hlfir.expr<?xcomplex<f64>> {
// CHECK: ^bb0(%[[VAL_7:.*]]: index):
-// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_8:.*]] = fir.alloca complex<f64> {bindc_name = ".sum.reduction"}
// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f64
// CHECK: %[[VAL_10:.*]] = fir.undefined complex<f64>
// CHECK: %[[VAL_11:.*]] = fir.insert_value %[[VAL_10]], %[[VAL_9]], [0 : index] : (complex<f64>, f64) -> complex<f64>
// CHECK: %[[VAL_12:.*]] = fir.insert_value %[[VAL_11]], %[[VAL_9]], [1 : index] : (complex<f64>, f64) -> complex<f64>
-// CHECK: %[[VAL_13:.*]] = fir.do_loop %[[VAL_14:.*]] = %[[VAL_8]] to %[[VAL_4]] step %...
[truncated]
|
flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Outdated
Show resolved
Hide resolved
|
||
// Create temporary scalar for keeping the running reduction value. | ||
mlir::Value reductionTemp = | ||
builder.createTemporaryAlloc(loc, elementType, ".sum.reduction"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeanPerier, what do you think about calling this outside of genKernel
? It looks like it results in stacksave/stackrestore in the stack reclaim pass (after the elemental is transformed into loops), which is not ideal. I think it should be safe to hoist this call provided that the initializing store is kept inside the elemental.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes some sense to me, the only impact I see is that it may make harder parallelization of SUM(, DIM) which is otherwise trivial (each threads does reduces into an element of the result array).
Can you try what happens with a SUM(, DIM) inside a workshare construct? Since you enable the rewrite to an elemental, I think that the elemental to omp loop should kick in and hoisting the alloca may be bad there.
Maybe the stack reclaim pass should hoist constant size alloca outside of loops (with the assumption that parallelization of the loops happened at that point), at least for scalars. This may have impacts on the stack size of course, but for scalars that should be limited.
Since SUM(, DIM) was not parallelized before anyway, your solution would still be acceptable to me though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fir::FirOpBuilder::getAllocaBlock
understands OpenMP operations and should give safe insertion points to move allocas in the stack reclaim pass. OpenMP parallelisation will all have happened by then.
If OpenMP workshare support is blocking optimizations in earlier passes please let me know and I will see if I can rethink the design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
I tried the following test:
When I place the alloca inside the elemental, I get the following:
So it seems that we end up allocating the temporary in each thread, and then the stack is automatically reclaimed after exiting When I place the alloca outside the elemental:
So besides the It seems to me hoisting the alloca should be okay, but might be not enough to get rid of stacksave/stackrestore always (e.g. if we generate the elemental inside another loop, we will end up with the stack bookkeeping anyway). I am now inclining to get back to the SSA reduction, though modifying What do you think? |
Thanks for taking the time to take an in-depth look at OpenMP. I am worried the added I agree that Pinging @ivanradanov and @Thirumalai-Shaktivel from AMD |
I added I think |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me. I don't want to hold this up on OpenMP concerns.
I can roughly imagine how this could be implemented for OpenMP (not for this patch). I think ultimately we will need the genBody callback to operate on scalars so that we can generate the right reduction clause implementation (or alternatively have an argument saying what the openmp intrinsic reduction kind is then use the lowering code to generate the omp.declare_reduction
).
I'm imagining something like
omp.declare_reduction @something : type init {
^bb0(%unused: type)
%init_val = // The arith.constant passed as the init value for the reduction variable
omp.yield(%init_val)
} combiner {
^bb0(%lhs: type, %rhs: type):
// generated by the genBody callback. e.g. for type == i32
%res = arith.addi %lhs, %rhs : i32
omp.yield(%res)
}
func.func [...] {
%fortran_variable:2 = hlfir.declare [...]
omp.wsloop reduction(@something %fortran_variable#0 -> %arg0 : !fir.ref<type>) {
omp.loop_nest /*indices are %arg1...%argn*/ {
%privatized_variable = hlfir.declare %arg0 [...]
// generated by genLoopWIthReductions:
%rhs = hlfir.designate [...]
// generated by genBody:
%res = operation %privatized_variable %rhs : type
omp.yield
}
}
// The result has been stored to %fortran_variable
}
This would give some loss of generality on the indexing. Maybe there should be two callbacks: one for indexing and one for the "combiner" op. IMO this can be done in a later patch.
Thank you for the example, Tom! Yes, it should be possible to split the body generator into parts so that it works for generating the combiner code and the rhs value code. |
I am trying to switch to keeping the reduction value in a temporary
scalar location so that I can use hlfir::genLoopNest easily.
This also allows using omp.loop_nest with worksharing for OpenMP.