Skip to content

[flang] Added fir.is_contiguous_box and fir.box_total_elements ops. #131047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion flang/include/flang/Optimizer/Builder/Runtime/Inquiry.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,15 @@ mlir::Value genSize(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value genSizeDim(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value array, mlir::Value dim);

/// Generate call to `Is_contiguous` runtime routine.
/// Generate call to `IsContiguous` runtime routine.
mlir::Value genIsContiguous(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value array);

/// Generate call to `IsContiguousUpTo` runtime routine.
/// \p dim specifies the dimension up to which contiguity
/// needs to be checked (not exceeding the actual rank of the array).
mlir::Value genIsContiguousUpTo(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value array, mlir::Value dim);

} // namespace fir::runtime
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_INQUIRY_H
30 changes: 30 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3416,4 +3416,34 @@ def fir_UnpackArrayOp
let hasVerifier = 1;
}

def fir_IsContiguousBoxOp : fir_Op<"is_contiguous_box", [NoMemoryEffect]> {
let summary = "Returns true if the boxed entity is contiguous";
let description = [{
Returns true iff the boxed entity is contiguous:
* in the leading dimension (if `innermost` attribute is set),
* in all dimensions (if `innermost` attribute is not set).

The input box cannot be absent.
}];
let arguments = (ins AnyBoxLike:$box, UnitAttr:$innermost);
let results = (outs I1);

let assemblyFormat = [{
$box (`innermost` $innermost^):(`whole`)? attr-dict `:` functional-type(operands, results)
}];
let hasCanonicalizer = 1;
}

def fir_BoxTotalElementsOp
: fir_SimpleOneResultOp<"box_total_elements", [NoMemoryEffect]> {
let summary = "Returns the boxed entity's total size in elements";
let description = [{
Returns the boxed entity's total size in elements.
The input box cannot be absent.
}];
let arguments = (ins AnyBoxLike:$box);
let results = (outs AnyIntegerLike);
let hasCanonicalizer = 1;
}

#endif
4 changes: 4 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ namespace fir {
#define GEN_PASS_DECL_COMPILERGENERATEDNAMESCONVERSION
#define GEN_PASS_DECL_SETRUNTIMECALLATTRIBUTES
#define GEN_PASS_DECL_GENRUNTIMECALLSFORTEST
#define GEN_PASS_DECL_SIMPLIFYFIROPERATIONS

#include "flang/Optimizer/Transforms/Passes.h.inc"

Expand All @@ -86,6 +87,9 @@ void populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
bool forceLoopToExecuteOnce = false,
bool setNSW = true);

void populateSimplifyFIROperationsPatterns(mlir::RewritePatternSet &patterns,
bool preferInlineImplementation);

// declarative passes
#define GEN_PASS_REGISTRATION
#include "flang/Optimizer/Transforms/Passes.h.inc"
Expand Down
14 changes: 14 additions & 0 deletions flang/include/flang/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -486,4 +486,18 @@ def GenRuntimeCallsForTest
let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect"];
}

def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
let summary = "Simplifies complex FIR operations";
let description = [{
Expands complex FIR operations into their equivalent using
FIR, SCF and other usual dialects. It may also generate calls
to Fortran runtime.
}];

let options = [Option<
"preferInlineImplementation", "prefer-inline-implementation", "bool",
/*default=*/"false",
"Prefer expanding without using Fortran runtime calls.">];
}

