Skip to content

Commit 80e2a89

Browse files
committed
[flang] add FIR to FIR pass to lower assumed-rank operations
Add pass to lower assumed-rank operations. The current patch adds codegen for fir.rebox_assumed_rank. It will be the pass lowering fir.select_rank. fir.rebox_assumed_rank is lowered to a call to CopyAndUpdateDescriptor runtime API. Note that the lowering ends-up allocating two new descriptor at the LLVM level (one alloca created by the pass for the CopyAndUpdateDescriptor result descriptor argument, the second one is created by the fir.load of the result descriptor in codegen). LLVM is currently unable to properly optimize and merge those allocas. The "nocapture" attribute added to CopyAndUpdateDescriptor arguments gives part of the information to LLVM, but the fir.load codegen of descriptors must be updated to use llvm.memcpy instead of llvm.load+store to allow LLVM to optimize it. This will be done in later patch.
1 parent b0b3596 commit 80e2a89

File tree

17 files changed

+361
-2
lines changed

17 files changed

+361
-2
lines changed

flang/include/flang/Optimizer/Builder/FIRBuilder.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
5050
mlir::SymbolTable *symbolTable = nullptr)
5151
: OpBuilder{op, /*listener=*/this}, kindMap{std::move(kindMap)},
5252
symbolTable{symbolTable} {}
53-
explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap)
54-
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)} {
53+
explicit FirOpBuilder(mlir::OpBuilder &builder, fir::KindMapping kindMap,
54+
mlir::SymbolTable *symbolTable = nullptr)
55+
: OpBuilder(builder), OpBuilder::Listener(), kindMap{std::move(kindMap)},
56+
symbolTable{symbolTable} {
5557
setListener(this);
5658
}
5759
explicit FirOpBuilder(mlir::OpBuilder &builder, mlir::ModuleOp mod)

flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ constexpr TypeBuilderFunc getModel<signed char>() {
130130
};
131131
}
132132
template <>
133+
constexpr TypeBuilderFunc getModel<unsigned char>() {
134+
return [](mlir::MLIRContext *context) -> mlir::Type {
135+
return mlir::IntegerType::get(context, 8 * sizeof(unsigned char));
136+
};
137+
}
138+
template <>
133139
constexpr TypeBuilderFunc getModel<void *>() {
134140
return [](mlir::MLIRContext *context) -> mlir::Type {
135141
return fir::LLVMPointerType::get(context,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===-- Support.h - generate support runtime API calls ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef FORTRAN_OPTIMIZER_BUILDER_RUNTIME_SUPPORT_H
10+
#define FORTRAN_OPTIMIZER_BUILDER_RUNTIME_SUPPORT_H
11+
12+
namespace mlir {
13+
class Value;
14+
class Location;
15+
} // namespace mlir
16+
17+
namespace fir {
18+
class FirOpBuilder;
19+
}
20+
21+
namespace fir::runtime {
22+
23+
/// Generate call to `CopyAndUpdateDescriptor` runtime routine.
24+
void genCopyAndUpdateDescriptor(fir::FirOpBuilder &builder, mlir::Location loc,
25+
mlir::Value to, mlir::Value from,
26+
mlir::Value newDynamicType,
27+
mlir::Value newAttribute,
28+
mlir::Value newLowerBounds);
29+
30+
} // namespace fir::runtime
31+
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_SUPPORT_H

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class BaseBoxType : public mlir::Type {
5353
/// Return the same type, except for the shape, that is taken the shape
5454
/// of shapeMold.
5555
BaseBoxType getBoxTypeWithNewShape(mlir::Type shapeMold) const;
56+
BaseBoxType getBoxTypeWithNewShape(int rank) const;
5657

5758
/// Methods for support type inquiry through isa, cast, and dyn_cast.
5859
static bool classof(mlir::Type type);

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ namespace fir {
3636
#define GEN_PASS_DECL_AFFINEDIALECTDEMOTION
3737
#define GEN_PASS_DECL_ANNOTATECONSTANTOPERANDS
3838
#define GEN_PASS_DECL_ARRAYVALUECOPY
39+
#define GEN_PASS_DECL_ASSUMEDRANKOPCONVERSION
3940
#define GEN_PASS_DECL_CHARACTERCONVERSION
4041
#define GEN_PASS_DECL_CFGCONVERSION
4142
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,4 +402,16 @@ def FunctionAttr : Pass<"function-attr", "mlir::func::FuncOp"> {
402402
let constructor = "::fir::createFunctionAttrPass()";
403403
}
404404

405+
def AssumedRankOpConversion : Pass<"fir-assumed-rank-op", "mlir::ModuleOp"> {
406+
let summary =
407+
"Simplify operations on assumed-rank types";
408+
let description = [{
409+
This pass breaks up the lowering of operations on assumed-rank types by
410+
introducing an intermediate FIR level that simplifies code generation.
411+
}];
412+
let dependentDialects = [
413+
"fir::FIROpsDialect", "mlir::func::FuncDialect"
414+
];
415+
}
416+
405417
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

flang/include/flang/Tools/CLOptions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ inline void createDefaultFIROptimizerPassPipeline(
292292

293293
// Polymorphic types
294294
pm.addPass(fir::createPolymorphicOpConversion());
295+
pm.addPass(fir::createAssumedRankOpConversion());
295296

296297
if (pc.AliasAnalysis && !disableFirAliasTags && !useOldAliasTags)
297298
pm.addPass(fir::createAddAliasTags());

flang/lib/Optimizer/Builder/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ add_flang_library(FIRBuilder
2929
Runtime/Ragged.cpp
3030
Runtime/Reduction.cpp
3131
Runtime/Stop.cpp
32+
Runtime/Support.cpp
3233
Runtime/TemporaryStack.cpp
3334
Runtime/Transformational.cpp
3435
TemporaryStorage.cpp
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//===-- Support.cpp - generate support runtime API calls --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Optimizer/Builder/Runtime/Support.h"
10+
#include "flang/Optimizer/Builder/FIRBuilder.h"
11+
#include "flang/Optimizer/Builder/Runtime/RTBuilder.h"
12+
#include "flang/Runtime/support.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14+
15+
using namespace Fortran::runtime;
16+
17+
template <>
18+
constexpr fir::runtime::TypeBuilderFunc
19+
fir::runtime::getModel<Fortran::runtime::LowerBoundModifier>() {
20+
return [](mlir::MLIRContext *context) -> mlir::Type {
21+
return mlir::IntegerType::get(
22+
context, sizeof(Fortran::runtime::LowerBoundModifier) * 8);
23+
};
24+
}
25+
26+
void fir::runtime::genCopyAndUpdateDescriptor(fir::FirOpBuilder &builder,
27+
mlir::Location loc,
28+
mlir::Value to, mlir::Value from,
29+
mlir::Value newDynamicType,
30+
mlir::Value newAttribute,
31+
mlir::Value newLowerBounds) {
32+
mlir::func::FuncOp func =
33+
fir::runtime::getRuntimeFunc<mkRTKey(CopyAndUpdateDescriptor)>(loc,
34+
builder);
35+
auto fTy = func.getFunctionType();
36+
auto args =
37+
fir::runtime::createArguments(builder, loc, fTy, to, from, newDynamicType,
38+
newAttribute, newLowerBounds);
39+
llvm::StringRef noCapture = mlir::LLVM::LLVMDialect::getNoCaptureAttrName();
40+
if (!func.getArgAttr(0, noCapture)) {
41+
mlir::UnitAttr unitAttr = mlir::UnitAttr::get(func.getContext());
42+
func.setArgAttr(0, noCapture, unitAttr);
43+
func.setArgAttr(1, noCapture, unitAttr);
44+
}
45+
builder.create<fir::CallOp>(loc, func, args);
46+
}

flang/lib/Optimizer/Dialect/FIRType.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,17 @@ fir::BaseBoxType::getBoxTypeWithNewShape(mlir::Type shapeMold) const {
13241324
return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
13251325
}
13261326

1327+
fir::BaseBoxType fir::BaseBoxType::getBoxTypeWithNewShape(int rank) const {
1328+
std::optional<fir::SequenceType::ShapeRef> newShape;
1329+
fir::SequenceType::Shape shapeVector;
1330+
if (rank > 0) {
1331+
shapeVector =
1332+
fir::SequenceType::Shape(rank, fir::SequenceType::getUnknownExtent());
1333+
newShape = shapeVector;
1334+
}
1335+
return mlir::cast<fir::BaseBoxType>(changeTypeShape(*this, newShape));
1336+
}
1337+
13271338
bool fir::BaseBoxType::isAssumedRank() const {
13281339
if (auto seqTy =
13291340
mlir::dyn_cast<fir::SequenceType>(fir::unwrapRefType(getEleTy())))
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
//===-- AssumedRankOpConversion.cpp ---------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "flang/Common/Fortran.h"
10+
#include "flang/Lower/BuiltinModules.h"
11+
#include "flang/Optimizer/Builder/FIRBuilder.h"
12+
#include "flang/Optimizer/Builder/Runtime/Support.h"
13+
#include "flang/Optimizer/Builder/Todo.h"
14+
#include "flang/Optimizer/Dialect/FIRDialect.h"
15+
#include "flang/Optimizer/Dialect/FIROps.h"
16+
#include "flang/Optimizer/Support/TypeCode.h"
17+
#include "flang/Optimizer/Support/Utils.h"
18+
#include "flang/Optimizer/Transforms/Passes.h"
19+
#include "flang/Runtime/support.h"
20+
#include "mlir/Dialect/Func/IR/FuncOps.h"
21+
#include "mlir/Pass/Pass.h"
22+
#include "mlir/Transforms/DialectConversion.h"
23+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24+
25+
namespace fir {
26+
#define GEN_PASS_DEF_ASSUMEDRANKOPCONVERSION
27+
#include "flang/Optimizer/Transforms/Passes.h.inc"
28+
} // namespace fir
29+
30+
using namespace fir;
31+
using namespace mlir;
32+
33+
namespace {
34+
35+
static int getCFIAttribute(mlir::Type boxType) {
36+
if (fir::isAllocatableType(boxType))
37+
return CFI_attribute_allocatable;
38+
if (fir::isPointerType(boxType))
39+
return CFI_attribute_pointer;
40+
return CFI_attribute_other;
41+
}
42+
43+
static Fortran::runtime::LowerBoundModifier
44+
getLowerBoundModifier(fir::LowerBoundModifierAttribute modifier) {
45+
switch (modifier) {
46+
case fir::LowerBoundModifierAttribute::Preserve:
47+
return Fortran::runtime::LowerBoundModifier::Preserve;
48+
case fir::LowerBoundModifierAttribute::SetToOnes:
49+
return Fortran::runtime::LowerBoundModifier::SetToOnes;
50+
case fir::LowerBoundModifierAttribute::SetToZeroes:
51+
return Fortran::runtime::LowerBoundModifier::SetToZeroes;
52+
}
53+
llvm_unreachable("bad modifier code");
54+
}
55+
56+
class ReboxAssumedRankConv
57+
: public mlir::OpRewritePattern<fir::ReboxAssumedRankOp> {
58+
public:
59+
using OpRewritePattern::OpRewritePattern;
60+
61+
ReboxAssumedRankConv(mlir::MLIRContext *context,
62+
mlir::SymbolTable *symbolTable, fir::KindMapping kindMap)
63+
: mlir::OpRewritePattern<fir::ReboxAssumedRankOp>(context),
64+
symbolTable{symbolTable}, kindMap{kindMap} {};
65+
66+
mlir::LogicalResult
67+
matchAndRewrite(fir::ReboxAssumedRankOp rebox,
68+
mlir::PatternRewriter &rewriter) const override {
69+
fir::FirOpBuilder builder{rewriter, kindMap, symbolTable};
70+
mlir::Location loc = rebox.getLoc();
71+
auto newBoxType = mlir::cast<fir::BaseBoxType>(rebox.getType());
72+
mlir::Type newMaxRankBoxType =
73+
newBoxType.getBoxTypeWithNewShape(Fortran::common::maxRank);
74+
// CopyAndUpdateDescriptor FIR interface requires loading
75+
// !fir.ref<fir.box> input which is expensive with assumed-rank. It could
76+
// be best to add an entry point that takes a non "const" from to cover
77+
// this case, but it would be good to indicate to LLVM that from does not
78+
// get modified.
79+
if (fir::isBoxAddress(rebox.getBox().getType()))
80+
TODO(loc, "fir.rebox_assumed_rank codegen with fir.ref<fir.box<>> input");
81+
mlir::Value tempDesc = builder.createTemporary(loc, newMaxRankBoxType);
82+
mlir::Value newDtype;
83+
mlir::Type newEleType = newBoxType.unwrapInnerType();
84+
auto oldBoxType = mlir::cast<fir::BaseBoxType>(
85+
fir::unwrapRefType(rebox.getBox().getType()));
86+
auto newDerivedType = mlir::dyn_cast<fir::RecordType>(newEleType);
87+
if (newDerivedType && (newEleType != oldBoxType.unwrapInnerType()) &&
88+
!fir::isPolymorphicType(newBoxType)) {
89+
newDtype = builder.create<fir::TypeDescOp>(
90+
loc, mlir::TypeAttr::get(newDerivedType));
91+
} else {
92+
newDtype = builder.createNullConstant(loc);
93+
}
94+
mlir::Value newAttribute = builder.createIntegerConstant(
95+
loc, builder.getIntegerType(8), getCFIAttribute(newBoxType));
96+
int lbsModifierCode =
97+
static_cast<int>(getLowerBoundModifier(rebox.getLbsModifier()));
98+
mlir::Value lowerBoundModifier = builder.createIntegerConstant(
99+
loc, builder.getIntegerType(32), lbsModifierCode);
100+
fir::runtime::genCopyAndUpdateDescriptor(builder, loc, tempDesc,
101+
rebox.getBox(), newDtype,
102+
newAttribute, lowerBoundModifier);
103+
104+
mlir::Value descValue = builder.create<fir::LoadOp>(loc, tempDesc);
105+
mlir::Value castDesc = builder.createConvert(loc, newBoxType, descValue);
106+
rewriter.replaceOp(rebox, castDesc);
107+
return mlir::success();
108+
}
109+
110+
private:
111+
mlir::SymbolTable *symbolTable = nullptr;
112+
fir::KindMapping kindMap;
113+
};
114+
115+
/// Convert FIR structured control flow ops to CFG ops.
116+
class AssumedRankOpConversion
117+
: public fir::impl::AssumedRankOpConversionBase<AssumedRankOpConversion> {
118+
public:
119+
void runOnOperation() override {
120+
auto *context = &getContext();
121+
mlir::ModuleOp mod = getOperation();
122+
mlir::SymbolTable symbolTable(mod);
123+
fir::KindMapping kindMap = fir::getKindMapping(mod);
124+
mlir::RewritePatternSet patterns(context);
125+
patterns.insert<ReboxAssumedRankConv>(context, &symbolTable, kindMap);
126+
mlir::GreedyRewriteConfig config;
127+
config.enableRegionSimplification = false;
128+
(void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config);
129+
}
130+
};
131+
} // namespace

flang/lib/Optimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_flang_library(FIRTransforms
44
AffinePromotion.cpp
55
AffineDemotion.cpp
66
AnnotateConstant.cpp
7+
AssumedRankOpConversion.cpp
78
CharacterConversion.cpp
89
ControlFlowConverter.cpp
910
ArrayValueCopy.cpp

flang/test/Driver/bbc-mlir-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
! CHECK-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
4747

4848
! CHECK-NEXT: PolymorphicOpConversion
49+
! CHECK-NEXT: AssumedRankOpConversion
4950

5051
! CHECK-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
5152
! CHECK-NEXT: 'fir.global' Pipeline

flang/test/Driver/mlir-debug-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
7474

7575
! ALL-NEXT: PolymorphicOpConversion
76+
! ALL-NEXT: AssumedRankOpConversion
7677

7778
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']
7879
! ALL-NEXT: 'fir.global' Pipeline

flang/test/Driver/mlir-pass-pipeline.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
! ALL-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
8181

8282
! ALL-NEXT: PolymorphicOpConversion
83+
! ALL-NEXT: AssumedRankOpConversion
8384
! O2-NEXT: AddAliasTags
8485

8586
! ALL-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']

flang/test/Fir/basic-program.fir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ func.func @_QQmain() {
8080
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
8181

8282
// PASSES-NEXT: PolymorphicOpConversion
83+
// PASSES-NEXT: AssumedRankOpConversion
8384
// PASSES-NEXT: AddAliasTags
8485

8586
// PASSES-NEXT: Pipeline Collection : ['fir.global', 'func.func', 'omp.declare_reduction', 'omp.private']

0 commit comments

Comments
 (0)