Skip to content

Commit 1e42f7a

Browse files
committed
OpaqueType: Use format string
1 parent 2015abf commit 1e42f7a

File tree

6 files changed

+141
-12
lines changed

6 files changed

+141
-12
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,20 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
9999
```
100100
}];
101101

102-
let parameters = (ins StringRefParameter<"the opaque value">:$value);
103-
let assemblyFormat = "`<` $value `>`";
102+
let parameters = (ins StringRefParameter<"the opaque value">:$value,
103+
OptionalArrayRefParameter<"Type">:$fmtArgs);
104+
let assemblyFormat = "`<` $value (`,` custom<VariadicFmtArgs>($fmtArgs)^)? `>`";
104105
let genVerifyDecl = 1;
106+
107+
let builders = [TypeBuilder<(ins "::llvm::StringRef":$value), [{ return $_get($_ctxt, value, SmallVector<Type>{}); }] >];
108+
109+
let extraClassDeclaration = [{
110+
// Either a literal string, or an placeholder for the fmtArgs.
111+
struct Placeholder {};
112+
using ReplacementItem = std::variant<StringRef, Placeholder>;
113+
114+
FailureOr<SmallVector<ReplacementItem>> parseFormatString();
115+
}];
105116
}
106117

107118
def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/ADT/StringExtras.h"
2121
#include "llvm/ADT/TypeSwitch.h"
2222
#include "llvm/Support/Casting.h"
23+
#include "llvm/Support/FormatVariadic.h"
2324

2425
using namespace mlir;
2526
using namespace mlir::emitc;
@@ -929,22 +930,44 @@ LogicalResult emitc::VerbatimOp::verify() {
929930
return success();
930931
}
931932

933+
static ParseResult parseVariadicFmtArgs(AsmParser &p,
934+
SmallVector<Type> &params) {
935+
Type type;
936+
if (p.parseType(type))
937+
return failure();
938+
939+
params.push_back(type);
940+
while (succeeded(p.parseOptionalComma())) {
941+
if (p.parseType(type))
942+
return failure();
943+
params.push_back(type);
944+
}
945+
946+
return success();
947+
}
948+
949+
static void printVariadicFmtArgs(AsmPrinter &p, ArrayRef<Type> params) {
950+
llvm::interleaveComma(params, p, [&](Type type) { p.printType(type); });
951+
}
952+
932953
/// Parse a format string and return a list of its parts.
933954
/// A part is either a StringRef that has to be printed as-is, or
934955
/// a Placeholder which requires printing the next operand of the VerbatimOp.
935956
/// In the format string, all `{}` are replaced by Placeholders, except if the
936-
/// `{` is escaped by `{{` - then it doesn't start a placeholder.
937-
FailureOr<SmallVector<emitc::VerbatimOp::ReplacementItem>>
938-
emitc::VerbatimOp::parseFormatString() {
939-
SmallVector<ReplacementItem> items;
957+
/// `{` is escaped by `{{` - then it doesn't start a placeholder
958+
template <typename Op, class ArgType>
959+
FailureOr<SmallVector<typename Op::ReplacementItem>>
960+
parseFormatString(StringRef toParse, ArgType fmtArgs,
961+
std::optional<llvm::function_ref<mlir::InFlightDiagnostic()>>
962+
emitError = {}) {
963+
SmallVector<typename Op::ReplacementItem> items;
940964

941965
// If there are not operands, the format string is not interpreted.
942-
if (getFmtArgs().empty()) {
943-
items.push_back(getValue());
966+
if (fmtArgs.empty()) {
967+
items.push_back(toParse);
944968
return items;
945969
}
946970

947-
StringRef toParse = getValue();
948971
while (!toParse.empty()) {
949972
size_t idx = toParse.find('{');
950973
if (idx == StringRef::npos) {
@@ -972,15 +995,28 @@ emitc::VerbatimOp::parseFormatString() {
972995
continue;
973996
}
974997
if (nextChar == '}') {
975-
items.push_back(Placeholder{});
998+
items.push_back(typename Op::Placeholder{});
976999
toParse = toParse.drop_front(2);
9771000
continue;
9781001
}
979-
return emitOpError() << "expected '}' after unescaped '{'";
1002+
1003+
if (emitError.has_value()) {
1004+
return (*emitError)() << "expected '}' after unescaped '{'";
1005+
}
1006+
return failure();
9801007
}
9811008
return items;
9821009
}
9831010

1011+
FailureOr<SmallVector<emitc::VerbatimOp::ReplacementItem>>
1012+
emitc::VerbatimOp::parseFormatString() {
1013+
auto errorCallback = [&]() -> InFlightDiagnostic {
1014+
return this->emitError();
1015+
};
1016+
return ::parseFormatString<emitc::VerbatimOp>(getValue(), getFmtArgs(),
1017+
errorCallback);
1018+
}
1019+
9841020
//===----------------------------------------------------------------------===//
9851021
// EmitC Enums
9861022
//===----------------------------------------------------------------------===//
@@ -1072,17 +1108,37 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
10721108

10731109
LogicalResult mlir::emitc::OpaqueType::verify(
10741110
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1075-
llvm::StringRef value) {
1111+
llvm::StringRef value, ArrayRef<Type> fmtArgs) {
10761112
if (value.empty()) {
10771113
return emitError() << "expected non empty string in !emitc.opaque type";
10781114
}
10791115
if (value.back() == '*') {
10801116
return emitError() << "pointer not allowed as outer type with "
10811117
"!emitc.opaque, use !emitc.ptr instead";
10821118
}
1119+
1120+
FailureOr<SmallVector<ReplacementItem>> fmt =
1121+
::parseFormatString<emitc::OpaqueType>(value, fmtArgs, emitError);
1122+
if (failed(fmt))
1123+
return failure();
1124+
1125+
size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) {
1126+
return std::holds_alternative<Placeholder>(item);
1127+
});
1128+
1129+
if (numPlaceholders != fmtArgs.size()) {
1130+
return emitError()
1131+
<< "requires operands for each placeholder in the format string";
1132+
}
1133+
10831134
return success();
10841135
}
10851136

1137+
FailureOr<SmallVector<emitc::OpaqueType::ReplacementItem>>
1138+
emitc::OpaqueType::parseFormatString() {
1139+
return ::parseFormatString<emitc::OpaqueType>(getValue(), getFmtArgs());
1140+
}
1141+
10861142
//===----------------------------------------------------------------------===//
10871143
// GlobalOp
10881144
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,6 +1728,24 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
17281728
if (auto tType = dyn_cast<TupleType>(type))
17291729
return emitTupleType(loc, tType.getTypes());
17301730
if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1731+
FailureOr<SmallVector<emitc::OpaqueType::ReplacementItem>> items =
1732+
oType.parseFormatString();
1733+
if (failed(items))
1734+
return failure();
1735+
1736+
auto fmtArg = oType.getFmtArgs().begin();
1737+
for (emitc::OpaqueType::ReplacementItem &item : *items) {
1738+
if (auto *str = std::get_if<StringRef>(&item)) {
1739+
os << *str;
1740+
} else {
1741+
if (failed(emitType(loc, *fmtArg++))) {
1742+
return failure();
1743+
}
1744+
}
1745+
}
1746+
1747+
return success();
1748+
17311749
os << oType.getValue();
17321750
return success();
17331751
}

mlir/test/Dialect/EmitC/invalid_types.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,34 @@ func.func @illegal_opaque_type_2() {
1414

1515
// -----
1616

17+
func.func @illegal_opaque_type() {
18+
// expected-error @+1 {{expected non-function type}}
19+
%1 = "emitc.variable"(){value = "42" : !emitc.opaque<"{}, {}", "string">} : () -> !emitc.opaque<"mytype">
20+
}
21+
22+
// -----
23+
24+
func.func @illegal_opaque_type() {
25+
// expected-error @+1 {{requires operands for each placeholder in the format string}}
26+
%1 = "emitc.variable"(){value = "42" : !emitc.opaque<"a", f32>} : () -> !emitc.opaque<"mytype">
27+
}
28+
29+
// -----
30+
31+
func.func @illegal_opaque_type() {
32+
// expected-error @+1 {{requires operands for each placeholder in the format string}}
33+
%1 = "emitc.variable"(){value = "42" : !emitc.opaque<"{}, {}", f32>} : () -> !emitc.opaque<"mytype">
34+
}
35+
36+
// -----
37+
38+
func.func @illegal_opaque_type() {
39+
// expected-error @+1 {{expected '}' after unescaped '{'}}
40+
%1 = "emitc.variable"(){value = "42" : !emitc.opaque<"{ ", i32>} : () -> !emitc.opaque<"mytype">
41+
}
42+
43+
// -----
44+
1745
func.func @illegal_array_missing_spec(
1846
// expected-error @+1 {{expected non-function type}}
1947
%arg0: !emitc.array<>) {

mlir/test/Dialect/EmitC/types.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ func.func @opaque_types() {
3838
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector<std::string>">]} : () -> ()
3939
// CHECK-NEXT: !emitc.opaque<"SmallVector<int*, 4>">
4040
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"SmallVector<int*, 4>">]} : () -> ()
41+
// CHECK-NEXT: !emitc.opaque<"{}", i32>
42+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", i32>>]} : () -> ()
43+
// CHECK-NEXT: !emitc.opaque<"{}, {}", i32, f32>]
44+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}, {}", i32, f32>>]} : () -> ()
45+
// CHECK-NEXT: !emitc.opaque<"{}"
46+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}">>]} : () -> ()
4147

4248
return
4349
}

mlir/test/Target/Cpp/types.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ func.func @opaque_types() {
1212
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"status_t">>]} : () -> ()
1313
// CHECK-NEXT: f<std::vector<std::string>>();
1414
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector<std::string>">]} : () -> ()
15+
// CHECK: f<float>()
16+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", f32>>]} : () -> ()
17+
// CHECK: f<int16_t {>();
18+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{} {{", si16>>]} : () -> ()
19+
// CHECK: f<int8_t {>();
20+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{} {", i8>>]} : () -> ()
21+
// CHECK: f<status_t>();
22+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", !emitc<opaque<"status_t">> >>]} : () -> ()
23+
// CHECK: f<top<nested<float>,int32_t>>();
24+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"top<{},{}>", !emitc<opaque<"nested<{}>", f32>>, i32>>]} : () -> ()
1525

1626
return
1727
}

0 commit comments

Comments
 (0)