|
15 | 15 | #define MLIR_IR_DIALECTIMPLEMENTATION_H
|
16 | 16 |
|
17 | 17 | #include "mlir/IR/OpImplementation.h"
|
18 |
| -#include "llvm/ADT/Twine.h" |
19 |
| -#include "llvm/Support/SMLoc.h" |
20 |
| -#include "llvm/Support/raw_ostream.h" |
21 | 18 |
|
22 | 19 | namespace mlir {
|
23 | 20 |
|
24 |
| -class Builder; |
25 |
| - |
26 | 21 | //===----------------------------------------------------------------------===//
|
27 | 22 | // DialectAsmPrinter
|
28 | 23 | //===----------------------------------------------------------------------===//
|
29 | 24 |
|
30 | 25 | /// This is a pure-virtual base class that exposes the asmprinter hooks
|
31 | 26 | /// necessary to implement a custom printAttribute/printType() method on a
|
32 | 27 | /// dialect.
|
33 |
| -class DialectAsmPrinter { |
| 28 | +class DialectAsmPrinter : public AsmPrinter { |
34 | 29 | public:
|
35 |
| - DialectAsmPrinter() {} |
36 |
| - virtual ~DialectAsmPrinter(); |
37 |
| - virtual raw_ostream &getStream() const = 0; |
38 |
| - |
39 |
| - /// Print the given attribute to the stream. |
40 |
| - virtual void printAttribute(Attribute attr) = 0; |
41 |
| - |
42 |
| - /// Print the given attribute without its type. The corresponding parser must |
43 |
| - /// provide a valid type for the attribute. |
44 |
| - virtual void printAttributeWithoutType(Attribute attr) = 0; |
45 |
| - |
46 |
| - /// Print the given floating point value in a stabilized form that can be |
47 |
| - /// roundtripped through the IR. This is the companion to the 'parseFloat' |
48 |
| - /// hook on the DialectAsmParser. |
49 |
| - virtual void printFloat(const APFloat &value) = 0; |
50 |
| - |
51 |
| - /// Print the given type to the stream. |
52 |
| - virtual void printType(Type type) = 0; |
53 |
| - |
54 |
| -private: |
55 |
| - DialectAsmPrinter(const DialectAsmPrinter &) = delete; |
56 |
| - void operator=(const DialectAsmPrinter &) = delete; |
| 30 | + using AsmPrinter::AsmPrinter; |
| 31 | + ~DialectAsmPrinter() override; |
57 | 32 | };
|
58 | 33 |
|
59 |
| -// Make the implementations convenient to use. |
60 |
| -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Attribute attr) { |
61 |
| - p.printAttribute(attr); |
62 |
| - return p; |
63 |
| -} |
64 |
| - |
65 |
| -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, |
66 |
| - const APFloat &value) { |
67 |
| - p.printFloat(value); |
68 |
| - return p; |
69 |
| -} |
70 |
| -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, float value) { |
71 |
| - return p << APFloat(value); |
72 |
| -} |
73 |
| -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, double value) { |
74 |
| - return p << APFloat(value); |
75 |
| -} |
76 |
| - |
77 |
| -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, Type type) { |
78 |
| - p.printType(type); |
79 |
| - return p; |
80 |
| -} |
81 |
| - |
82 |
| -// Support printing anything that isn't convertible to one of the above types, |
83 |
| -// even if it isn't exactly one of them. For example, we want to print |
84 |
| -// FunctionType with the Type version above, not have it match this. |
85 |
| -template <typename T, typename std::enable_if< |
86 |
| - !std::is_convertible<T &, Attribute &>::value && |
87 |
| - !std::is_convertible<T &, Type &>::value && |
88 |
| - !std::is_convertible<T &, APFloat &>::value && |
89 |
| - !llvm::is_one_of<T, double, float>::value, |
90 |
| - T>::type * = nullptr> |
91 |
| -inline DialectAsmPrinter &operator<<(DialectAsmPrinter &p, const T &other) { |
92 |
| - p.getStream() << other; |
93 |
| - return p; |
94 |
| -} |
95 |
| - |
96 | 34 | //===----------------------------------------------------------------------===//
|
97 | 35 | // DialectAsmParser
|
98 | 36 | //===----------------------------------------------------------------------===//
|
99 | 37 |
|
100 |
| -/// The DialectAsmParser has methods for interacting with the asm parser: |
101 |
| -/// parsing things from it, emitting errors etc. It has an intentionally |
102 |
| -/// high-level API that is designed to reduce/constrain syntax innovation in |
103 |
| -/// individual attributes or types. |
104 |
| -class DialectAsmParser { |
| 38 | +/// The DialectAsmParser has methods for interacting with the asm parser when |
| 39 | +/// parsing attributes and types. |
| 40 | +class DialectAsmParser : public AsmParser { |
105 | 41 | public:
|
106 |
| - virtual ~DialectAsmParser(); |
107 |
| - |
108 |
| - /// Emit a diagnostic at the specified location and return failure. |
109 |
| - virtual InFlightDiagnostic emitError(llvm::SMLoc loc, |
110 |
| - const Twine &message = {}) = 0; |
111 |
| - |
112 |
| - /// Return a builder which provides useful access to MLIRContext, global |
113 |
| - /// objects like types and attributes. |
114 |
| - virtual Builder &getBuilder() const = 0; |
115 |
| - |
116 |
| - /// Get the location of the next token and store it into the argument. This |
117 |
| - /// always succeeds. |
118 |
| - virtual llvm::SMLoc getCurrentLocation() = 0; |
119 |
| - ParseResult getCurrentLocation(llvm::SMLoc *loc) { |
120 |
| - *loc = getCurrentLocation(); |
121 |
| - return success(); |
122 |
| - } |
123 |
| - |
124 |
| - /// Return the location of the original name token. |
125 |
| - virtual llvm::SMLoc getNameLoc() const = 0; |
126 |
| - |
127 |
| - /// Re-encode the given source location as an MLIR location and return it. |
128 |
| - /// Note: This method should only be used when a `Location` is necessary, as |
129 |
| - /// the encoding process is not efficient. In other cases a more suitable |
130 |
| - /// alternative should be used, such as the `getChecked` methods defined |
131 |
| - /// below. |
132 |
| - virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; |
| 42 | + using AsmParser::AsmParser; |
| 43 | + ~DialectAsmParser() override; |
133 | 44 |
|
134 | 45 | /// Returns the full specification of the symbol being parsed. This allows for
|
135 | 46 | /// using a separate parser if necessary.
|
136 | 47 | virtual StringRef getFullSymbolSpec() const = 0;
|
137 |
| - |
138 |
| - // These methods emit an error and return failure or success. This allows |
139 |
| - // these to be chained together into a linear sequence of || expressions in |
140 |
| - // many cases. |
141 |
| - |
142 |
| - /// Parse a floating point value from the stream. |
143 |
| - virtual ParseResult parseFloat(double &result) = 0; |
144 |
| - |
145 |
| - /// Parse an integer value from the stream. |
146 |
| - template <typename IntT> |
147 |
| - ParseResult parseInteger(IntT &result) { |
148 |
| - auto loc = getCurrentLocation(); |
149 |
| - OptionalParseResult parseResult = parseOptionalInteger(result); |
150 |
| - if (!parseResult.hasValue()) |
151 |
| - return emitError(loc, "expected integer value"); |
152 |
| - return *parseResult; |
153 |
| - } |
154 |
| - |
155 |
| - /// Parse an optional integer value from the stream. |
156 |
| - virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; |
157 |
| - |
158 |
| - template <typename IntT> |
159 |
| - OptionalParseResult parseOptionalInteger(IntT &result) { |
160 |
| - auto loc = getCurrentLocation(); |
161 |
| - |
162 |
| - // Parse the unsigned variant. |
163 |
| - APInt uintResult; |
164 |
| - OptionalParseResult parseResult = parseOptionalInteger(uintResult); |
165 |
| - if (!parseResult.hasValue() || failed(*parseResult)) |
166 |
| - return parseResult; |
167 |
| - |
168 |
| - // Try to convert to the provided integer type. sextOrTrunc is correct even |
169 |
| - // for unsigned types because parseOptionalInteger ensures the sign bit is |
170 |
| - // zero for non-negated integers. |
171 |
| - result = |
172 |
| - (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue(); |
173 |
| - if (APInt(uintResult.getBitWidth(), result) != uintResult) |
174 |
| - return emitError(loc, "integer value too large"); |
175 |
| - return success(); |
176 |
| - } |
177 |
| - |
178 |
| - /// Invoke the `getChecked` method of the given Attribute or Type class, using |
179 |
| - /// the provided location to emit errors in the case of failure. Note that |
180 |
| - /// unlike `OpBuilder::getType`, this method does not implicitly insert a |
181 |
| - /// context parameter. |
182 |
| - template <typename T, typename... ParamsT> |
183 |
| - T getChecked(llvm::SMLoc loc, ParamsT &&... params) { |
184 |
| - return T::getChecked([&] { return emitError(loc); }, |
185 |
| - std::forward<ParamsT>(params)...); |
186 |
| - } |
187 |
| - /// A variant of `getChecked` that uses the result of `getNameLoc` to emit |
188 |
| - /// errors. |
189 |
| - template <typename T, typename... ParamsT> |
190 |
| - T getChecked(ParamsT &&... params) { |
191 |
| - return T::getChecked([&] { return emitError(getNameLoc()); }, |
192 |
| - std::forward<ParamsT>(params)...); |
193 |
| - } |
194 |
| - |
195 |
| - //===--------------------------------------------------------------------===// |
196 |
| - // Token Parsing |
197 |
| - //===--------------------------------------------------------------------===// |
198 |
| - |
199 |
| - /// Parse a '->' token. |
200 |
| - virtual ParseResult parseArrow() = 0; |
201 |
| - |
202 |
| - /// Parse a '->' token if present |
203 |
| - virtual ParseResult parseOptionalArrow() = 0; |
204 |
| - |
205 |
| - /// Parse a '{' token. |
206 |
| - virtual ParseResult parseLBrace() = 0; |
207 |
| - |
208 |
| - /// Parse a '{' token if present |
209 |
| - virtual ParseResult parseOptionalLBrace() = 0; |
210 |
| - |
211 |
| - /// Parse a `}` token. |
212 |
| - virtual ParseResult parseRBrace() = 0; |
213 |
| - |
214 |
| - /// Parse a `}` token if present |
215 |
| - virtual ParseResult parseOptionalRBrace() = 0; |
216 |
| - |
217 |
| - /// Parse a `:` token. |
218 |
| - virtual ParseResult parseColon() = 0; |
219 |
| - |
220 |
| - /// Parse a `:` token if present. |
221 |
| - virtual ParseResult parseOptionalColon() = 0; |
222 |
| - |
223 |
| - /// Parse a `,` token. |
224 |
| - virtual ParseResult parseComma() = 0; |
225 |
| - |
226 |
| - /// Parse a `,` token if present. |
227 |
| - virtual ParseResult parseOptionalComma() = 0; |
228 |
| - |
229 |
| - /// Parse a `=` token. |
230 |
| - virtual ParseResult parseEqual() = 0; |
231 |
| - |
232 |
| - /// Parse a `=` token if present. |
233 |
| - virtual ParseResult parseOptionalEqual() = 0; |
234 |
| - |
235 |
| - /// Parse a quoted string token. |
236 |
| - ParseResult parseString(std::string *string) { |
237 |
| - auto loc = getCurrentLocation(); |
238 |
| - if (parseOptionalString(string)) |
239 |
| - return emitError(loc, "expected string"); |
240 |
| - return success(); |
241 |
| - } |
242 |
| - |
243 |
| - /// Parse a quoted string token if present. |
244 |
| - virtual ParseResult parseOptionalString(std::string *string) = 0; |
245 |
| - |
246 |
| - /// Parse a given keyword. |
247 |
| - ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { |
248 |
| - auto loc = getCurrentLocation(); |
249 |
| - if (parseOptionalKeyword(keyword)) |
250 |
| - return emitError(loc, "expected '") << keyword << "'" << msg; |
251 |
| - return success(); |
252 |
| - } |
253 |
| - |
254 |
| - /// Parse a keyword into 'keyword'. |
255 |
| - ParseResult parseKeyword(StringRef *keyword) { |
256 |
| - auto loc = getCurrentLocation(); |
257 |
| - if (parseOptionalKeyword(keyword)) |
258 |
| - return emitError(loc, "expected valid keyword"); |
259 |
| - return success(); |
260 |
| - } |
261 |
| - |
262 |
| - /// Parse the given keyword if present. |
263 |
| - virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
264 |
| - |
265 |
| - /// Parse a keyword, if present, into 'keyword'. |
266 |
| - virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
267 |
| - |
268 |
| - /// Parse a '<' token. |
269 |
| - virtual ParseResult parseLess() = 0; |
270 |
| - |
271 |
| - /// Parse a `<` token if present. |
272 |
| - virtual ParseResult parseOptionalLess() = 0; |
273 |
| - |
274 |
| - /// Parse a '>' token. |
275 |
| - virtual ParseResult parseGreater() = 0; |
276 |
| - |
277 |
| - /// Parse a `>` token if present. |
278 |
| - virtual ParseResult parseOptionalGreater() = 0; |
279 |
| - |
280 |
| - /// Parse a `(` token. |
281 |
| - virtual ParseResult parseLParen() = 0; |
282 |
| - |
283 |
| - /// Parse a `(` token if present. |
284 |
| - virtual ParseResult parseOptionalLParen() = 0; |
285 |
| - |
286 |
| - /// Parse a `)` token. |
287 |
| - virtual ParseResult parseRParen() = 0; |
288 |
| - |
289 |
| - /// Parse a `)` token if present. |
290 |
| - virtual ParseResult parseOptionalRParen() = 0; |
291 |
| - |
292 |
| - /// Parse a `[` token. |
293 |
| - virtual ParseResult parseLSquare() = 0; |
294 |
| - |
295 |
| - /// Parse a `[` token if present. |
296 |
| - virtual ParseResult parseOptionalLSquare() = 0; |
297 |
| - |
298 |
| - /// Parse a `]` token. |
299 |
| - virtual ParseResult parseRSquare() = 0; |
300 |
| - |
301 |
| - /// Parse a `]` token if present. |
302 |
| - virtual ParseResult parseOptionalRSquare() = 0; |
303 |
| - |
304 |
| - /// Parse a `...` token if present; |
305 |
| - virtual ParseResult parseOptionalEllipsis() = 0; |
306 |
| - |
307 |
| - /// Parse a `?` token. |
308 |
| - virtual ParseResult parseOptionalQuestion() = 0; |
309 |
| - |
310 |
| - /// Parse a `*` token. |
311 |
| - virtual ParseResult parseOptionalStar() = 0; |
312 |
| - |
313 |
| - //===--------------------------------------------------------------------===// |
314 |
| - // Attribute Parsing |
315 |
| - //===--------------------------------------------------------------------===// |
316 |
| - |
317 |
| - /// Parse an arbitrary attribute and return it in result. |
318 |
| - virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
319 |
| - |
320 |
| - /// Parse an attribute of a specific kind and type. |
321 |
| - template <typename AttrType> |
322 |
| - ParseResult parseAttribute(AttrType &result, Type type = {}) { |
323 |
| - llvm::SMLoc loc = getCurrentLocation(); |
324 |
| - |
325 |
| - // Parse any kind of attribute. |
326 |
| - Attribute attr; |
327 |
| - if (parseAttribute(attr, type)) |
328 |
| - return failure(); |
329 |
| - |
330 |
| - // Check for the right kind of attribute. |
331 |
| - result = attr.dyn_cast<AttrType>(); |
332 |
| - if (!result) |
333 |
| - return emitError(loc, "invalid kind of attribute specified"); |
334 |
| - return success(); |
335 |
| - } |
336 |
| - |
337 |
| - /// Parse an affine map instance into 'map'. |
338 |
| - virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
339 |
| - |
340 |
| - /// Parse an integer set instance into 'set'. |
341 |
| - virtual ParseResult printIntegerSet(IntegerSet &set) = 0; |
342 |
| - |
343 |
| - //===--------------------------------------------------------------------===// |
344 |
| - // Type Parsing |
345 |
| - //===--------------------------------------------------------------------===// |
346 |
| - |
347 |
| - /// Parse a type. |
348 |
| - virtual ParseResult parseType(Type &result) = 0; |
349 |
| - |
350 |
| - /// Parse a type of a specific kind, e.g. a FunctionType. |
351 |
| - template <typename TypeType> |
352 |
| - ParseResult parseType(TypeType &result) { |
353 |
| - llvm::SMLoc loc = getCurrentLocation(); |
354 |
| - |
355 |
| - // Parse any kind of type. |
356 |
| - Type type; |
357 |
| - if (parseType(type)) |
358 |
| - return failure(); |
359 |
| - |
360 |
| - // Check for the right kind of attribute. |
361 |
| - result = type.dyn_cast<TypeType>(); |
362 |
| - if (!result) |
363 |
| - return emitError(loc, "invalid kind of type specified"); |
364 |
| - return success(); |
365 |
| - } |
366 |
| - |
367 |
| - /// Parse a type if present. |
368 |
| - virtual OptionalParseResult parseOptionalType(Type &result) = 0; |
369 |
| - |
370 |
| - /// Parse a 'x' separated dimension list. This populates the dimension list, |
371 |
| - /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on |
372 |
| - /// `?` otherwise. |
373 |
| - /// |
374 |
| - /// dimension-list ::= (dimension `x`)* |
375 |
| - /// dimension ::= `?` | integer |
376 |
| - /// |
377 |
| - /// When `allowDynamic` is not set, this is used to parse: |
378 |
| - /// |
379 |
| - /// static-dimension-list ::= (integer `x`)* |
380 |
| - virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
381 |
| - bool allowDynamic = true) = 0; |
382 |
| - |
383 |
| - /// Parse an 'x' token in a dimension list, handling the case where the x is |
384 |
| - /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the |
385 |
| - /// next token. |
386 |
| - virtual ParseResult parseXInDimensionList() = 0; |
387 | 48 | };
|
388 | 49 |
|
389 | 50 | } // end namespace mlir
|
|
0 commit comments