Skip to content

Commit ad4697c

Browse files
authored
OpaqueType with format strings (#391)
OpaqueType: Use format string
1 parent 4b36487 commit ad4697c

File tree

8 files changed

+182
-58
lines changed

8 files changed

+182
-58
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ bool isPointerWideType(mlir::Type type);
5252
/// Give the name of the EmitC reference attribute.
5353
StringRef getReferenceAttributeName();
5454

55+
// Either a literal string, or an placeholder for the fmtArgs.
56+
struct Placeholder {};
57+
using ReplacementItem = std::variant<StringRef, Placeholder>;
58+
5559
} // namespace emitc
5660
} // namespace mlir
5761

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,11 +1168,7 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> {
11681168
}];
11691169

11701170
let extraClassDeclaration = [{
1171-
// Either a literal string, or an placeholder for the fmtArgs.
1172-
struct Placeholder {};
1173-
using ReplacementItem = std::variant<StringRef, Placeholder>;
1174-
1175-
FailureOr<SmallVector<ReplacementItem>> parseFormatString();
1171+
FailureOr<SmallVector<::mlir::emitc::ReplacementItem>> parseFormatString();
11761172
}];
11771173

11781174
let arguments = (ins StrAttr:$value,

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,16 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> {
9999
```
100100
}];
101101

102-
let parameters = (ins StringRefParameter<"the opaque value">:$value);
103-
let assemblyFormat = "`<` $value `>`";
102+
let parameters = (ins StringRefParameter<"the opaque value">:$value,
103+
OptionalArrayRefParameter<"Type">:$fmtArgs);
104+
let assemblyFormat = "`<` $value (`,` custom<VariadicTypeFmtArgs>($fmtArgs)^)? `>`";
104105
let genVerifyDecl = 1;
106+
107+
let builders = [TypeBuilder<(ins "::llvm::StringRef":$value), [{ return $_get($_ctxt, value, SmallVector<Type>{}); }] >];
108+
109+
let extraClassDeclaration = [{
110+
FailureOr<SmallVector<::mlir::emitc::ReplacementItem>> parseFormatString();
111+
}];
105112
}
106113

107114
def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> {

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 105 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/ADT/StringExtras.h"
2121
#include "llvm/ADT/TypeSwitch.h"
2222
#include "llvm/Support/Casting.h"
23+
#include "llvm/Support/FormatVariadic.h"
2324

2425
using namespace mlir;
2526
using namespace mlir::emitc;
@@ -154,6 +155,64 @@ static LogicalResult verifyInitializationAttribute(Operation *op,
154155
return success();
155156
}
156157

158+
/// Parse a format string and return a list of its parts.
159+
/// A part is either a StringRef that has to be printed as-is, or
160+
/// a Placeholder which requires printing the next operand of the VerbatimOp.
161+
/// In the format string, all `{}` are replaced by Placeholders, except if the
162+
/// `{` is escaped by `{{` - then it doesn't start a placeholder.
163+
template <class ArgType>
164+
FailureOr<SmallVector<ReplacementItem>>
165+
parseFormatString(StringRef toParse, ArgType fmtArgs,
166+
std::optional<llvm::function_ref<mlir::InFlightDiagnostic()>>
167+
emitError = {}) {
168+
SmallVector<ReplacementItem> items;
169+
170+
// If there are not operands, the format string is not interpreted.
171+
if (fmtArgs.empty()) {
172+
items.push_back(toParse);
173+
return items;
174+
}
175+
176+
while (!toParse.empty()) {
177+
size_t idx = toParse.find('{');
178+
if (idx == StringRef::npos) {
179+
// No '{'
180+
items.push_back(toParse);
181+
break;
182+
}
183+
if (idx > 0) {
184+
// Take all chars excluding the '{'.
185+
items.push_back(toParse.take_front(idx));
186+
toParse = toParse.drop_front(idx);
187+
continue;
188+
}
189+
if (toParse.size() < 2) {
190+
// '{' is last character
191+
items.push_back(toParse);
192+
break;
193+
}
194+
// toParse contains at least two characters and starts with `{`.
195+
char nextChar = toParse[1];
196+
if (nextChar == '{') {
197+
// Double '{{' -> '{' (escaping).
198+
items.push_back(toParse.take_front(1));
199+
toParse = toParse.drop_front(2);
200+
continue;
201+
}
202+
if (nextChar == '}') {
203+
items.push_back(Placeholder{});
204+
toParse = toParse.drop_front(2);
205+
continue;
206+
}
207+
208+
if (emitError.has_value()) {
209+
return (*emitError)() << "expected '}' after unescaped '{'";
210+
}
211+
return failure();
212+
}
213+
return items;
214+
}
215+
157216
//===----------------------------------------------------------------------===//
158217
// AddOp
159218
//===----------------------------------------------------------------------===//
@@ -914,7 +973,11 @@ LogicalResult emitc::SubscriptOp::verify() {
914973
//===----------------------------------------------------------------------===//
915974

916975
LogicalResult emitc::VerbatimOp::verify() {
917-
FailureOr<SmallVector<ReplacementItem>> fmt = parseFormatString();
976+
auto errorCallback = [&]() -> InFlightDiagnostic {
977+
return this->emitOpError();
978+
};
979+
FailureOr<SmallVector<ReplacementItem>> fmt =
980+
::parseFormatString(getValue(), getFmtArgs(), errorCallback);
918981
if (failed(fmt))
919982
return failure();
920983

@@ -929,56 +992,29 @@ LogicalResult emitc::VerbatimOp::verify() {
929992
return success();
930993
}
931994

932-
/// Parse a format string and return a list of its parts.
933-
/// A part is either a StringRef that has to be printed as-is, or
934-
/// a Placeholder which requires printing the next operand of the VerbatimOp.
935-
/// In the format string, all `{}` are replaced by Placeholders, except if the
936-
/// `{` is escaped by `{{` - then it doesn't start a placeholder.
937-
FailureOr<SmallVector<emitc::VerbatimOp::ReplacementItem>>
938-
emitc::VerbatimOp::parseFormatString() {
939-
SmallVector<ReplacementItem> items;
995+
static ParseResult parseVariadicTypeFmtArgs(AsmParser &p,
996+
SmallVector<Type> &params) {
997+
Type type;
998+
if (p.parseType(type))
999+
return failure();
9401000

941-
// If there are not operands, the format string is not interpreted.
942-
if (getFmtArgs().empty()) {
943-
items.push_back(getValue());
944-
return items;
1001+
params.push_back(type);
1002+
while (succeeded(p.parseOptionalComma())) {
1003+
if (p.parseType(type))
1004+
return failure();
1005+
params.push_back(type);
9451006
}
9461007

947-
StringRef toParse = getValue();
948-
while (!toParse.empty()) {
949-
size_t idx = toParse.find('{');
950-
if (idx == StringRef::npos) {
951-
// No '{'
952-
items.push_back(toParse);
953-
break;
954-
}
955-
if (idx > 0) {
956-
// Take all chars excluding the '{'.
957-
items.push_back(toParse.take_front(idx));
958-
toParse = toParse.drop_front(idx);
959-
continue;
960-
}
961-
if (toParse.size() < 2) {
962-
// '{' is last character
963-
items.push_back(toParse);
964-
break;
965-
}
966-
// toParse contains at least two characters and starts with `{`.
967-
char nextChar = toParse[1];
968-
if (nextChar == '{') {
969-
// Double '{{' -> '{' (escaping).
970-
items.push_back(toParse.take_front(1));
971-
toParse = toParse.drop_front(2);
972-
continue;
973-
}
974-
if (nextChar == '}') {
975-
items.push_back(Placeholder{});
976-
toParse = toParse.drop_front(2);
977-
continue;
978-
}
979-
return emitOpError() << "expected '}' after unescaped '{'";
980-
}
981-
return items;
1008+
return success();
1009+
}
1010+
1011+
static void printVariadicTypeFmtArgs(AsmPrinter &p, ArrayRef<Type> params) {
1012+
llvm::interleaveComma(params, p, [&](Type type) { p.printType(type); });
1013+
}
1014+
1015+
FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
1016+
// Error checking is done in verify.
1017+
return ::parseFormatString(getValue(), getFmtArgs());
9821018
}
9831019

9841020
//===----------------------------------------------------------------------===//
@@ -1072,17 +1108,37 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
10721108

10731109
LogicalResult mlir::emitc::OpaqueType::verify(
10741110
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1075-
llvm::StringRef value) {
1111+
llvm::StringRef value, ArrayRef<Type> fmtArgs) {
10761112
if (value.empty()) {
10771113
return emitError() << "expected non empty string in !emitc.opaque type";
10781114
}
10791115
if (value.back() == '*') {
10801116
return emitError() << "pointer not allowed as outer type with "
10811117
"!emitc.opaque, use !emitc.ptr instead";
10821118
}
1119+
1120+
FailureOr<SmallVector<ReplacementItem>> fmt =
1121+
::parseFormatString(value, fmtArgs, emitError);
1122+
if (failed(fmt))
1123+
return failure();
1124+
1125+
size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) {
1126+
return std::holds_alternative<Placeholder>(item);
1127+
});
1128+
1129+
if (numPlaceholders != fmtArgs.size()) {
1130+
return emitError()
1131+
<< "requires operands for each placeholder in the format string";
1132+
}
1133+
10831134
return success();
10841135
}
10851136

1137+
FailureOr<SmallVector<ReplacementItem>> emitc::OpaqueType::parseFormatString() {
1138+
// Error checking is done in verify.
1139+
return ::parseFormatString(getValue(), getFmtArgs());
1140+
}
1141+
10861142
//===----------------------------------------------------------------------===//
10871143
// GlobalOp
10881144
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,14 @@ static LogicalResult printOperation(CppEmitter &emitter,
512512
emitc::VerbatimOp verbatimOp) {
513513
raw_ostream &os = emitter.ostream();
514514

515-
FailureOr<SmallVector<emitc::VerbatimOp::ReplacementItem>> items =
515+
FailureOr<SmallVector<ReplacementItem>> items =
516516
verbatimOp.parseFormatString();
517517
if (failed(items))
518518
return failure();
519519

520520
auto fmtArg = verbatimOp.getFmtArgs().begin();
521521

522-
for (emitc::VerbatimOp::ReplacementItem &item : *items) {
522+
for (ReplacementItem &item : *items) {
523523
if (auto *str = std::get_if<StringRef>(&item)) {
524524
os << *str;
525525
} else {
@@ -1728,6 +1728,23 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) {
17281728
if (auto tType = dyn_cast<TupleType>(type))
17291729
return emitTupleType(loc, tType.getTypes());
17301730
if (auto oType = dyn_cast<emitc::OpaqueType>(type)) {
1731+
FailureOr<SmallVector<ReplacementItem>> items = oType.parseFormatString();
1732+
if (failed(items))
1733+
return failure();
1734+
1735+
auto fmtArg = oType.getFmtArgs().begin();
1736+
for (ReplacementItem &item : *items) {
1737+
if (auto *str = std::get_if<StringRef>(&item)) {
1738+
os << *str;
1739+
} else {
1740+
if (failed(emitType(loc, *fmtArg++))) {
1741+
return failure();
1742+
}
1743+
}
1744+
}
1745+
1746+
return success();
1747+
17311748
os << oType.getValue();
17321749
return success();
17331750
}

mlir/test/Dialect/EmitC/invalid_types.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,34 @@ func.func @illegal_opaque_type_2() {
1414

1515
// -----
1616

17+
// expected-error @+1 {{expected non-function type}}
18+
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", "string">) {
19+
return
20+
}
21+
22+
// -----
23+
24+
// expected-error @+1 {{requires operands for each placeholder in the format string}}
25+
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"a", f32>) {
26+
return
27+
}
28+
29+
// -----
30+
31+
// expected-error @+1 {{requires operands for each placeholder in the format string}}
32+
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", f32>) {
33+
return
34+
}
35+
36+
// -----
37+
38+
// expected-error @+1 {{expected '}' after unescaped '{'}}
39+
func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{ ", i32>) {
40+
return
41+
}
42+
43+
// -----
44+
1745
func.func @illegal_array_missing_spec(
1846
// expected-error @+1 {{expected non-function type}}
1947
%arg0: !emitc.array<>) {

mlir/test/Dialect/EmitC/types.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ func.func @opaque_types() {
3838
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector<std::string>">]} : () -> ()
3939
// CHECK-NEXT: !emitc.opaque<"SmallVector<int*, 4>">
4040
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"SmallVector<int*, 4>">]} : () -> ()
41+
// CHECK-NEXT: !emitc.opaque<"{}", i32>
42+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", i32>>]} : () -> ()
43+
// CHECK-NEXT: !emitc.opaque<"{}, {}", i32, f32>]
44+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}, {}", i32, f32>>]} : () -> ()
45+
// CHECK-NEXT: !emitc.opaque<"{}"
46+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}">>]} : () -> ()
4147

4248
return
4349
}

mlir/test/Target/Cpp/types.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ func.func @opaque_types() {
1212
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"status_t">>]} : () -> ()
1313
// CHECK-NEXT: f<std::vector<std::string>>();
1414
emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector<std::string>">]} : () -> ()
15+
// CHECK: f<float>()
16+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", f32>>]} : () -> ()
17+
// CHECK: f<int16_t {>();
18+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{} {{", si16>>]} : () -> ()
19+
// CHECK: f<int8_t {>();
20+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{} {", i8>>]} : () -> ()
21+
// CHECK: f<status_t>();
22+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"{}", !emitc<opaque<"status_t">> >>]} : () -> ()
23+
// CHECK: f<top<nested<float>,int32_t>>();
24+
emitc.call_opaque "f"() {template_args = [!emitc<opaque<"top<{},{}>", !emitc<opaque<"nested<{}>", f32>>, i32>>]} : () -> ()
1525

1626
return
1727
}

0 commit comments

Comments
 (0)