Skip to content

Commit 5312063

Browse files
committed
[mlir:OpAsm] Factor out the common bits of (Op/Dialect)Asm(Parser/Printer)
This has a few benefits: * It allows for defining parsers/printer code blocks that can be shared between operations and attribute/types. * It removes the weird duplication of generic parser/printer hooks, which means that newly added hooks only require touching one class. Differential Revision: https://reviews.llvm.org/D110375
1 parent 62cc6b0 commit 5312063

File tree

8 files changed

+1014
-1378
lines changed

8 files changed

+1014
-1378
lines changed

mlir/include/mlir/IR/DialectImplementation.h

Lines changed: 8 additions & 347 deletions
Original file line numberDiff line numberDiff line change
@@ -15,375 +15,36 @@
1515
#define MLIR_IR_DIALECTIMPLEMENTATION_H
1616

1717
#include "mlir/IR/OpImplementation.h"
18-
#include "llvm/ADT/Twine.h"
19-
#include "llvm/Support/SMLoc.h"
20-
#include "llvm/Support/raw_ostream.h"
2118

2219
namespace mlir {
2320

24-
class Builder;
25-
2621
//===----------------------------------------------------------------------===//
2722
// DialectAsmPrinter
2823
//===----------------------------------------------------------------------===//
2924

3025
/// This is a pure-virtual base class that exposes the asmprinter hooks
3126
/// necessary to implement a custom printAttribute/printType() method on a
3227
/// dialect.
33-
class DialectAsmPrinter {
28+
class DialectAsmPrinter : public AsmPrinter {
3429
public:
35-
DialectAsmPrinter() {}
36-
virtual ~DialectAsmPrinter();
37-
virtual raw_ostream &getStream() const = 0;
38-
39-
/// Print the given attribute to the stream.
40-
virtual void printAttribute(Attribute attr) = 0;
41-
42-
/// Print the given attribute without its type. The corresponding parser must
43-
/// provide a valid type for the attribute.
44-
virtual void printAttributeWithoutType(Attribute attr) = 0;
45-
46-
/// Print the given floating point value in a stabilized form that can be
47-
/// roundtripped through the IR. This is the companion to the 'parseFloat'
48-
/// hook on the DialectAsmParser.
49-
virtual void printFloat(const APFloat &value) = 0;
50-
51-
/// Print the given type to the stream.
52-
virtual void printType(Type type) = 0;
53-
54-
private:
55-
DialectAsmPrinter(const DialectAsmPrinter &) = delete;
56-
void operator=(const DialectAsmPrinter &) = delete;
30+
using AsmPrinter::AsmPrinter;
31+
~DialectAsmPrinter() override;
5732
};
5833

59-
// Make the implementations convenient to use.
60-
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) {
61-
p.printAttribute(attr);
62-
return p;
63-
}
64-
65-
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p,
66-
const APFloat &value) {
67-
p.printFloat(value);
68-
return p;
69-
}
70-
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) {
71-
return p << APFloat(value);
72-
}
73-
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) {
74-
return p << APFloat(value);
75-
}
76-
77-
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) {
78-
p.printType(type);
79-
return p;
80-
}
81-
82-
// Support printing anything that isn't convertible to one of the above types,
83-
// even if it isn't exactly one of them. For example, we want to print
84-
// FunctionType with the Type version above, not have it match this.
85-
template <typename T, typename std::enable_if<
86-
!std::is_convertible<T &, Attribute &>::value &&
87-
!std::is_convertible<T &, Type &>::value &&
88-
!std::is_convertible<T &, APFloat &>::value &&
89-
!llvm::is_one_of<T, double, float>::value,
90-
T>::type * = nullptr>
91-
inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) {
92-
p.getStream() << other;
93-
return p;
94-
}
95-
9634
//===----------------------------------------------------------------------===//
9735
// DialectAsmParser
9836
//===----------------------------------------------------------------------===//
9937

