Skip to content

Commit 4b964eb

Browse files
authored
Merge pull request #430 from schweitzpgi/ch-embox
Code cleanup for fir.embox op. Use some of MLIR's new facilities, etc…
2 parents c8ec106 + 50765a2 commit 4b964eb

File tree

8 files changed

+156
-257
lines changed

8 files changed

+156
-257
lines changed

flang/include/flang/Lower/Support/BoxValue.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,6 @@ class ExtendedValue : public details::matcher<ExtendedValue> {
250250

251251
friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
252252
const ExtendedValue &);
253-
friend mlir::Value getBase(const ExtendedValue &exv);
254-
friend mlir::Value getLen(const ExtendedValue &exv);
255-
friend ExtendedValue substBase(const ExtendedValue &exv, mlir::Value base);
256253

257254
const VT &matchee() const { return box; }
258255

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 46 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,13 @@ def AnyShapeLike : TypeConstraint<Or<[fir_ShapeType.predicate,
109109
fir_ShapeShiftType.predicate]>, "any legal shape type">;
110110
def AnyShapeType : Type<AnyShapeLike.predicate, "any legal shape type">;
111111
def fir_SliceType : Type<CPred<"$_self.isa<fir::SliceType>()">, "slice type">;
112-
def AnyEmboxLike : TypeConstraint<Or<[AnyIntegerType.predicate,
113-
AnyShapeType.predicate, fir_SliceType.predicate]>,
114-
"any legal embox argument type">;
115-
def AnyEmboxArg : Type<AnyEmboxLike.predicate, "embox argument type">;
112+
def AffineMapAttr : Attr<
113+
CPred<"$_self.isa<mlir::AffineMapAttr>()">, "AffineMap attribute"> {
114+
let storageType = "mlir::AffineMapAttr";
115+
let returnType = "mlir::AffineMap";
116+
let valueType = Index;
117+
let constBuilderCall = "mlir::AffineMapAttr::get($0)";
118+
}
116119

117120
// A type descriptor's type
118121
def fir_TypeDescType : Type<CPred<"$_self.isa<fir::TypeDescType>()">,
@@ -1034,9 +1037,11 @@ def fir_HasValueOp : fir_Op<"has_value", [Terminator, HasParent<"GlobalOp">]> {
10341037
let assemblyFormat = "$resval attr-dict `:` type($resval)";
10351038
}
10361039

1040+
//===------------------------------------------------------------------------===//
10371041
// Operations on !fir.box<T> type objects
1042+
//===------------------------------------------------------------------------===//
10381043

