Skip to content

[flang][hlfir][NFC] Fix mlir misuse in LowerHLFIRIntrinsics #83293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 51 additions & 48 deletions flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ namespace {
/// Base class for passes converting transformational intrinsic operations into
/// runtime calls
template <class OP>
class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
class HlfirIntrinsicConversion : public mlir::OpConversionPattern<OP> {
public:
explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
: mlir::OpRewritePattern<OP>{ctx} {
: mlir::OpConversionPattern<OP>{ctx} {
// required for cases where intrinsics are chained together e.g.
// matmul(matmul(a, b), c)
// because converting the inner operation then invalidates the
Expand Down Expand Up @@ -145,7 +145,7 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
void processReturnValue(mlir::Operation *op,
const fir::ExtendedValue &resultExv, bool mustBeFreed,
fir::FirOpBuilder &builder,
mlir::PatternRewriter &rewriter) const {
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Location loc = op->getLoc();

mlir::Value firBase = fir::getBase(resultExv);
Expand Down Expand Up @@ -176,13 +176,9 @@ class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
rewriter.eraseOp(use);
}
}
// TODO: This entire pass should be a greedy pattern rewrite or a manual
// IR traversal. A dialect conversion cannot be used here because
// `replaceAllUsesWith` is not supported. Similarly, `replaceOp` is not
// suitable because "op->getResult(0)" and "base" can have different types.
// In such a case, the dialect conversion will attempt to convert the type,
// but no type converter is specified in this pass. Also note that all
// patterns in this pass are actually rewrite patterns.
// the types might not match exactly (but are safe)
// e.g. !hlfir.expr<?xi32> vs !hlfir.expr<2xi32>
// TODO: is this allowed by MLIR?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is allowed. (As long as you don't produce an invalid operation, everything's fine.)

The problem in the current implementation is that op->getResult(0).replaceAllUsesWith(base); bypasses the rewriter API, i.e., making an IR modification without going through the rewriter. The rewriter equivalent would be RewriterBase::replaceAllUsesWith. But that function is not supported by the conversion pattern rewriter yet.

But it should also not be necessary:

    op->getResult(0).replaceAllUsesWith(base);
    rewriter.replaceOp(op, base);

rewriter.replaceOp already performs the replacement, so the previous replaceAllUsesWith is redundant. I tried removing the replaceAllUsesWith but then the dialect conversion no longer succeeds. I think the problem is that whenever the type changes during a rewriter.replaceOp, the dialect conversion tries to build a type conversion op with the type converter. In this case here, no type conversion is needed, but that's not something that the dialect conversion considers. The error that I saw was due to the fact that no type converter is specified. (Take a look at the ConversionOpPattern constructor, you can pass a type converter.) One thing you could try is passing a type converter that always returns the same SSA value, i.e., does not convert anything. Not sure if it will work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your help.

I don't think it works. After following your suggestion I am getting null operand errors from operations referring to the replaced value. It looks like rewriter.replaceOp(op, base) is not replacing cases where the type changed (or maybe the old versions of the operation usage aren't removed before this check runs?).

I guess the right solution will be to use greedy pattern rewriter after all, even though this pass is conceptually performing a partial dialect conversion.

op->getResult(0).replaceAllUsesWith(base);
rewriter.replaceOp(op, base);
}
Expand All @@ -203,48 +199,53 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
typename HlfirIntrinsicConversion<OP>::IntrinsicArgument;
using HlfirIntrinsicConversion<OP>::lowerArguments;
using HlfirIntrinsicConversion<OP>::processReturnValue;
using Adaptor = typename OP::Adaptor;

protected:
auto buildNumericalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
auto buildNumericalArgs(mlir::Operation *operation, Adaptor adaptor,
mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter,
std::string opName) const {
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
inArgs.push_back({operation.getDim(), i32});
inArgs.push_back({operation.getMask(), logicalType});
inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()});
inArgs.push_back({adaptor.getDim(), i32});
inArgs.push_back({adaptor.getMask(), logicalType});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};

auto buildMinMaxLocArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
auto buildMinMaxLocArgs(mlir::Operation *operation, Adaptor adaptor,
mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter, std::string opName,
fir::FirOpBuilder builder) const {
llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({operation.getArray(), operation.getArray().getType()});
inArgs.push_back({operation.getDim(), i32});
inArgs.push_back({operation.getMask(), logicalType});
inArgs.push_back({adaptor.getArray(), adaptor.getArray().getType()});
inArgs.push_back({adaptor.getDim(), i32});
inArgs.push_back({adaptor.getMask(), logicalType});
mlir::Value kind = builder.createIntegerConstant(
operation->getLoc(), i32, getKindForType(operation.getType()));
operation->getLoc(), i32,
getKindForType(operation->getResult(0).getType()));
inArgs.push_back({kind, i32});
inArgs.push_back({operation.getBack(), i32});
inArgs.push_back({adaptor.getBack(), i32});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};