#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES
4 changes: 4 additions & 0 deletions flang/include/flang/Runtime/support.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ extern "C" {
// Predicate: is the storage described by a Descriptor contiguous in memory?
bool RTDECL(IsContiguous)(const Descriptor &);

// Predicate: is the storage described by a Descriptor contiguous in memory
// up to the given dimension?
bool RTDECL(IsContiguousUpTo)(const Descriptor &, int);

// Predicate: is this descriptor describing an assumed-size array?
bool RTDECL(IsAssumedSize)(const Descriptor &);

Expand Down
14 changes: 13 additions & 1 deletion flang/lib/Optimizer/Builder/Runtime/Inquiry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ mlir::Value fir::runtime::genSize(fir::FirOpBuilder &builder,
return builder.create<fir::CallOp>(loc, sizeFunc, args).getResult(0);
}

/// Generate call to `Is_contiguous` runtime routine.
/// Generate call to `IsContiguous` runtime routine.
mlir::Value fir::runtime::genIsContiguous(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value array) {
Expand All @@ -102,6 +102,18 @@ mlir::Value fir::runtime::genIsContiguous(fir::FirOpBuilder &builder,
return builder.create<fir::CallOp>(loc, isContiguousFunc, args).getResult(0);
}

/// Generate call to `IsContiguousUpTo` runtime routine.
mlir::Value fir::runtime::genIsContiguousUpTo(fir::FirOpBuilder &builder,
mlir::Location loc,
mlir::Value array,
mlir::Value dim) {
mlir::func::FuncOp isContiguousFunc =
fir::runtime::getRuntimeFunc<mkRTKey(IsContiguousUpTo)>(loc, builder);
auto fTy = isContiguousFunc.getFunctionType();
auto args = fir::runtime::createArguments(builder, loc, fTy, array, dim);
return builder.create<fir::CallOp>(loc, isContiguousFunc, args).getResult(0);
}

void fir::runtime::genShape(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value resultAddr, mlir::Value array,
mlir::Value kind) {
Expand Down
77 changes: 77 additions & 0 deletions flang/lib/Optimizer/Dialect/FIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4671,6 +4671,83 @@ void fir::UnpackArrayOp::getEffects(
mlir::SideEffects::DefaultResource::get());
}

//===----------------------------------------------------------------------===//
// IsContiguousBoxOp
//===----------------------------------------------------------------------===//

namespace {
struct SimplifyIsContiguousBoxOp
: public mlir::OpRewritePattern<fir::IsContiguousBoxOp> {
using mlir::OpRewritePattern<fir::IsContiguousBoxOp>::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(fir::IsContiguousBoxOp op,
mlir::PatternRewriter &rewriter) const override;
};
} // namespace

mlir::LogicalResult SimplifyIsContiguousBoxOp::matchAndRewrite(
fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const {
auto boxType = mlir::cast<fir::BaseBoxType>(op.getBox().getType());
// Nothing to do for assumed-rank arrays and !fir.box<none>.
if (boxType.isAssumedRank() || fir::isBoxNone(boxType))
return mlir::failure();

if (fir::getBoxRank(boxType) == 0) {
// Scalars are always contiguous.
mlir::Type i1Type = rewriter.getI1Type();
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
return mlir::success();
}

// TODO: support more patterns, e.g. a result of fir.embox without
// the slice is contiguous. We can add fir::isSimplyContiguous(box)
// that walks def-use to figure it out.
return mlir::failure();
}

void fir::IsContiguousBoxOp::getCanonicalizationPatterns(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
patterns.add<SimplifyIsContiguousBoxOp>(context);
}

//===----------------------------------------------------------------------===//
// BoxTotalElementsOp
//===----------------------------------------------------------------------===//

namespace {
struct SimplifyBoxTotalElementsOp
: public mlir::OpRewritePattern<fir::BoxTotalElementsOp> {
using mlir::OpRewritePattern<fir::BoxTotalElementsOp>::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(fir::BoxTotalElementsOp op,
mlir::PatternRewriter &rewriter) const override;
};
} // namespace

mlir::LogicalResult SimplifyBoxTotalElementsOp::matchAndRewrite(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could also canonicalize this for cases where the number of elements is known at compile time inside of the type e.g. !fir.box<!fir.array<10xi32>>

fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const {
auto boxType = mlir::cast<fir::BaseBoxType>(op.getBox().getType());
// Nothing to do for assumed-rank arrays and !fir.box<none>.
if (boxType.isAssumedRank() || fir::isBoxNone(boxType))
return mlir::failure();

if (fir::getBoxRank(boxType) == 0) {
// Scalar: 1 element.
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
op, op.getType(), rewriter.getIntegerAttr(op.getType(), 1));
return mlir::success();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Result of fir.embox without slices can also fold to true (there are other cases, but it is a good start and some fir::isSimplyConintiguous(box) helper could be added later to try to walk back more, which could be helpful after inlining).


// TODO: support more cases, e.g. !fir.box<!fir.array<10xi32>>.
return mlir::failure();
}

void fir::BoxTotalElementsOp::getCanonicalizationPatterns(
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
patterns.add<SimplifyBoxTotalElementsOp>(context);
}

//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions flang/lib/Optimizer/Passes/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
pm.addPass(fir::createPolymorphicOpConversion());
pm.addPass(fir::createAssumedRankOpConversion());

// Expand FIR operations that may use SCF dialect for their
// implementation. This is a mandatory pass.
pm.addPass(fir::createSimplifyFIROperations(
{/*preferInlineImplementation=*/pc.OptLevel.isOptimizingForSpeed()}));

if (pc.AliasAnalysis && !disableFirAliasTags && !useOldAliasTags)
pm.addPass(fir::createAddAliasTags());

Expand Down
1 change: 1 addition & 0 deletions flang/lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_flang_library(FIRTransforms
DebugTypeGenerator.cpp
SetRuntimeCallAttributes.cpp
GenRuntimeCallsForTest.cpp
SimplifyFIROperations.cpp

DEPENDS
CUFAttrs
Expand Down
145 changes: 145 additions & 0 deletions flang/lib/Optimizer/Transforms/SimplifyFIROperations.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
//===- SimplifyFIROperations.cpp -- simplify complex FIR operations ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

//===----------------------------------------------------------------------===//
/// \file
/// This pass transforms some FIR operations into their equivalent
/// implementations using other FIR operations. The transformation
/// can legally use SCF dialect and generate Fortran runtime calls.
//===----------------------------------------------------------------------===//

#include "flang/Optimizer/Builder/FIRBuilder.h"
#include "flang/Optimizer/Builder/Runtime/Inquiry.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Optimizer/Dialect/FIROps.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace fir {
#define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS
#include "flang/Optimizer/Transforms/Passes.h.inc"
} // namespace fir

#define DEBUG_TYPE "flang-simplify-fir-operations"

namespace {
/// Pass runner.
class SimplifyFIROperationsPass
: public fir::impl::SimplifyFIROperationsBase<SimplifyFIROperationsPass> {
public:
using fir::impl::SimplifyFIROperationsBase<
SimplifyFIROperationsPass>::SimplifyFIROperationsBase;

void runOnOperation() override final;
};

/// Base class for all conversions holding the pass options.
template <typename Op>
class ConversionBase : public mlir::OpRewritePattern<Op> {
public:
using mlir::OpRewritePattern<Op>::OpRewritePattern;

template <typename... Args>
ConversionBase(mlir::MLIRContext *context, Args &&...args)
: mlir::OpRewritePattern<Op>(context),
options{std::forward<Args>(args)...} {}

mlir::LogicalResult matchAndRewrite(Op,
mlir::PatternRewriter &) const override;

protected:
fir::SimplifyFIROperationsOptions options;
};

/// fir::IsContiguousBoxOp converter.
using IsContiguousBoxCoversion = ConversionBase<fir::IsContiguousBoxOp>;

/// fir::BoxTotalElementsOp converter.
using BoxTotalElementsConversion = ConversionBase<fir::BoxTotalElementsOp>;
} // namespace

/// Generate a call to IsContiguous/IsContiguousUpTo function or an inline
/// sequence reading extents/strides from the box and checking them.
/// This conversion may produce fir.box_elesize and a loop (for assumed
/// rank).
template <>
mlir::LogicalResult IsContiguousBoxCoversion::matchAndRewrite(
fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const {
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, op.getOperation());
// TODO: support preferInlineImplementation.
bool doInline = options.preferInlineImplementation && false;
if (!doInline) {
// Generate Fortran runtime call.
mlir::Value result;
if (op.getInnermost()) {
mlir::Value one =
builder.createIntegerConstant(loc, builder.getI32Type(), 1);
result =
fir::runtime::genIsContiguousUpTo(builder, loc, op.getBox(), one);
} else {
result = fir::runtime::genIsContiguous(builder, loc, op.getBox());
}
result = builder.createConvert(loc, op.getType(), result);
rewriter.replaceOp(op, result);
return mlir::success();
}

// Generate inline implementation.
TODO(loc, "inline IsContiguousBoxOp");
return mlir::failure();
}

/// Generate a call to Size runtime function or an inline
/// sequence reading extents from the box an multiplying them.
/// This conversion may produce a loop (for assumed rank).
template <>
mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite(
fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const {
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, op.getOperation());
// TODO: support preferInlineImplementation.
// Reading the extent from the box for 1D arrays probably
// results in less code than the call, so we can always
// inline it.
bool doInline = options.preferInlineImplementation && false;
if (!doInline) {
// Generate Fortran runtime call.
mlir::Value result = fir::runtime::genSize(builder, loc, op.getBox());
result = builder.createConvert(loc, op.getType(), result);
rewriter.replaceOp(op, result);
return mlir::success();
}

// Generate inline implementation.
TODO(loc, "inline BoxTotalElementsOp");
return mlir::failure();
}

void SimplifyFIROperationsPass::runOnOperation() {
mlir::ModuleOp module = getOperation();
mlir::MLIRContext &context = getContext();
mlir::RewritePatternSet patterns(&context);
fir::populateSimplifyFIROperationsPatterns(patterns,
preferInlineImplementation);
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;

if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
mlir::emitError(module.getLoc(), DEBUG_TYPE " pass failed");
signalPassFailure();
}
}

void fir::populateSimplifyFIROperationsPatterns(
mlir::RewritePatternSet &patterns, bool preferInlineImplementation) {
patterns.insert<IsContiguousBoxCoversion, BoxTotalElementsConversion>(
patterns.getContext(), preferInlineImplementation);
}
1 change: 1 addition & 0 deletions flang/test/Driver/bbc-mlir-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

! CHECK-NEXT: PolymorphicOpConversion
! CHECK-NEXT: AssumedRankOpConversion
! CHECK-NEXT: SimplifyFIROperations

! CHECK-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
! CHECK-NEXT: 'fir.global' Pipeline
Expand Down
1 change: 1 addition & 0 deletions flang/test/Driver/mlir-debug-pass-pipeline.f90
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@

! ALL-NEXT: PolymorphicOpConversion
! ALL-NEXT: AssumedRankOpConversion
! ALL-NEXT: SimplifyFIROperations

! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
! ALL-NEXT: 'fir.global' Pipeline
Expand Down
Loading