1039-
def fir_EmboxOp : fir_Op<"embox", [NoSideEffect]> {
1044+
def fir_EmboxOp : fir_Op<"embox", [NoSideEffect, AttrSizedOperandSegments]> {
10401045
let summary = "boxes a given reference and (optional) dimension information";
10411046

10421047
let description = [{
@@ -1056,51 +1061,43 @@ def fir_EmboxOp : fir_Op<"embox", [NoSideEffect]> {
10561061
information through the use of additional attributes.
10571062
}];
10581063

1059-
let arguments = (ins AnyReferenceLike:$memref, Variadic<AnyEmboxArg>:$args);
1064+
let arguments = (ins
1065+
AnyReferenceLike:$memref,
1066+
Optional<AnyShapeType>:$shape,
1067+
Optional<fir_SliceType>:$slice,
1068+
Variadic<AnyIntegerType>:$lenParams,
1069+
OptionalAttr<AffineMapAttr>:$accessMap
1070+
);
10601071

10611072
let results = (outs fir_BoxType);
10621073

1063-
let parser = "return parseEmboxOp(parser, result);";
1074+
let builders = [
1075+
OpBuilder<"mlir::OpBuilder &builder, mlir::OperationState &state,"
1076+
"llvm::ArrayRef<mlir::Type> resultTypes, mlir::Value memref,"
1077+
"mlir::Value shape = {}, mlir::Value slice = {},"
1078+
"mlir::ValueRange lenParams = {}", [{
1079+
return build(builder, state, resultTypes, memref, shape, slice,
1080+
lenParams, mlir::AffineMapAttr{}); }]>
1081+
];
10641082

1065-
let printer = [{
1066-
p << getOperationName() << ' ';
1067-
p.printOperand(memref());
1068-
if (auto shape = getShape()) {
1069-
p << '(';
1070-
p.printOperand(shape);
1071-
p << ')';
1072-
}
1073-
if (auto slice = getSlice()) {
1074-
p << '[';
1075-
p.printOperand(slice);
1076-
p << ']';
1077-
}
1078-
if (auto map = getAttr(layoutName()))
1079-
p << " map " << map;
1080-
if (hasLenParams()) {
1081-
p << " typeparams ";
1082-
p.printOperands(getLenParams());
1083-
}
1084-
p.printOptionalAttrDict(getAttrs(), {layoutName(), lenpName(), shapeName(),
1085-
sliceName()});
1086-
p << " : ";
1087-
p.printFunctionalType(getOperation());
1083+
let assemblyFormat = [{
1084+
$memref (`(` $shape^ `)`)? (`[` $slice^ `]`)? (`typeparams` $lenParams^)? (`map` $accessMap^)? attr-dict `:` functional-type(operands, results)
10881085
}];
10891086

10901087
let verifier = [{
10911088
auto eleTy = fir::dyn_cast_ptrEleTy(memref().getType());
10921089
if (!eleTy)
10931090
return emitOpError("must embox a memory reference type");
10941091
if (hasLenParams()) {
1095-
auto lenParams = numLenParams();
1092+
auto lenPs = numLenParams();
10961093
if (auto rt = eleTy.dyn_cast<fir::RecordType>()) {
1097-
if (lenParams != rt.getNumLenParams())
1094+
if (lenPs != rt.getNumLenParams())
10981095
return emitOpError("number of LEN params does not correspond"
10991096
" to the !fir.type type");
11001097
} else {
11011098
return emitOpError("LEN parameters require !fir.type type");
11021099
}
1103-
for (auto lp : getLenParams())
1100+
for (auto lp : lenParams())
11041101
if (!fir::isa_integer(lp.getType()))
11051102
return emitOpError("LEN parameters must be integral type");
11061103
}
@@ -1118,42 +1115,10 @@ def fir_EmboxOp : fir_Op<"embox", [NoSideEffect]> {
11181115
}];
11191116

11201117
let extraClassDeclaration = [{
1121-
static constexpr llvm::StringRef layoutName() { return "layout_map"; }
1122-
static constexpr llvm::StringRef lenpName() { return "len_param_count"; }
1123-
static constexpr llvm::StringRef shapeName() { return "shape"; }
1124-
static constexpr llvm::StringRef sliceName() { return "slice"; }
1125-
1126-
mlir::Value getShape() {
1127-
if (auto x = getAttrOfType<mlir::UnitAttr>(shapeName()))
1128-
return *std::next(operand_begin());
1129-
return {};
1130-
}
1131-
1132-
mlir::Value getSlice() {
1133-
if (auto x = getAttrOfType<mlir::UnitAttr>(sliceName())) {
1134-
auto iter = std::next(operand_begin());
1135-
if (getShape())
1136-
iter = std::next(iter);
1137-
return *iter;
1138-
}
1139-
return {};
1140-
}
1141-
1142-
bool hasLenParams() { return bool{getAttr(lenpName())}; }
1143-
unsigned numLenParams() {
1144-
if (auto x = getAttrOfType<mlir::IntegerAttr>(lenpName()))
1145-
return x.getInt();
1146-
return 0;
1147-
}
1148-
1149-
operand_range getLenParams() {
1150-
auto iter = std::next(operand_begin());
1151-
if (getShape())
1152-
iter = std::next(iter);
1153-
if (getSlice())
1154-
iter = std::next(iter);
1155-
return {iter, operand_end()};
1156-
}
1118+
mlir::Value getShape() { return shape(); }
1119+
mlir::Value getSlice() { return slice(); }
1120+
bool hasLenParams() { return !lenParams().empty(); }
1121+
unsigned numLenParams() { return lenParams().size(); }
11571122
}];
11581123
}
11591124

@@ -1576,7 +1541,9 @@ def fir_BoxTypeDescOp : fir_SimpleOneResultOp<"box_tdesc", [NoSideEffect]> {
15761541
}
15771542

15781543
// Record and array type operations
1579-
def fir_ArrayCoorOp : fir_Op<"array_coor", [NoSideEffect, AttrSizedOperandSegments]> {
1544+
def fir_ArrayCoorOp : fir_Op<"array_coor",
1545+
[NoSideEffect, AttrSizedOperandSegments]> {
1546+
15801547
let summary = "Find the coordinate of an element of an array";
15811548

15821549
let description = [{
@@ -1601,14 +1568,18 @@ def fir_ArrayCoorOp : fir_Op<"array_coor", [NoSideEffect, AttrSizedOperandSegmen
16011568
```
16021569
}];
16031570

1604-
let arguments = (ins AnyReferenceLike:$memref,
1605-
Optional<AnyShapeType>:$shape, Optional<fir_SliceType>:$slice,
1606-
Variadic<AnyCoordinateType>:$indices,
1607-
Variadic<AnyIntegerType>:$lenParams);
1571+
let arguments = (ins
1572+
AnyReferenceLike:$memref,
1573+
Optional<AnyShapeType>:$shape,
1574+
Optional<fir_SliceType>:$slice,
1575+
Variadic<AnyCoordinateType>:$indices,
1576+
Variadic<AnyIntegerType>:$lenParams
1577+
);
16081578

16091579
let results = (outs fir_ReferenceType);
1580+
16101581
let assemblyFormat = [{
1611-
$memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams` $lenParams^)? `:` functional-type(operands, results) attr-dict
1582+
$memref (`(`$shape^`)`)? (`[`$slice^`]`)? $indices (`typeparams` $lenParams^)? attr-dict `:` functional-type(operands, results)
16121583
}];
16131584

16141585
let verifier = [{

flang/include/flang/Optimizer/Support/Matcher.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
#include <variant>
1717

18-
// Boilerplate CRTP class for a simplified type-casing syntactic sugar.
18+
// Boilerplate CRTP class for a simplified type-casing syntactic sugar. This
19+
// lets one write pattern matchers using a more compact syntax.
1920
namespace fir::details {
2021
// clang-format off
2122
template<class... Ts> struct matches : Ts... { using Ts::operator()...; };

flang/lib/Lower/ConvertExpr.cpp

Lines changed: 63 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -273,16 +273,14 @@ class ExprLowering {
273273

274274
fir::ExtendedValue getExValue(const Fortran::lower::SymbolBox &symBox) {
275275
using T = fir::ExtendedValue;
276-
return std::visit(
277-
Fortran::common::visitors{
278-
[](const Fortran::lower::SymbolBox::Intrinsic &box) -> T {
279-
return box.getAddr();
280-
},
281-
[](const auto &box) -> T { return box; },
282-
[](const Fortran::lower::SymbolBox::None &) -> T {
283-
llvm_unreachable("symbol not mapped");
284-
}},
285-
symBox.box);
276+
return symBox.match(
277+
[](const Fortran::lower::SymbolBox::Intrinsic &box) -> T {
278+
return box.getAddr();
279+
},
280+
[](const Fortran::lower::SymbolBox::None &) -> T {
281+
llvm_unreachable("symbol not mapped");
282+
},
283+
[](const auto &box) -> T { return box; });
286284
}
287285

288286
/// Returns a reference to a symbol or its box/boxChar descriptor if it has
@@ -1098,25 +1096,23 @@ class ExprLowering {
10981096
// We need some context here, since we could also box as an argument
10991097
llvm::report_fatal_error("TODO: array slice not supported");
11001098
};
1101-
return std::visit(
1102-
Fortran::common::visitors{
1103-
[&](const Fortran::lower::SymbolBox::FullDim &arr) {
1104-
if (!inArrayContext() && isSlice(aref))
1105-
return genArraySlice(arr);
1106-
return genFullDim(arr, one);
1107-
},
1108-
[&](const Fortran::lower::SymbolBox::CharFullDim &arr) {
1109-
return genFullDim(arr, arr.getLen());
1110-
},
1111-
[&](const Fortran::lower::SymbolBox::Derived &arr) {
1112-
TODO();
1113-
return mlir::Value{};
1114-
},
1115-
[&](const auto &) {
1116-
TODO();
1117-
return mlir::Value{};
1118-
}},
1119-
si.box);
1099+
return si.match(
1100+
[&](const Fortran::lower::SymbolBox::FullDim &arr) {
1101+
if (!inArrayContext() && isSlice(aref))
1102+
return genArraySlice(arr);
1103+
return genFullDim(arr, one);
1104+
},
1105+
[&](const Fortran::lower::SymbolBox::CharFullDim &arr) {
1106+
return genFullDim(arr, arr.getLen());
1107+
},
1108+
[&](const Fortran::lower::SymbolBox::Derived &arr) {
1109+
TODO();
1110+
return mlir::Value{};
1111+
},
1112+
[&](const auto &) {
1113+
TODO();
1114+
return mlir::Value{};
1115+
});
11201116
}
11211117

11221118
fir::ExtendedValue genArrayCoorOp(const Fortran::lower::SymbolBox &si,
@@ -1163,28 +1159,26 @@ class ExprLowering {
11631159
return builder.create<fir::ArrayCoorOp>(
11641160
loc, refTy, addr, shape, mlir::Value{}, arrayCoorArgs, ValueRange());
11651161
};
1166-
return std::visit(
1167-
Fortran::common::visitors{
1168-
[&](const Fortran::lower::SymbolBox::FullDim &arr) {
1169-
if (!inArrayContext() && isSlice(aref)) {
1170-
TODO();
1171-
return mlir::Value{};
1172-
}
1173-
return genWithShape(arr);
1174-
},
1175-
[&](const Fortran::lower::SymbolBox::CharFullDim &arr) {
1176-
TODO();
1177-
return mlir::Value{};
1178-
},
1179-
[&](const Fortran::lower::SymbolBox::Derived &arr) {
1180-
TODO();
1181-
return mlir::Value{};
1182-
},
1183-
[&](const auto &) {
1184-
TODO();
1185-
return mlir::Value{};
1186-
}},
1187-
si.box);
1162+
return si.match(
1163+
[&](const Fortran::lower::SymbolBox::FullDim &arr) {
1164+
if (!inArrayContext() && isSlice(aref)) {
1165+
TODO();
1166+
return mlir::Value{};
1167+
}
1168+
return genWithShape(arr);
1169+
},
1170+
[&](const Fortran::lower::SymbolBox::CharFullDim &arr) {
1171+
TODO();
1172+
return mlir::Value{};
1173+
},
1174+
[&](const Fortran::lower::SymbolBox::Derived &arr) {
1175+
TODO();
1176+
return mlir::Value{};
1177+
},
1178+
[&](const auto &) {
1179+
TODO();
1180+
return mlir::Value{};
1181+
});
11881182
}
11891183

11901184
// Return the coordinate of the array reference
@@ -1643,31 +1637,22 @@ fir::ExtendedValue Fortran::lower::createStringLiteral(
16431637
//===----------------------------------------------------------------------===//
16441638

16451639
mlir::Value fir::getBase(const fir::ExtendedValue &exv) {
1646-
return std::visit(Fortran::common::visitors{
1647-
[](const fir::UnboxedValue &x) { return x; },
1648-
[](const auto &x) { return x.getAddr(); },
1649-
},
1650-
exv.box);
1640+
return exv.match([](const fir::UnboxedValue &x) { return x; },
1641+
[](const auto &x) { return x.getAddr(); });
16511642
}
16521643

16531644
mlir::Value fir::getLen(const fir::ExtendedValue &exv) {
1654-
return std::visit(
1655-
Fortran::common::visitors{
1656-
[](const fir::CharBoxValue &x) { return x.getLen(); },
1657-
[](const fir::CharArrayBoxValue &x) { return x.getLen(); },
1658-
[](const fir::BoxValue &x) { return x.getLen(); },
1659-
[](const auto &) { return mlir::Value{}; }},
1660-
exv.box);
1645+
return exv.match([](const fir::CharBoxValue &x) { return x.getLen(); },
1646+
[](const fir::CharArrayBoxValue &x) { return x.getLen(); },
1647+
[](const fir::BoxValue &x) { return x.getLen(); },
1648+
[](const auto &) { return mlir::Value{}; });
16611649
}
16621650

1663-
fir::ExtendedValue fir::substBase(const fir::ExtendedValue &ex,
1651+
fir::ExtendedValue fir::substBase(const fir::ExtendedValue &exv,
16641652
mlir::Value base) {
1665-
return std::visit(
1666-
Fortran::common::visitors{
1667-
[&](const fir::UnboxedValue &x) { return fir::ExtendedValue(base); },
1668-
[&](const auto &x) { return fir::ExtendedValue(x.clone(base)); },
1669-
},
1670-
ex.box);
1653+
return exv.match(
1654+
[=](const fir::UnboxedValue &x) { return fir::ExtendedValue(base); },
1655+
[=](const auto &x) { return fir::ExtendedValue(x.clone(base)); });
16711656
}
16721657

16731658
llvm::raw_ostream &fir::operator<<(llvm::raw_ostream &os,
@@ -1750,14 +1735,13 @@ void Fortran::lower::SymMap::dump() const {
17501735
auto &os = llvm::errs();
17511736
for (auto iter : symbolMap) {
17521737
os << "symbol [" << *iter.first << "] ->\n\t";
1753-
std::visit(Fortran::common::visitors{
1754-
[&](const Fortran::lower::SymbolBox::None &box) {
1755-
os << "** symbol not properly mapped **\n";
1756-
},
1757-
[&](const Fortran::lower::SymbolBox::Intrinsic &val) {
1758-
os << val.getAddr() << '\n';
1759-
},
1760-
[&](const auto &box) { os << box << '\n'; }},
1761-
iter.second.box);
1738+
iter.second.match(
1739+
[&](const Fortran::lower::SymbolBox::None &box) {
1740+
os << "** symbol not properly mapped **\n";
1741+
},
1742+
[&](const Fortran::lower::SymbolBox::Intrinsic &val) {
1743+
os << val.getAddr() << '\n';
1744+
},
1745+
[&](const auto &box) { os << box << '\n'; });
17621746
}
17631747
}

0 commit comments

Comments
 (0)