Skip to content

Commit 00f9c85

Browse files
authored
[flang] Added fir.is_contiguous_box and fir.box_total_elements ops. (#131047)
These are helper operations to aid with expanding of fir.pack_array.
1 parent 52cd27e commit 00f9c85

File tree

19 files changed

+566
-2
lines changed

19 files changed

+566
-2
lines changed

flang/include/flang/Optimizer/Builder/Runtime/Inquiry.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,15 @@ mlir::Value genSize(fir::FirOpBuilder &builder, mlir::Location loc,
5050
mlir::Value genSizeDim(fir::FirOpBuilder &builder, mlir::Location loc,
5151
mlir::Value array, mlir::Value dim);
5252

53-
/// Generate call to `Is_contiguous` runtime routine.
53+
/// Generate call to `IsContiguous` runtime routine.
5454
mlir::Value genIsContiguous(fir::FirOpBuilder &builder, mlir::Location loc,
5555
mlir::Value array);
5656

57+
/// Generate call to `IsContiguousUpTo` runtime routine.
58+
/// \p dim specifies the dimension up to which contiguity
59+
/// needs to be checked (not exceeding the actual rank of the array).
60+
mlir::Value genIsContiguousUpTo(fir::FirOpBuilder &builder, mlir::Location loc,
61+
mlir::Value array, mlir::Value dim);
62+
5763
} // namespace fir::runtime
5864
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_INQUIRY_H

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3416,4 +3416,34 @@ def fir_UnpackArrayOp
34163416
let hasVerifier = 1;
34173417
}
34183418

3419+
def fir_IsContiguousBoxOp : fir_Op<"is_contiguous_box", [NoMemoryEffect]> {
3420+
let summary = "Returns true if the boxed entity is contiguous";
3421+
let description = [{
3422+
Returns true iff the boxed entity is contiguous:
3423+
* in the leading dimension (if `innermost` attribute is set),
3424+
* in all dimensions (if `innermost` attribute is not set).
3425+
3426+
The input box cannot be absent.
3427+
}];
3428+
let arguments = (ins AnyBoxLike:$box, UnitAttr:$innermost);
3429+
let results = (outs I1);
3430+
3431+
let assemblyFormat = [{
3432+
$box (`innermost` $innermost^):(`whole`)? attr-dict `:` functional-type(operands, results)
3433+
}];
3434+
let hasCanonicalizer = 1;
3435+
}
3436+
3437+
def fir_BoxTotalElementsOp
3438+
: fir_SimpleOneResultOp<"box_total_elements", [NoMemoryEffect]> {
3439+
let summary = "Returns the boxed entity's total size in elements";
3440+
let description = [{
3441+
Returns the boxed entity's total size in elements.
3442+
The input box cannot be absent.
3443+
}];
3444+
let arguments = (ins AnyBoxLike:$box);
3445+
let results = (outs AnyIntegerLike);
3446+
let hasCanonicalizer = 1;
3447+
}
3448+
34193449
#endif

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ namespace fir {
6262
#define GEN_PASS_DECL_COMPILERGENERATEDNAMESCONVERSION
6363
#define GEN_PASS_DECL_SETRUNTIMECALLATTRIBUTES
6464
#define GEN_PASS_DECL_GENRUNTIMECALLSFORTEST
65+
#define GEN_PASS_DECL_SIMPLIFYFIROPERATIONS
6566

6667
#include "flang/Optimizer/Transforms/Passes.h.inc"
6768

@@ -86,6 +87,9 @@ void populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
8687
bool forceLoopToExecuteOnce = false,
8788
bool setNSW = true);
8889

90+
void populateSimplifyFIROperationsPatterns(mlir::RewritePatternSet &patterns,
91+
bool preferInlineImplementation);
92+
8993
// declarative passes
9094
#define GEN_PASS_REGISTRATION
9195
#include "flang/Optimizer/Transforms/Passes.h.inc"

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,4 +486,18 @@ def GenRuntimeCallsForTest
486486
let dependentDialects = ["fir::FIROpsDialect", "mlir::func::FuncDialect"];
487487
}
488488

489+
def SimplifyFIROperations : Pass<"simplify-fir-operations", "mlir::ModuleOp"> {
490+
let summary = "Simplifies complex FIR operations";
491+
let description = [{
492+
Expands complex FIR operations into their equivalent using
493+
FIR, SCF and other usual dialects. It may also generate calls
494+
to Fortran runtime.
495+
}];
496+
497+
let options = [Option<
498+
"preferInlineImplementation", "prefer-inline-implementation", "bool",
499+
/*default=*/"false",
500+
"Prefer expanding without using Fortran runtime calls.">];
501+
}
502+
489503
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