100-
/// The DialectAsmParser has methods for interacting with the asm parser:
101-
/// parsing things from it, emitting errors etc. It has an intentionally
102-
/// high-level API that is designed to reduce/constrain syntax innovation in
103-
/// individual attributes or types.
104-
class DialectAsmParser {
38+
/// The DialectAsmParser has methods for interacting with the asm parser when
39+
/// parsing attributes and types.
40+
class DialectAsmParser : public AsmParser {
10541
public:
106-
virtual ~DialectAsmParser();
107-
108-
/// Emit a diagnostic at the specified location and return failure.
109-
virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
110-
const Twine &message = {}) = 0;
111-
112-
/// Return a builder which provides useful access to MLIRContext, global
113-
/// objects like types and attributes.
114-
virtual Builder &getBuilder() const = 0;
115-
116-
/// Get the location of the next token and store it into the argument. This
117-
/// always succeeds.
118-
virtual llvm::SMLoc getCurrentLocation() = 0;
119-
ParseResult getCurrentLocation(llvm::SMLoc *loc) {
120-
*loc = getCurrentLocation();
121-
return success();
122-
}
123-
124-
/// Return the location of the original name token.
125-
virtual llvm::SMLoc getNameLoc() const = 0;
126-
127-
/// Re-encode the given source location as an MLIR location and return it.
128-
/// Note: This method should only be used when a `Location` is necessary, as
129-
/// the encoding process is not efficient. In other cases a more suitable
130-
/// alternative should be used, such as the `getChecked` methods defined
131-
/// below.
132-
virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
42+
using AsmParser::AsmParser;
43+
~DialectAsmParser() override;
13344

13445
/// Returns the full specification of the symbol being parsed. This allows for
13546
/// using a separate parser if necessary.
13647
virtual StringRef getFullSymbolSpec() const = 0;
137-
138-
// These methods emit an error and return failure or success. This allows
139-
// these to be chained together into a linear sequence of || expressions in
140-
// many cases.
141-
142-
/// Parse a floating point value from the stream.
143-
virtual ParseResult parseFloat(double &result) = 0;
144-
145-
/// Parse an integer value from the stream.
146-
template <typename IntT>
147-
ParseResult parseInteger(IntT &result) {
148-
auto loc = getCurrentLocation();
149-
OptionalParseResult parseResult = parseOptionalInteger(result);
150-
if (!parseResult.hasValue())
151-
return emitError(loc, "expected integer value");
152-
return *parseResult;
153-
}
154-
155-
/// Parse an optional integer value from the stream.
156-
virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
157-
158-
template <typename IntT>
159-
OptionalParseResult parseOptionalInteger(IntT &result) {
160-
auto loc = getCurrentLocation();
161-
162-
// Parse the unsigned variant.
163-
APInt uintResult;
164-
OptionalParseResult parseResult = parseOptionalInteger(uintResult);
165-
if (!parseResult.hasValue() || failed(*parseResult))
166-
return parseResult;
167-
168-
// Try to convert to the provided integer type. sextOrTrunc is correct even
169-
// for unsigned types because parseOptionalInteger ensures the sign bit is
170-
// zero for non-negated integers.
171-
result =
172-
(IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue();
173-
if (APInt(uintResult.getBitWidth(), result) != uintResult)
174-
return emitError(loc, "integer value too large");
175-
return success();
176-
}
177-
178-
/// Invoke the `getChecked` method of the given Attribute or Type class, using
179-
/// the provided location to emit errors in the case of failure. Note that
180-
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
181-
/// context parameter.
182-
template <typename T, typename... ParamsT>
183-
T getChecked(llvm::SMLoc loc, ParamsT &&... params) {
184-
return T::getChecked([&] { return emitError(loc); },
185-
std::forward<ParamsT>(params)...);
186-
}
187-
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
188-
/// errors.
189-
template <typename T, typename... ParamsT>
190-
T getChecked(ParamsT &&... params) {
191-
return T::getChecked([&] { return emitError(getNameLoc()); },
192-
std::forward<ParamsT>(params)...);
193-
}
194-
195-
//===--------------------------------------------------------------------===//
196-
// Token Parsing
197-
//===--------------------------------------------------------------------===//
198-
199-
/// Parse a '->' token.
200-
virtual ParseResult parseArrow() = 0;
201-
202-
/// Parse a '->' token if present
203-
virtual ParseResult parseOptionalArrow() = 0;
204-
205-
/// Parse a '{' token.
206-
virtual ParseResult parseLBrace() = 0;
207-
208-
/// Parse a '{' token if present
209-
virtual ParseResult parseOptionalLBrace() = 0;
210-
211-
/// Parse a `}` token.
212-
virtual ParseResult parseRBrace() = 0;
213-
214-
/// Parse a `}` token if present
215-
virtual ParseResult parseOptionalRBrace() = 0;
216-
217-
/// Parse a `:` token.
218-
virtual ParseResult parseColon() = 0;
219-
220-
/// Parse a `:` token if present.
221-
virtual ParseResult parseOptionalColon() = 0;
222-
223-
/// Parse a `,` token.
224-
virtual ParseResult parseComma() = 0;
225-
226-
/// Parse a `,` token if present.
227-
virtual ParseResult parseOptionalComma() = 0;
228-
229-
/// Parse a `=` token.
230-
virtual ParseResult parseEqual() = 0;
231-
232-
/// Parse a `=` token if present.
233-
virtual ParseResult parseOptionalEqual() = 0;
234-
235-
/// Parse a quoted string token.
236-
ParseResult parseString(std::string *string) {
237-
auto loc = getCurrentLocation();
238-
if (parseOptionalString(string))
239-
return emitError(loc, "expected string");
240-
return success();
241-
}
242-
243-
/// Parse a quoted string token if present.
244-
virtual ParseResult parseOptionalString(std::string *string) = 0;
245-
246-
/// Parse a given keyword.
247-
ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
248-
auto loc = getCurrentLocation();
249-
if (parseOptionalKeyword(keyword))
250-
return emitError(loc, "expected '") << keyword << "'" << msg;
251-
return success();
252-
}
253-
254-
/// Parse a keyword into 'keyword'.
255-
ParseResult parseKeyword(StringRef *keyword) {
256-
auto loc = getCurrentLocation();
257-
if (parseOptionalKeyword(keyword))
258-
return emitError(loc, "expected valid keyword");
259-
return success();
260-
}
261-
262-
/// Parse the given keyword if present.
263-
virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
264-
265-
/// Parse a keyword, if present, into 'keyword'.
266-
virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
267-
268-
/// Parse a '<' token.
269-
virtual ParseResult parseLess() = 0;
270-
271-
/// Parse a `<` token if present.
272-
virtual ParseResult parseOptionalLess() = 0;
273-
274-
/// Parse a '>' token.
275-
virtual ParseResult parseGreater() = 0;
276-
277-
/// Parse a `>` token if present.
278-
virtual ParseResult parseOptionalGreater() = 0;
279-
280-
/// Parse a `(` token.
281-
virtual ParseResult parseLParen() = 0;
282-
283-
/// Parse a `(` token if present.
284-
virtual ParseResult parseOptionalLParen() = 0;
285-
286-
/// Parse a `)` token.
287-
virtual ParseResult parseRParen() = 0;
288-
289-
/// Parse a `)` token if present.
290-
virtual ParseResult parseOptionalRParen() = 0;
291-
292-
/// Parse a `[` token.
293-
virtual ParseResult parseLSquare() = 0;
294-
295-
/// Parse a `[` token if present.
296-
virtual ParseResult parseOptionalLSquare() = 0;
297-
298-
/// Parse a `]` token.
299-
virtual ParseResult parseRSquare() = 0;
300-
301-
/// Parse a `]` token if present.
302-
virtual ParseResult parseOptionalRSquare() = 0;
303-
304-
/// Parse a `...` token if present;
305-
virtual ParseResult parseOptionalEllipsis() = 0;
306-
307-
/// Parse a `?` token.
308-
virtual ParseResult parseOptionalQuestion() = 0;
309-
310-
/// Parse a `*` token.
311-
virtual ParseResult parseOptionalStar() = 0;
312-
313-
//===--------------------------------------------------------------------===//
314-
// Attribute Parsing
315-
//===--------------------------------------------------------------------===//
316-
317-
/// Parse an arbitrary attribute and return it in result.
318-
virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
319-
320-
/// Parse an attribute of a specific kind and type.
321-
template <typename AttrType>
322-
ParseResult parseAttribute(AttrType &result, Type type = {}) {
323-
llvm::SMLoc loc = getCurrentLocation();
324-
325-
// Parse any kind of attribute.
326-
Attribute attr;
327-
if (parseAttribute(attr, type))
328-
return failure();
329-
330-
// Check for the right kind of attribute.
331-
result = attr.dyn_cast<AttrType>();
332-
if (!result)
333-
return emitError(loc, "invalid kind of attribute specified");
334-
return success();
335-
}
336-
337-
/// Parse an affine map instance into 'map'.
338-
virtual ParseResult parseAffineMap(AffineMap &map) = 0;
339-
340-
/// Parse an integer set instance into 'set'.
341-
virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
342-
343-
//===--------------------------------------------------------------------===//
344-
// Type Parsing
345-
//===--------------------------------------------------------------------===//
346-
347-
/// Parse a type.
348-
virtual ParseResult parseType(Type &result) = 0;
349-
350-
/// Parse a type of a specific kind, e.g. a FunctionType.
351-
template <typename TypeType>
352-
ParseResult parseType(TypeType &result) {
353-
llvm::SMLoc loc = getCurrentLocation();
354-
355-
// Parse any kind of type.
356-
Type type;
357-
if (parseType(type))
358-
return failure();
359-
360-
// Check for the right kind of attribute.
361-
result = type.dyn_cast<TypeType>();
362-
if (!result)
363-
return emitError(loc, "invalid kind of type specified");
364-
return success();
365-
}
366-
367-
/// Parse a type if present.
368-
virtual OptionalParseResult parseOptionalType(Type &result) = 0;
369-
370-
/// Parse a 'x' separated dimension list. This populates the dimension list,
371-
/// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
372-
/// `?` otherwise.
373-
///
374-
/// dimension-list ::= (dimension `x`)*
375-
/// dimension ::= `?` | integer
376-
///
377-
/// When `allowDynamic` is not set, this is used to parse:
378-
///
379-
/// static-dimension-list ::= (integer `x`)*
380-
virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
381-
bool allowDynamic = true) = 0;
382-
383-
/// Parse an 'x' token in a dimension list, handling the case where the x is
384-
/// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
385-
/// next token.
386-
virtual ParseResult parseXInDimensionList() = 0;
38748
};
38849

38950
} // end namespace mlir

0 commit comments

Comments
 (0)