Skip to content

Commit 084451c

Browse files
authored
[flang] Do not inline SUM with invalid DIM argument. (#118911)
Such SUMs might appear in dead code after constant propagation. They do not have to be inlined.
1 parent 1ca3927 commit 084451c

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
108108
mlir::Value mask = sum.getMask();
109109
mlir::Value dim = sum.getDim();
110110
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
111-
assert(dimVal > 0 && "DIM must be present and a positive constant");
112111
mlir::Value resultShape, dimExtent;
113112
std::tie(resultShape, dimExtent) =
114113
genResultShape(loc, builder, array, dimVal);
@@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
235234
mlir::Value inShape = hlfir::genShape(loc, builder, array);
236235
llvm::SmallVector<mlir::Value> inExtents =
237236
hlfir::getExplicitExtentsFromShape(inShape, builder);
237+
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
238+
"DIM must be present and a positive constant not exceeding "
239+
"the array's rank");
238240
if (inShape.getUses().empty())
239241
inShape.getDefiningOp()->erase();
240242

@@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics
348350
// would avoid creating a temporary for the elemental array expression.
349351
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
350352
if (mlir::Value dim = sum.getDim()) {
351-
if (fir::getIntIfConstant(dim)) {
353+
if (auto dimVal = fir::getIntIfConstant(dim)) {
352354
if (!fir::isa_trivial(sum.getType())) {
353355
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
354356
// It is only legal when X is 1, and it should probably be
355357
// canonicalized into SUM(a).
356-
return false;
358+
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
359+
hlfir::getFortranElementOrSequenceType(
360+
sum.getArray().getType()));
361+
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
362+
// Ignore SUMs with illegal DIM values.
363+
// They may appear in dead code,
364+
// and they do not have to be converted.
365+
return false;
366+
}
357367
}
358368
}
359369
}

flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,3 +411,21 @@ func.func @sum_non_const_dim(%arg0: !fir.box<!fir.array<3xi32>>, %dim: i32) {
411411
// CHECK: %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
412412
// CHECK: return
413413
// CHECK: }
414+
415+
// negative: invalid dim==0
416+
func.func @sum_invalid_dim0(%arg0: !hlfir.expr<2x3xi32>) {
417+
%cst = arith.constant 0 : i32
418+
%res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
419+
return
420+
}
421+
// CHECK-LABEL: func.func @sum_invalid_dim0(
422+
// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
423+
424+
// negative: invalid dim>rank
425+
func.func @sum_invalid_dim_big(%arg0: !hlfir.expr<2x3xi32>) {
426+
%cst = arith.constant 3 : i32
427+
%res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
428+
return
429+
}
430+
// CHECK-LABEL: func.func @sum_invalid_dim_big(
431+
// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>

0 commit comments

Comments
 (0)