flang/include/flang/Runtime/support.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ extern "C" {
3434
// Predicate: is the storage described by a Descriptor contiguous in memory?
3535
bool RTDECL(IsContiguous)(const Descriptor &);
3636

37+
// Predicate: is the storage described by a Descriptor contiguous in memory
38+
// up to the given dimension?
39+
bool RTDECL(IsContiguousUpTo)(const Descriptor &, int);
40+
3741
// Predicate: is this descriptor describing an assumed-size array?
3842
bool RTDECL(IsAssumedSize)(const Descriptor &);
3943

flang/lib/Optimizer/Builder/Runtime/Inquiry.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ mlir::Value fir::runtime::genSize(fir::FirOpBuilder &builder,
9191
return builder.create<fir::CallOp>(loc, sizeFunc, args).getResult(0);
9292
}
9393

94-
/// Generate call to `Is_contiguous` runtime routine.
94+
/// Generate call to `IsContiguous` runtime routine.
9595
mlir::Value fir::runtime::genIsContiguous(fir::FirOpBuilder &builder,
9696
mlir::Location loc,
9797
mlir::Value array) {
@@ -102,6 +102,18 @@ mlir::Value fir::runtime::genIsContiguous(fir::FirOpBuilder &builder,
102102
return builder.create<fir::CallOp>(loc, isContiguousFunc, args).getResult(0);
103103
}
104104

105+
/// Generate call to `IsContiguousUpTo` runtime routine.
106+
mlir::Value fir::runtime::genIsContiguousUpTo(fir::FirOpBuilder &builder,
107+
mlir::Location loc,
108+
mlir::Value array,
109+
mlir::Value dim) {
110+
mlir::func::FuncOp isContiguousFunc =
111+
fir::runtime::getRuntimeFunc<mkRTKey(IsContiguousUpTo)>(loc, builder);
112+
auto fTy = isContiguousFunc.getFunctionType();
113+
auto args = fir::runtime::createArguments(builder, loc, fTy, array, dim);
114+
return builder.create<fir::CallOp>(loc, isContiguousFunc, args).getResult(0);
115+
}
116+
105117
void fir::runtime::genShape(fir::FirOpBuilder &builder, mlir::Location loc,
106118
mlir::Value resultAddr, mlir::Value array,
107119
mlir::Value kind) {

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4671,6 +4671,83 @@ void fir::UnpackArrayOp::getEffects(
46714671
mlir::SideEffects::DefaultResource::get());
46724672
}
46734673

4674+
//===----------------------------------------------------------------------===//
4675+
// IsContiguousBoxOp
4676+
//===----------------------------------------------------------------------===//
4677+
4678+
namespace {
4679+
struct SimplifyIsContiguousBoxOp
4680+
: public mlir::OpRewritePattern<fir::IsContiguousBoxOp> {
4681+
using mlir::OpRewritePattern<fir::IsContiguousBoxOp>::OpRewritePattern;
4682+
mlir::LogicalResult
4683+
matchAndRewrite(fir::IsContiguousBoxOp op,
4684+
mlir::PatternRewriter &rewriter) const override;
4685+
};
4686+
} // namespace
4687+
4688+
mlir::LogicalResult SimplifyIsContiguousBoxOp::matchAndRewrite(
4689+
fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const {
4690+
auto boxType = mlir::cast<fir::BaseBoxType>(op.getBox().getType());
4691+
// Nothing to do for assumed-rank arrays and !fir.box<none>.
4692+
if (boxType.isAssumedRank() || fir::isBoxNone(boxType))
4693+
return mlir::failure();
4694+
4695+
if (fir::getBoxRank(boxType) == 0) {
4696+
// Scalars are always contiguous.
4697+
mlir::Type i1Type = rewriter.getI1Type();
4698+
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
4699+
op, i1Type, rewriter.getIntegerAttr(i1Type, 1));
4700+
return mlir::success();
4701+
}
4702+
4703+
// TODO: support more patterns, e.g. a result of fir.embox without
4704+
// the slice is contiguous. We can add fir::isSimplyContiguous(box)
4705+
// that walks def-use to figure it out.
4706+
return mlir::failure();
4707+
}
4708+
4709+
void fir::IsContiguousBoxOp::getCanonicalizationPatterns(
4710+
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
4711+
patterns.add<SimplifyIsContiguousBoxOp>(context);
4712+
}
4713+
4714+
//===----------------------------------------------------------------------===//
4715+
// BoxTotalElementsOp
4716+
//===----------------------------------------------------------------------===//
4717+
4718+
namespace {
4719+
struct SimplifyBoxTotalElementsOp
4720+
: public mlir::OpRewritePattern<fir::BoxTotalElementsOp> {
4721+
using mlir::OpRewritePattern<fir::BoxTotalElementsOp>::OpRewritePattern;
4722+
mlir::LogicalResult
4723+
matchAndRewrite(fir::BoxTotalElementsOp op,
4724+
mlir::PatternRewriter &rewriter) const override;
4725+
};
4726+
} // namespace
4727+
4728+
mlir::LogicalResult SimplifyBoxTotalElementsOp::matchAndRewrite(
4729+
fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const {
4730+
auto boxType = mlir::cast<fir::BaseBoxType>(op.getBox().getType());
4731+
// Nothing to do for assumed-rank arrays and !fir.box<none>.
4732+
if (boxType.isAssumedRank() || fir::isBoxNone(boxType))
4733+
return mlir::failure();
4734+
4735+
if (fir::getBoxRank(boxType) == 0) {
4736+
// Scalar: 1 element.
4737+
rewriter.replaceOpWithNewOp<mlir::arith::ConstantOp>(
4738+
op, op.getType(), rewriter.getIntegerAttr(op.getType(), 1));
4739+
return mlir::success();
4740+
}
4741+
4742+
// TODO: support more cases, e.g. !fir.box<!fir.array<10xi32>>.
4743+
return mlir::failure();
4744+
}
4745+
4746+
void fir::BoxTotalElementsOp::getCanonicalizationPatterns(
4747+
mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) {
4748+
patterns.add<SimplifyBoxTotalElementsOp>(context);
4749+
}
4750+
46744751
//===----------------------------------------------------------------------===//
46754752
// FIROpsDialect
46764753
//===----------------------------------------------------------------------===//

flang/lib/Optimizer/Passes/Pipelines.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ void createDefaultFIROptimizerPassPipeline(mlir::PassManager &pm,
198198
pm.addPass(fir::createPolymorphicOpConversion());
199199
pm.addPass(fir::createAssumedRankOpConversion());
200200

201+
// Expand FIR operations that may use SCF dialect for their
202+
// implementation. This is a mandatory pass.
203+
pm.addPass(fir::createSimplifyFIROperations(
204+
{/*preferInlineImplementation=*/pc.OptLevel.isOptimizingForSpeed()}));
205+
201206
if (pc.AliasAnalysis && !disableFirAliasTags && !useOldAliasTags)
202207
pm.addPass(fir::createAddAliasTags());
203208

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ add_flang_library(FIRTransforms
3131
DebugTypeGenerator.cpp
3232
SetRuntimeCallAttributes.cpp
3333
GenRuntimeCallsForTest.cpp
34+
SimplifyFIROperations.cpp
3435

3536
DEPENDS
3637
CUFAttrs
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
//===- SimplifyFIROperations.cpp -- simplify complex FIR operations ------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
//===----------------------------------------------------------------------===//
10+
/// \file
11+
/// This pass transforms some FIR operations into their equivalent
12+
/// implementations using other FIR operations. The transformation
13+
/// can legally use SCF dialect and generate Fortran runtime calls.
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "flang/Optimizer/Builder/FIRBuilder.h"
17+
#include "flang/Optimizer/Builder/Runtime/Inquiry.h"
18+
#include "flang/Optimizer/Builder/Todo.h"
19+
#include "flang/Optimizer/Dialect/FIROps.h"
20+
#include "flang/Optimizer/Transforms/Passes.h"
21+
#include "mlir/Pass/Pass.h"
22+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23+
24+
namespace fir {
25+
#define GEN_PASS_DEF_SIMPLIFYFIROPERATIONS
26+
#include "flang/Optimizer/Transforms/Passes.h.inc"
27+
} // namespace fir
28+
29+
#define DEBUG_TYPE "flang-simplify-fir-operations"
30+
31+
namespace {
32+
/// Pass runner.
33+
class SimplifyFIROperationsPass
34+
: public fir::impl::SimplifyFIROperationsBase<SimplifyFIROperationsPass> {
35+
public:
36+
using fir::impl::SimplifyFIROperationsBase<
37+
SimplifyFIROperationsPass>::SimplifyFIROperationsBase;
38+
39+
void runOnOperation() override final;
40+
};
41+
42+
/// Base class for all conversions holding the pass options.
43+
template <typename Op>
44+
class ConversionBase : public mlir::OpRewritePattern<Op> {
45+
public:
46+
using mlir::OpRewritePattern<Op>::OpRewritePattern;
47+
48+
template <typename... Args>
49+
ConversionBase(mlir::MLIRContext *context, Args &&...args)
50+
: mlir::OpRewritePattern<Op>(context),
51+
options{std::forward<Args>(args)...} {}
52+
53+
mlir::LogicalResult matchAndRewrite(Op,
54+
mlir::PatternRewriter &) const override;
55+
56+
protected:
57+
fir::SimplifyFIROperationsOptions options;
58+
};
59+
60+
/// fir::IsContiguousBoxOp converter.
61+
using IsContiguousBoxCoversion = ConversionBase<fir::IsContiguousBoxOp>;
62+
63+
/// fir::BoxTotalElementsOp converter.
64+
using BoxTotalElementsConversion = ConversionBase<fir::BoxTotalElementsOp>;
65+
} // namespace
66+
67+
/// Generate a call to IsContiguous/IsContiguousUpTo function or an inline
68+
/// sequence reading extents/strides from the box and checking them.
69+
/// This conversion may produce fir.box_elesize and a loop (for assumed
70+
/// rank).
71+
template <>
72+
mlir::LogicalResult IsContiguousBoxCoversion::matchAndRewrite(
73+
fir::IsContiguousBoxOp op, mlir::PatternRewriter &rewriter) const {
74+
mlir::Location loc = op.getLoc();
75+
fir::FirOpBuilder builder(rewriter, op.getOperation());
76+
// TODO: support preferInlineImplementation.
77+
bool doInline = options.preferInlineImplementation && false;
78+
if (!doInline) {
79+
// Generate Fortran runtime call.
80+
mlir::Value result;
81+
if (op.getInnermost()) {
82+
mlir::Value one =
83+
builder.createIntegerConstant(loc, builder.getI32Type(), 1);
84+
result =
85+
fir::runtime::genIsContiguousUpTo(builder, loc, op.getBox(), one);
86+
} else {
87+
result = fir::runtime::genIsContiguous(builder, loc, op.getBox());
88+
}
89+
result = builder.createConvert(loc, op.getType(), result);
90+
rewriter.replaceOp(op, result);
91+
return mlir::success();
92+
}
93+
94+
// Generate inline implementation.
95+
TODO(loc, "inline IsContiguousBoxOp");
96+
return mlir::failure();
97+
}
98+
99+
/// Generate a call to Size runtime function or an inline
100+
/// sequence reading extents from the box an multiplying them.
101+
/// This conversion may produce a loop (for assumed rank).
102+
template <>
103+
mlir::LogicalResult BoxTotalElementsConversion::matchAndRewrite(
104+
fir::BoxTotalElementsOp op, mlir::PatternRewriter &rewriter) const {
105+
mlir::Location loc = op.getLoc();
106+
fir::FirOpBuilder builder(rewriter, op.getOperation());
107+
// TODO: support preferInlineImplementation.
108+
// Reading the extent from the box for 1D arrays probably
109+
// results in less code than the call, so we can always
110+
// inline it.
111+
bool doInline = options.preferInlineImplementation && false;
112+
if (!doInline) {
113+
// Generate Fortran runtime call.
114+
mlir::Value result = fir::runtime::genSize(builder, loc, op.getBox());
115+
result = builder.createConvert(loc, op.getType(), result);
116+
rewriter.replaceOp(op, result);
117+
return mlir::success();
118+
}
119+
120+
// Generate inline implementation.
121+
TODO(loc, "inline BoxTotalElementsOp");
122+
return mlir::failure();
123+
}
124+
125+
void SimplifyFIROperationsPass::runOnOperation() {
126+
mlir::ModuleOp module = getOperation();
127+
mlir::MLIRContext &context = getContext();
128+
mlir::RewritePatternSet patterns(&context);
129+
fir::populateSimplifyFIROperationsPatterns(patterns,
130+
preferInlineImplementation);
131+
mlir::GreedyRewriteConfig config;
132+
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
133+
134+
if (mlir::failed(
135+
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
136+
mlir::emitError(module.getLoc(), DEBUG_TYPE " pass failed");
137+
signalPassFailure();
138+
}
139+
}
140+
141+
void fir::populateSimplifyFIROperationsPatterns(
142+
mlir::RewritePatternSet &patterns, bool preferInlineImplementation) {
143+
patterns.insert<IsContiguousBoxCoversion, BoxTotalElementsConversion>(
144+
patterns.getContext(), preferInlineImplementation);
145+
}

flang/test/Driver/bbc-mlir-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747

4848
! CHECK-NEXT: PolymorphicOpConversion
4949
! CHECK-NEXT: AssumedRankOpConversion
50+
! CHECK-NEXT: SimplifyFIROperations
5051

5152
! CHECK-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
5253
! CHECK-NEXT: 'fir.global' Pipeline

flang/test/Driver/mlir-debug-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777

7878
! ALL-NEXT: PolymorphicOpConversion
7979
! ALL-NEXT: AssumedRankOpConversion
80+
! ALL-NEXT: SimplifyFIROperations
8081

8182
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
8283
! ALL-NEXT: 'fir.global' Pipeline

0 commit comments

Comments
 (0)