@@ -108,7 +108,6 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
108
108
mlir::Value mask = sum.getMask ();
109
109
mlir::Value dim = sum.getDim ();
110
110
int64_t dimVal = fir::getIntIfConstant (dim).value_or (0 );
111
- assert (dimVal > 0 && " DIM must be present and a positive constant" );
112
111
mlir::Value resultShape, dimExtent;
113
112
std::tie (resultShape, dimExtent) =
114
113
genResultShape (loc, builder, array, dimVal);
@@ -235,6 +234,9 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
235
234
mlir::Value inShape = hlfir::genShape (loc, builder, array);
236
235
llvm::SmallVector<mlir::Value> inExtents =
237
236
hlfir::getExplicitExtentsFromShape (inShape, builder);
237
+ assert (dimVal > 0 && dimVal <= static_cast <int64_t >(inExtents.size ()) &&
238
+ " DIM must be present and a positive constant not exceeding "
239
+ " the array's rank" );
238
240
if (inShape.getUses ().empty ())
239
241
inShape.getDefiningOp ()->erase ();
240
242
@@ -348,12 +350,20 @@ class SimplifyHLFIRIntrinsics
348
350
// would avoid creating a temporary for the elemental array expression.
349
351
target.addDynamicallyLegalOp <hlfir::SumOp>([](hlfir::SumOp sum) {
350
352
if (mlir::Value dim = sum.getDim ()) {
351
- if (fir::getIntIfConstant (dim)) {
353
+ if (auto dimVal = fir::getIntIfConstant (dim)) {
352
354
if (!fir::isa_trivial (sum.getType ())) {
353
355
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
354
356
// It is only legal when X is 1, and it should probably be
355
357
// canonicalized into SUM(a).
356
- return false ;
358
+ fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
359
+ hlfir::getFortranElementOrSequenceType (
360
+ sum.getArray ().getType ()));
361
+ if (*dimVal > 0 && *dimVal <= arrayTy.getDimension ()) {
362
+ // Ignore SUMs with illegal DIM values.
363
+ // They may appear in dead code,
364
+ // and they do not have to be converted.
365
+ return false ;
366
+ }
357
367
}
358
368
}
359
369
}
0 commit comments