auto buildLogicalArgs(OP operation, mlir::Type i32, mlir::Type logicalType,
auto buildLogicalArgs(mlir::Operation *operation, Adaptor adaptor,
mlir::Type i32, mlir::Type logicalType,
mlir::PatternRewriter &rewriter,
std::string opName) const {
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({operation.getMask(), logicalType});
inArgs.push_back({operation.getDim(), i32});
inArgs.push_back({adaptor.getMask(), logicalType});
inArgs.push_back({adaptor.getDim(), i32});
auto *argLowering = fir::getIntrinsicArgumentLowering(opName);
return lowerArguments(operation, inArgs, rewriter, argLowering);
};

public:
mlir::LogicalResult
matchAndRewrite(OP operation,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(OP operation, Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
std::string opName;
if constexpr (std::is_same_v<OP, hlfir::SumOp>) {
opName = "sum";
Expand Down Expand Up @@ -279,13 +280,15 @@ class HlfirReductionIntrinsicConversion : public HlfirIntrinsicConversion<OP> {
std::is_same_v<OP, hlfir::ProductOp> ||
std::is_same_v<OP, hlfir::MaxvalOp> ||
std::is_same_v<OP, hlfir::MinvalOp>) {
args = buildNumericalArgs(operation, i32, logicalType, rewriter, opName);
args = buildNumericalArgs(operation, adaptor, i32, logicalType, rewriter,
opName);
} else if constexpr (std::is_same_v<OP, hlfir::MinlocOp> ||
std::is_same_v<OP, hlfir::MaxlocOp>) {
args = buildMinMaxLocArgs(operation, i32, logicalType, rewriter, opName,
builder);
args = buildMinMaxLocArgs(operation, adaptor, i32, logicalType, rewriter,
opName, builder);
} else {
args = buildLogicalArgs(operation, i32, logicalType, rewriter, opName);
args = buildLogicalArgs(operation, adaptor, i32, logicalType, rewriter,
opName);
}

mlir::Type scalarResultType =
Expand Down Expand Up @@ -319,8 +322,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
using HlfirIntrinsicConversion<hlfir::CountOp>::HlfirIntrinsicConversion;

mlir::LogicalResult
matchAndRewrite(hlfir::CountOp count,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(hlfir::CountOp count, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, count.getOperation()};
const mlir::Location &loc = count->getLoc();

Expand All @@ -329,8 +332,8 @@ struct CountOpConversion : public HlfirIntrinsicConversion<hlfir::CountOp> {
builder.getContext(), builder.getKindMap().defaultLogicalKind());

llvm::SmallVector<IntrinsicArgument, 3> inArgs;
inArgs.push_back({count.getMask(), logicalType});
inArgs.push_back({count.getDim(), i32});
inArgs.push_back({adaptor.getMask(), logicalType});
inArgs.push_back({adaptor.getDim(), i32});
mlir::Value kind = builder.createIntegerConstant(
count->getLoc(), i32, getKindForType(count.getType()));
inArgs.push_back({kind, i32});
Expand All @@ -353,13 +356,13 @@ struct MatmulOpConversion : public HlfirIntrinsicConversion<hlfir::MatmulOp> {
using HlfirIntrinsicConversion<hlfir::MatmulOp>::HlfirIntrinsicConversion;

mlir::LogicalResult
matchAndRewrite(hlfir::MatmulOp matmul,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(hlfir::MatmulOp matmul, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, matmul.getOperation()};
const mlir::Location &loc = matmul->getLoc();

mlir::Value lhs = matmul.getLhs();
mlir::Value rhs = matmul.getRhs();
mlir::Value lhs = adaptor.getLhs();
mlir::Value rhs = adaptor.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
Expand All @@ -384,13 +387,13 @@ struct DotProductOpConversion
using HlfirIntrinsicConversion<hlfir::DotProductOp>::HlfirIntrinsicConversion;

mlir::LogicalResult
matchAndRewrite(hlfir::DotProductOp dotProduct,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(hlfir::DotProductOp dotProduct, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, dotProduct.getOperation()};
const mlir::Location &loc = dotProduct->getLoc();

mlir::Value lhs = dotProduct.getLhs();
mlir::Value rhs = dotProduct.getRhs();
mlir::Value lhs = adaptor.getLhs();
mlir::Value rhs = adaptor.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
Expand All @@ -415,12 +418,12 @@ class TransposeOpConversion
using HlfirIntrinsicConversion<hlfir::TransposeOp>::HlfirIntrinsicConversion;

mlir::LogicalResult
matchAndRewrite(hlfir::TransposeOp transpose,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(hlfir::TransposeOp transpose, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, transpose.getOperation()};
const mlir::Location &loc = transpose->getLoc();

mlir::Value arg = transpose.getArray();
mlir::Value arg = adaptor.getArray();
llvm::SmallVector<IntrinsicArgument, 1> inArgs;
inArgs.push_back({arg, arg.getType()});

Expand All @@ -445,13 +448,13 @@ struct MatmulTransposeOpConversion
hlfir::MatmulTransposeOp>::HlfirIntrinsicConversion;

mlir::LogicalResult
matchAndRewrite(hlfir::MatmulTransposeOp multranspose,
mlir::PatternRewriter &rewriter) const override {
matchAndRewrite(hlfir::MatmulTransposeOp multranspose, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
fir::FirOpBuilder builder{rewriter, multranspose.getOperation()};
const mlir::Location &loc = multranspose->getLoc();

mlir::Value lhs = multranspose.getLhs();
mlir::Value rhs = multranspose.getRhs();
mlir::Value lhs = adaptor.getLhs();
mlir::Value rhs = adaptor.getRhs();
llvm::SmallVector<IntrinsicArgument, 2> inArgs;
inArgs.push_back({lhs, lhs.getType()});
inArgs.push_back({rhs, rhs.getType()});
Expand Down