16
16
17
17
#include " mlir/Support/LLVM.h"
18
18
#include " mlir/Support/LogicalResult.h"
19
+ #include " llvm/ADT/FunctionExtras.h"
19
20
#include " llvm/ADT/StringRef.h"
20
21
#include " llvm/Support/CommandLine.h"
21
22
#include " llvm/Support/Compiler.h"
22
23
#include < memory>
23
24
24
25
namespace mlir {
25
26
namespace detail {
27
+ namespace pass_options {
28
+ // / Parse a string containing a list of comma-delimited elements, invoking the
29
+ // / given parser for each sub-element and passing them to the provided
30
+ // / element-append functor.
31
+ LogicalResult
32
+ parseCommaSeparatedList (llvm::cl::Option &opt, StringRef argName,
33
+ StringRef optionStr,
34
+ function_ref<LogicalResult(StringRef)> elementParseFn);
35
+ template <typename ElementParser, typename ElementAppendFn>
36
+ LogicalResult parseCommaSeparatedList (llvm::cl::Option &opt, StringRef argName,
37
+ StringRef optionStr,
38
+ ElementParser &elementParser,
39
+ ElementAppendFn &&appendFn) {
40
+ return parseCommaSeparatedList (
41
+ opt, argName, optionStr, [&](StringRef valueStr) {
42
+ typename ElementParser::parser_data_type value = {};
43
+ if (elementParser.parse (opt, argName, valueStr, value))
44
+ return failure ();
45
+ appendFn (value);
46
+ return success ();
47
+ });
48
+ }
49
+
50
+ // / Trait used to detect if a type has a operator<< method.
51
+ template <typename T>
52
+ using has_stream_operator_trait =
53
+ decltype (std::declval<raw_ostream &>() << std::declval<T>());
54
+ template <typename T>
55
+ using has_stream_operator = llvm::is_detected<has_stream_operator_trait, T>;
56
+
57
+ // / Utility methods for printing option values.
58
+ template <typename ParserT>
59
+ static void printOptionValue (raw_ostream &os, const bool &value) {
60
+ os << (value ? StringRef (" true" ) : StringRef (" false" ));
61
+ }
62
+ template <typename ParserT, typename DataT>
63
+ static std::enable_if_t <has_stream_operator<DataT>::value>
64
+ printOptionValue (raw_ostream &os, const DataT &value) {
65
+ os << value;
66
+ }
67
+ template <typename ParserT, typename DataT>
68
+ static std::enable_if_t <!has_stream_operator<DataT>::value>
69
+ printOptionValue (raw_ostream &os, const DataT &value) {
70
+ // If the value can't be streamed, fallback to checking for a print in the
71
+ // parser.
72
+ ParserT::print (os, value);
73
+ }
74
+ } // namespace pass_options
75
+
26
76
// / Base container class and manager for all pass options.
27
77
class PassOptions : protected llvm ::cl::SubCommand {
28
78
private:
@@ -85,11 +135,7 @@ class PassOptions : protected llvm::cl::SubCommand {
85
135
}
86
136
template <typename DataT, typename ParserT>
87
137
static void printValue (raw_ostream &os, ParserT &parser, const DataT &value) {
88
- os << value;
89
- }
90
- template <typename ParserT>
91
- static void printValue (raw_ostream &os, ParserT &parser, const bool &value) {
92
- os << (value ? StringRef (" true" ) : StringRef (" false" ));
138
+ detail::pass_options::printOptionValue<ParserT>(os, value);
93
139
}
94
140
95
141
public:
@@ -149,22 +195,27 @@ class PassOptions : protected llvm::cl::SubCommand {
149
195
};
150
196
151
197
// / This class represents a specific pass option that contains a list of
152
- // / values of the provided data type.
198
+ // / values of the provided data type. The elements within the textual form of
199
+ // / this option are parsed assuming they are comma-separated. Delimited
200
+ // / sub-ranges within individual elements of the list may contain commas that
201
+ // / are not treated as separators for the top-level list.
153
202
template <typename DataType, typename OptionParser = OptionParser<DataType>>
154
203
class ListOption
155
204
: public llvm::cl::list<DataType, /* StorageClass=*/ bool , OptionParser>,
156
205
public OptionBase {
157
206
public:
158
207
template <typename ... Args>
159
- ListOption (PassOptions &parent, StringRef arg, Args &&... args)
208
+ ListOption (PassOptions &parent, StringRef arg, Args &&...args)
160
209
: llvm::cl::list<DataType, /* StorageClass=*/ bool , OptionParser>(
161
- arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
210
+ arg, llvm::cl::sub(parent), std::forward<Args>(args)...),
211
+ elementParser (*this ) {
162
212
assert (!this ->isPositional () && !this ->isSink () &&
163
213
" sink and positional options are not supported" );
214
+ assert (!(this ->getMiscFlags () & llvm::cl::MiscFlags::CommaSeparated) &&
215
+ " ListOption is implicitly comma separated, specifying "
216
+ " CommaSeparated is extraneous" );
164
217
parent.options .push_back (this );
165
-
166
- // Set a callback to track if this option has a value.
167
- this ->setCallback ([this ](const auto &) { this ->optHasValue = true ; });
218
+ elementParser.initialize ();
168
219
}
169
220
~ListOption () override = default ;
170
221
ListOption<DataType, OptionParser> &
@@ -174,6 +225,14 @@ class PassOptions : protected llvm::cl::SubCommand {
174
225
return *this ;
175
226
}
176
227
228
+ bool handleOccurrence (unsigned pos, StringRef argName,
229
+ StringRef arg) override {
230
+ this ->optHasValue = true ;
231
+ return failed (detail::pass_options::parseCommaSeparatedList (
232
+ *this , argName, arg, elementParser,
233
+ [&](const DataType &value) { this ->addValue (value); }));
234
+ }
235
+
177
236
// / Allow assigning from an ArrayRef.
178
237
ListOption<DataType, OptionParser> &operator =(ArrayRef<DataType> values) {
179
238
((std::vector<DataType> &)*this ).assign (values.begin (), values.end ());
@@ -211,6 +270,9 @@ class PassOptions : protected llvm::cl::SubCommand {
211
270
void copyValueFrom (const OptionBase &other) final {
212
271
*this = static_cast <const ListOption<DataType, OptionParser> &>(other);
213
272
}
273
+
274
+ // / The parser to use for parsing the list elements.
275
+ OptionParser elementParser;
214
276
};
215
277
216
278
PassOptions () = default ;
@@ -255,9 +317,7 @@ class PassOptions : protected llvm::cl::SubCommand {
255
317
// / Usage:
256
318
// /
257
319
// / struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
258
- // / ListOption<int> someListFlag{
259
- // / *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated,
260
- // / llvm::cl::desc("...")};
320
+ // / ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
261
321
// / };
262
322
template <typename T> class PassPipelineOptions : public detail ::PassOptions {
263
323
public:
@@ -278,5 +338,77 @@ struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
278
338
279
339
} // namespace mlir
280
340
341
+ // ===----------------------------------------------------------------------===//
342
+ // MLIR Options
343
+ // ===----------------------------------------------------------------------===//
344
+
345
+ namespace llvm {
346
+ namespace cl {
347
+ // ===----------------------------------------------------------------------===//
348
+ // std::vector+SmallVector
349
+
350
+ namespace detail {
351
+ template <typename VectorT, typename ElementT>
352
+ class VectorParserBase : public basic_parser_impl {
353
+ public:
354
+ VectorParserBase (Option &opt) : basic_parser_impl(opt), elementParser(opt) {}
355
+
356
+ using parser_data_type = VectorT;
357
+
358
+ bool parse (Option &opt, StringRef argName, StringRef arg,
359
+ parser_data_type &vector) {
360
+ if (!arg.consume_front (" [" ) || !arg.consume_back (" ]" )) {
361
+ return opt.error (" expected vector option to be wrapped with '[]'" ,
362
+ argName);
363
+ }
364
+
365
+ return failed (mlir::detail::pass_options::parseCommaSeparatedList (
366
+ opt, argName, arg, elementParser,
367
+ [&](const ElementT &value) { vector.push_back (value); }));
368
+ }
369
+
370
+ static void print (raw_ostream &os, const VectorT &vector) {
371
+ llvm::interleave (
372
+ vector, os,
373
+ [&](const ElementT &value) {
374
+ mlir::detail::pass_options::printOptionValue<
375
+ llvm::cl::parser<ElementT>>(os, value);
376
+ },
377
+ " ," );
378
+ }
379
+
380
+ void printOptionInfo (const Option &opt, size_t globalWidth) const {
381
+ // Add the `vector<>` qualifier to the option info.
382
+ outs () << " --" << opt.ArgStr ;
383
+ outs () << " =<vector<" << elementParser.getValueName () << " >>" ;
384
+ Option::printHelpStr (opt.HelpStr , globalWidth, getOptionWidth (opt));
385
+ }
386
+
387
+ size_t getOptionWidth (const Option &opt) const {
388
+ // Add the `vector<>` qualifier to the option width.
389
+ StringRef vectorExt (" vector<>" );
390
+ return elementParser.getOptionWidth (opt) + vectorExt.size ();
391
+ }
392
+
393
+ private:
394
+ llvm::cl::parser<ElementT> elementParser;
395
+ };
396
+ } // namespace detail
397
+
398
+ template <typename T>
399
+ class parser <std::vector<T>>
400
+ : public detail::VectorParserBase<std::vector<T>, T> {
401
+ public:
402
+ parser (Option &opt) : detail::VectorParserBase<std::vector<T>, T>(opt) {}
403
+ };
404
+ template <typename T, unsigned N>
405
+ class parser <SmallVector<T, N>>
406
+ : public detail::VectorParserBase<SmallVector<T, N>, T> {
407
+ public:
408
+ parser (Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
409
+ };
410
+ } // end namespace cl
411
+ } // end namespace llvm
412
+
281
413
#endif // MLIR_PASS_PASSOPTIONS_H_
282
414
0 commit comments