Skip to content

[flang] Do not inline SUM with invalid DIM argument. #118911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
assert(dimVal > 0 && "DIM must be present and a positive constant");
mlir::Value resultShape, dimExtent;
std::tie(resultShape, dimExtent) =
genResultShape(loc, builder, array, dimVal);
Expand Down Expand Up @@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();

Expand Down Expand Up @@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics
// would avoid creating a temporary for the elemental array expression.
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
if (mlir::Value dim = sum.getDim()) {
if (fir::getIntIfConstant(dim)) {
if (auto dimVal = fir::getIntIfConstant(dim)) {
if (!fir::isa_trivial(sum.getType())) {
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
// It is only legal when X is 1, and it should probably be
// canonicalized into SUM(a).
return false;
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(
sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions flang/test/HLFIR/simplify-hlfir-intrinsics-sum.fir
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,21 @@ func.func @sum_non_const_dim(%arg0: !fir.box<!fir.array<3xi32>>, %dim: i32) {
// CHECK: %[[VAL_2:.*]] = hlfir.sum %[[VAL_0]] dim %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>, i32) -> i32
// CHECK: return
// CHECK: }

// negative: invalid dim==0
func.func @sum_invalid_dim0(%arg0: !hlfir.expr<2x3xi32>) {
%cst = arith.constant 0 : i32
%res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
return
}
// CHECK-LABEL: func.func @sum_invalid_dim0(
// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>

// negative: invalid dim>rank
func.func @sum_invalid_dim_big(%arg0: !hlfir.expr<2x3xi32>) {
%cst = arith.constant 3 : i32
%res = hlfir.sum %arg0 dim %cst : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
return
}
// CHECK-LABEL: func.func @sum_invalid_dim_big(
// CHECK: hlfir.sum %{{.*}} dim %{{.*}} : (!hlfir.expr<2x3xi32>, i32) -> !hlfir.expr<3xi32>
Loading