Skip to content

[flang] Do not inline SUM with invalid DIM argument. #118911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 9, 2024

Conversation

vzakhari
Copy link
Contributor

@vzakhari vzakhari commented Dec 6, 2024

Such SUMs might appear in dead code after constant propagation.
They do not have to be inlined.

Such SUMs might appear in dead code after constant propagation.
They do not have to be inlined.
@vzakhari vzakhari requested a review from jeanPerier December 6, 2024 02:10
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Dec 6, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Slava Zakharin (vzakhari)

Changes

Such SUMs might appear in dead code after constant propagation.
They do not have to be inlined.


Full diff: https://github.com/llvm/llvm-project/pull/118911.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp (+13-3)
  • (modified) flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir (+18)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
index 0c34c8221aeda6..ace63a970db932 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
@@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
     mlir::Value mask = sum.getMask();
     mlir::Value dim = sum.getDim();
     int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
-    assert(dimVal > 0 && "DIM must be present and a positive constant");
     mlir::Value resultShape, dimExtent;
     std::tie(resultShape, dimExtent) =
         genResultShape(loc, builder, array, dimVal);
@@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
     mlir::Value inShape = hlfir::genShape(loc, builder, array);
     llvm::SmallVector<mlir::Value> inExtents =
         hlfir::getExplicitExtentsFromShape(inShape, builder);
+    assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
+           "DIM must be present and a positive constant not exceeding "
+           "the array's rank");
     if (inShape.getUses().empty())
       inShape.getDefiningOp()->erase();
 
@@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics
     // would avoid creating a temporary for the elemental array expression.
     target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
       if (mlir::Value dim = sum.getDim()) {
-        if (fir::getIntIfConstant(dim)) {
+        if (auto dimVal = fir::getIntIfConstant(dim)) {
           if (!fir::isa_trivial(sum.getType())) {
             // Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
             // It is only legal when X is 1, and it should probably be
             // canonicalized into SUM(a).
-            return false;
+            fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
+                hlfir::getFortranElementOrSequenceType(
+                    sum.getArray().getType()));
+            if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
+              // Ignore SUMs with illegal DIM values.
+              // They may appear in dead code,
+              // and they do not have to be converted.
+              return false;
+            }
           }
         }
       }
diff --git a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
index 703b6673154f3f..313e54d5d0c4af 100644
--- a/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
+++ b/flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
@@ -411,3 +411,21 @@ func.func @sum_non_const_dim(%arg0: !fir.box<!fir.array<3xi32>>, %dim: i32) {
 // CHECK:           %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
 // CHECK:           return
 // CHECK:         }
+
+// negative: invalid dim==0
+func.func @sum_invalid_dim0(%arg0: !hlfir.expr<2x3xi32>) {
+  %cst = arith.constant 0 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_invalid_dim0(
+// CHECK:           hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+
+// negative: invalid dim>rank
+func.func @sum_invalid_dim_big(%arg0: !hlfir.expr<2x3xi32>) {
+  %cst = arith.constant 3 : i32
+  %res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
+  return
+}
+// CHECK-LABEL:   func.func @sum_invalid_dim_big(
+// CHECK:           hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>

Copy link
Contributor

@jeanPerier jeanPerier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@vzakhari vzakhari merged commit 084451c into llvm:main Dec 9, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants