Skip to content

Commit 2402bcc

Browse files
authored
[flang] Turn SimplifyHLFIRIntrinsics into a greedy rewriter. (#119946)
This is almost an NFC, except that folding changed ordering of some operations.
1 parent f239922 commit 2402bcc

File tree

4 files changed

+242
-331
lines changed

4 files changed

+242
-331
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 58 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
#include "flang/Optimizer/HLFIR/HLFIROps.h"
2020
#include "flang/Optimizer/HLFIR/Passes.h"
2121
#include "mlir/Dialect/Arith/IR/Arith.h"
22-
#include "mlir/Dialect/Func/IR/FuncOps.h"
23-
#include "mlir/IR/BuiltinDialect.h"
2422
#include "mlir/IR/Location.h"
2523
#include "mlir/Pass/Pass.h"
26-
#include "mlir/Transforms/DialectConversion.h"
24+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2725

2826
namespace hlfir {
2927
#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
@@ -45,9 +43,15 @@ class TransposeAsElementalConversion
4543
llvm::LogicalResult
4644
matchAndRewrite(hlfir::TransposeOp transpose,
4745
mlir::PatternRewriter &rewriter) const override {
46+
hlfir::ExprType expr = transpose.getType();
47+
// TODO: hlfir.elemental supports polymorphic data types now,
48+
// so this can be supported.
49+
if (expr.isPolymorphic())
50+
return rewriter.notifyMatchFailure(transpose,
51+
"TRANSPOSE of polymorphic type");
52+
4853
mlir::Location loc = transpose.getLoc();
4954
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
50-
hlfir::ExprType expr = transpose.getType();
5155
mlir::Type elementType = expr.getElementType();
5256
hlfir::Entity array = hlfir::Entity{transpose.getArray()};
5357
mlir::Value resultShape = genResultShape(loc, builder, array);
@@ -105,15 +109,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
105109
llvm::LogicalResult
106110
matchAndRewrite(hlfir::SumOp sum,
107111
mlir::PatternRewriter &rewriter) const override {
112+
if (!simplifySum)
113+
return rewriter.notifyMatchFailure(sum, "SUM simplification is disabled");
114+
115+
hlfir::Entity array = hlfir::Entity{sum.getArray()};
116+
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
117+
mlir::Value dim = sum.getDim();
118+
int64_t dimVal = 0;
119+
if (!isTotalReduction) {
120+
// In case of partial reduction we should ignore the operations
121+
// with invalid DIM values. They may appear in dead code
122+
// after constant propagation.
123+
auto constDim = fir::getIntIfConstant(dim);
124+
if (!constDim)
125+
return rewriter.notifyMatchFailure(sum, "Nonconstant DIM for SUM");
126+
dimVal = *constDim;
127+
128+
if ((dimVal <= 0 || dimVal > array.getRank()))
129+
return rewriter.notifyMatchFailure(
130+
sum, "Invalid DIM for partial SUM reduction");
131+
}
132+
108133
mlir::Location loc = sum.getLoc();
109134
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
110135
mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
111-
hlfir::Entity array = hlfir::Entity{sum.getArray()};
112136
mlir::Value mask = sum.getMask();
113-
mlir::Value dim = sum.getDim();
114-
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
115-
int64_t dimVal =
116-
isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
137+
117138
mlir::Value resultShape, dimExtent;
118139
llvm::SmallVector<mlir::Value> arrayExtents;
119140
if (isTotalReduction)
@@ -360,27 +381,38 @@ class CShiftAsElementalConversion
360381
public:
361382
using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;
362383

363-
explicit CShiftAsElementalConversion(mlir::MLIRContext *ctx)
364-
: OpRewritePattern(ctx) {
365-
setHasBoundedRewriteRecursion();
366-
}
367-
368384
llvm::LogicalResult
369385
matchAndRewrite(hlfir::CShiftOp cshift,
370386
mlir::PatternRewriter &rewriter) const override {
371387
using Fortran::common::maxRank;
372388

373-
mlir::Location loc = cshift.getLoc();
374-
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
375389
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
376390
assert(expr &&
377391
"expected an expression type for the result of hlfir.cshift");
392+
unsigned arrayRank = expr.getRank();
393+
// When it is a 1D CSHIFT, we may assume that the DIM argument
394+
// (whether it is present or absent) is equal to 1, otherwise,
395+
// the program is illegal.
396+
int64_t dimVal = 1;
397+
if (arrayRank != 1)
398+
if (mlir::Value dim = cshift.getDim()) {
399+
auto constDim = fir::getIntIfConstant(dim);
400+
if (!constDim)
401+
return rewriter.notifyMatchFailure(cshift,
402+
"Nonconstant DIM for CSHIFT");
403+
dimVal = *constDim;
404+
}
405+
406+
if (dimVal <= 0 || dimVal > arrayRank)
407+
return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT");
408+
409+
mlir::Location loc = cshift.getLoc();
410+
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
378411
mlir::Type elementType = expr.getElementType();
379412
hlfir::Entity array = hlfir::Entity{cshift.getArray()};
380413
mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
381414
llvm::SmallVector<mlir::Value> arrayExtents =
382415
hlfir::getExplicitExtentsFromShape(arrayShape, builder);
383-
unsigned arrayRank = expr.getRank();
384416
llvm::SmallVector<mlir::Value, 1> typeParams;
385417
hlfir::genLengthParameters(loc, builder, array, typeParams);
386418
hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
@@ -395,20 +427,6 @@ class CShiftAsElementalConversion
395427
shiftVal = builder.createConvert(loc, calcType, shiftVal);
396428
}
397429

398-
int64_t dimVal = 1;
399-
if (arrayRank == 1) {
400-
// When it is a 1D CSHIFT, we may assume that the DIM argument
401-
// (whether it is present or absent) is equal to 1, otherwise,
402-
// the program is illegal.
403-
assert(shiftVal && "SHIFT must be scalar");
404-
} else {
405-
if (mlir::Value dim = cshift.getDim())
406-
dimVal = fir::getIntIfConstant(dim).value_or(0);
407-
assert(dimVal > 0 && dimVal <= arrayRank &&
408-
"DIM must be present and a positive constant not exceeding "
409-
"the array's rank");
410-
}
411-
412430
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
413431
mlir::ValueRange inputIndices) -> hlfir::Entity {
414432
llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
@@ -462,68 +480,19 @@ class SimplifyHLFIRIntrinsics
462480
public:
463481
void runOnOperation() override {
464482
mlir::MLIRContext *context = &getContext();
483+
484+
mlir::GreedyRewriteConfig config;
485+
// Prevent the pattern driver from merging blocks
486+
config.enableRegionSimplification =
487+
mlir::GreedySimplifyRegionLevel::Disabled;
488+
465489
mlir::RewritePatternSet patterns(context);
466490
patterns.insert<TransposeAsElementalConversion>(context);
467491
patterns.insert<SumAsElementalConversion>(context);
468492
patterns.insert<CShiftAsElementalConversion>(context);
469-
mlir::ConversionTarget target(*context);
470-
// don't transform transpose of polymorphic arrays (not currently supported
471-
// by hlfir.elemental)
472-
target.addDynamicallyLegalOp<hlfir::TransposeOp>(
473-
[](hlfir::TransposeOp transpose) {
474-
return mlir::cast<hlfir::ExprType>(transpose.getType())
475-
.isPolymorphic();
476-
});
477-
// Handle only SUM(DIM=CONSTANT) case for now.
478-
// It may be beneficial to expand the non-DIM case as well.
479-
// E.g. when the input array is an elemental array expression,
480-
// expanding the SUM into a total reduction loop nest
481-
// would avoid creating a temporary for the elemental array expression.
482-
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
483-
if (!simplifySum)
484-
return true;
485-
486-
// Always inline total reductions.
487-
if (hlfir::Entity{sum}.getRank() == 0)
488-
return false;
489-
mlir::Value dim = sum.getDim();
490-
if (!dim)
491-
return false;
492-
493-
if (auto dimVal = fir::getIntIfConstant(dim)) {
494-
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
495-
hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
496-
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
497-
// Ignore SUMs with illegal DIM values.
498-
// They may appear in dead code,
499-
// and they do not have to be converted.
500-
return false;
501-
}
502-
}
503-
return true;
504-
});
505-
target.addDynamicallyLegalOp<hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
506-
unsigned resultRank = hlfir::Entity{cshift}.getRank();
507-
if (resultRank == 1)
508-
return false;
509-
510-
mlir::Value dim = cshift.getDim();
511-
if (!dim)
512-
return false;
513-
514-
// If DIM is present, then it must be constant to please
515-
// the conversion. In addition, ignore cases with
516-
// illegal DIM values.
517-
if (auto dimVal = fir::getIntIfConstant(dim))
518-
if (*dimVal > 0 && *dimVal <= resultRank)
519-
return false;
520-
521-
return true;
522-
});
523-
target.markUnknownOpDynamicallyLegal(
524-
[](mlir::Operation *) { return true; });
525-
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
526-
std::move(patterns)))) {
493+
494+
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
495+
getOperation(), std::move(patterns), config))) {
527496
mlir::emitError(getOperation()->getLoc(),
528497
"failure in HLFIR intrinsic simplification");
529498
signalPassFailure();

0 commit comments

Comments
 (0)