Skip to content

[flang] Turn SimplifyHLFIRIntrinsics into a greedy rewriter. #119946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 58 additions & 89 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Location.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace hlfir {
#define GEN_PASS_DEF_SIMPLIFYHLFIRINTRINSICS
Expand All @@ -44,9 +42,15 @@ class TransposeAsElementalConversion
llvm::LogicalResult
matchAndRewrite(hlfir::TransposeOp transpose,
mlir::PatternRewriter &rewriter) const override {
hlfir::ExprType expr = transpose.getType();
// TODO: hlfir.elemental supports polymorphic data types now,
// so this can be supported.
if (expr.isPolymorphic())
return rewriter.notifyMatchFailure(transpose,
"TRANSPOSE of polymorphic type");

mlir::Location loc = transpose.getLoc();
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
hlfir::ExprType expr = transpose.getType();
mlir::Type elementType = expr.getElementType();
hlfir::Entity array = hlfir::Entity{transpose.getArray()};
mlir::Value resultShape = genResultShape(loc, builder, array);
Expand Down Expand Up @@ -104,15 +108,32 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
llvm::LogicalResult
matchAndRewrite(hlfir::SumOp sum,
mlir::PatternRewriter &rewriter) const override {
if (!simplifySum)
return rewriter.notifyMatchFailure(sum, "SUM simplification is disabled");

hlfir::Entity array = hlfir::Entity{sum.getArray()};
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
mlir::Value dim = sum.getDim();
int64_t dimVal = 0;
if (!isTotalReduction) {
// In case of partial reduction we should ignore the operations
// with invalid DIM values. They may appear in dead code
// after constant propagation.
auto constDim = fir::getIntIfConstant(dim);
if (!constDim)
return rewriter.notifyMatchFailure(sum, "Nonconstant DIM for SUM");
dimVal = *constDim;

if ((dimVal <= 0 || dimVal > array.getRank()))
return rewriter.notifyMatchFailure(
sum, "Invalid DIM for partial SUM reduction");
}

mlir::Location loc = sum.getLoc();
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
hlfir::Entity array = hlfir::Entity{sum.getArray()};
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
int64_t dimVal =
isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);

mlir::Value resultShape, dimExtent;
llvm::SmallVector<mlir::Value> arrayExtents;
if (isTotalReduction)
Expand Down Expand Up @@ -359,27 +380,38 @@ class CShiftAsElementalConversion
public:
using mlir::OpRewritePattern<hlfir::CShiftOp>::OpRewritePattern;

explicit CShiftAsElementalConversion(mlir::MLIRContext *ctx)
: OpRewritePattern(ctx) {
setHasBoundedRewriteRecursion();
}

llvm::LogicalResult
matchAndRewrite(hlfir::CShiftOp cshift,
mlir::PatternRewriter &rewriter) const override {
using Fortran::common::maxRank;

mlir::Location loc = cshift.getLoc();
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(cshift.getType());
assert(expr &&
"expected an expression type for the result of hlfir.cshift");
unsigned arrayRank = expr.getRank();
// When it is a 1D CSHIFT, we may assume that the DIM argument
// (whether it is present or absent) is equal to 1, otherwise,
// the program is illegal.
int64_t dimVal = 1;
if (arrayRank != 1)
if (mlir::Value dim = cshift.getDim()) {
auto constDim = fir::getIntIfConstant(dim);
if (!constDim)
return rewriter.notifyMatchFailure(cshift,
"Nonconstant DIM for CSHIFT");
dimVal = *constDim;
}

if (dimVal <= 0 || dimVal > arrayRank)
return rewriter.notifyMatchFailure(cshift, "Invalid DIM for CSHIFT");

mlir::Location loc = cshift.getLoc();
fir::FirOpBuilder builder{rewriter, cshift.getOperation()};
mlir::Type elementType = expr.getElementType();
hlfir::Entity array = hlfir::Entity{cshift.getArray()};
mlir::Value arrayShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> arrayExtents =
hlfir::getExplicitExtentsFromShape(arrayShape, builder);
unsigned arrayRank = expr.getRank();
llvm::SmallVector<mlir::Value, 1> typeParams;
hlfir::genLengthParameters(loc, builder, array, typeParams);
hlfir::Entity shift = hlfir::Entity{cshift.getShift()};
Expand All @@ -394,20 +426,6 @@ class CShiftAsElementalConversion
shiftVal = builder.createConvert(loc, calcType, shiftVal);
}

int64_t dimVal = 1;
if (arrayRank == 1) {
// When it is a 1D CSHIFT, we may assume that the DIM argument
// (whether it is present or absent) is equal to 1, otherwise,
// the program is illegal.
assert(shiftVal && "SHIFT must be scalar");
} else {
if (mlir::Value dim = cshift.getDim())
dimVal = fir::getIntIfConstant(dim).value_or(0);
assert(dimVal > 0 && dimVal <= arrayRank &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
}

auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange inputIndices) -> hlfir::Entity {
llvm::SmallVector<mlir::Value, maxRank> indices{inputIndices};
Expand Down Expand Up @@ -461,68 +479,19 @@ class SimplifyHLFIRIntrinsics
public:
void runOnOperation() override {
mlir::MLIRContext *context = &getContext();

mlir::GreedyRewriteConfig config;
// Prevent the pattern driver from merging blocks
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;

mlir::RewritePatternSet patterns(context);
patterns.insert<TransposeAsElementalConversion>(context);
patterns.insert<SumAsElementalConversion>(context);
patterns.insert<CShiftAsElementalConversion>(context);
mlir::ConversionTarget target(*context);
// don't transform transpose of polymorphic arrays (not currently supported
// by hlfir.elemental)
target.addDynamicallyLegalOp<hlfir::TransposeOp>(
[](hlfir::TransposeOp transpose) {
return mlir::cast<hlfir::ExprType>(transpose.getType())
.isPolymorphic();
});
// Handle only SUM(DIM=CONSTANT) case for now.
// It may be beneficial to expand the non-DIM case as well.
// E.g. when the input array is an elemental array expression,
// expanding the SUM into a total reduction loop nest
// would avoid creating a temporary for the elemental array expression.
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
if (!simplifySum)
return true;

// Always inline total reductions.
if (hlfir::Entity{sum}.getRank() == 0)
return false;
mlir::Value dim = sum.getDim();
if (!dim)
return false;

if (auto dimVal = fir::getIntIfConstant(dim)) {
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}
return true;
});
target.addDynamicallyLegalOp<hlfir::CShiftOp>([](hlfir::CShiftOp cshift) {
unsigned resultRank = hlfir::Entity{cshift}.getRank();
if (resultRank == 1)
return false;

mlir::Value dim = cshift.getDim();
if (!dim)
return false;

// If DIM is present, then it must be constant to please
// the conversion. In addition, ignore cases with
// illegal DIM values.
if (auto dimVal = fir::getIntIfConstant(dim))
if (*dimVal > 0 && *dimVal <= resultRank)
return false;

return true;
});
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });
if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
std::move(patterns)))) {

if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR intrinsic simplification");
signalPassFailure();
Expand Down
Loading
Loading