Skip to content

Commit 6edef13

Browse files
committed
[mlir:PassOption] Rework ListOption parsing and add support for std::vector/SmallVector options
ListOption currently uses llvm::cl::list under the hood, but the usages of ListOption are generally a tad different from llvm::cl::list. This commit codifies this by making ListOption implicitly comma separated, and removes the explicit flag set for all of the current list options. The new parsing for comma separation of ListOption also adds in support for skipping over delimited sub-ranges (i.e. {}, [], (), "", ''). This more easily supports nested options that use those as part of the format, and this constraint (balanced delimiters) is already codified in the syntax of pass pipelines. See https://discourse.llvm.org/t/list-of-lists-pass-option/5950 for related discussion Differential Revision: https://reviews.llvm.org/D122879
1 parent e06ca31 commit 6edef13

22 files changed

+237
-95
lines changed

mlir/docs/PassManagement.md

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -431,9 +431,12 @@ components are integrated with the dynamic pipeline being executed.
431431
MLIR provides a builtin mechanism for passes to specify options that configure
432432
its behavior. These options are parsed at pass construction time independently
433433
for each instance of the pass. Options are defined using the `Option<>` and
434-
`ListOption<>` classes, and follow the
434+
`ListOption<>` classes, and generally follow the
435435
[LLVM command line](https://llvm.org/docs/CommandLine.html) flag definition
436-
rules. See below for a few examples:
436+
rules. One major distinction from the LLVM command line functionality is that
437+
all `ListOption`s are comma-separated, and delimited sub-ranges within individual
438+
elements of the list may contain commas that are not treated as separators for the
439+
top-level list.
437440

438441
```c++
439442
struct MyPass ... {
@@ -445,8 +448,7 @@ struct MyPass ... {
445448
/// Any parameters after the description are forwarded to llvm::cl::list and
446449
/// llvm::cl::opt respectively.
447450
Option<int> exampleOption{*this, "flag-name", llvm::cl::desc("...")};
448-
ListOption<int> exampleListOption{*this, "list-flag-name",
449-
llvm::cl::desc("...")};
451+
ListOption<int> exampleListOption{*this, "list-flag-name", llvm::cl::desc("...")};
450452
};
451453
```
452454
@@ -705,8 +707,7 @@ struct MyPass : PassWrapper<MyPass, OperationPass<ModuleOp>> {
705707
llvm::cl::desc("An example option"), llvm::cl::init(true)};
706708
ListOption<int64_t> listOption{
707709
*this, "example-list",
708-
llvm::cl::desc("An example list option"), llvm::cl::ZeroOrMore,
709-
llvm::cl::MiscFlags::CommaSeparated};
710+
llvm::cl::desc("An example list option"), llvm::cl::ZeroOrMore};
710711

711712
// Specify any statistics.
712713
Statistic statistic{this, "example-statistic", "An example statistic"};
@@ -742,8 +743,7 @@ def MyPass : Pass<"my-pass", "ModuleOp"> {
742743
Option<"option", "example-option", "bool", /*default=*/"true",
743744
"An example option">,
744745
ListOption<"listOption", "example-list", "int64_t",
745-
"An example list option",
746-
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
746+
"An example list option", "llvm::cl::ZeroOrMore">
747747
];
748748
749749
// Specify any statistics.
@@ -879,8 +879,7 @@ The `ListOption` class takes the following fields:
879879
def MyPass : Pass<"my-pass"> {
880880
let options = [
881881
ListOption<"listOption", "example-list", "int64_t",
882-
"An example list option",
883-
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
882+
"An example list option", "llvm::cl::ZeroOrMore">
884883
];
885884
}
886885
```

mlir/docs/PatternRewriter.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,12 +439,10 @@ below:
439439
440440
```tablegen
441441
ListOption<"disabledPatterns", "disable-patterns", "std::string",
442-
"Labels of patterns that should be filtered out during application",
443-
"llvm::cl::MiscFlags::CommaSeparated">,
442+
"Labels of patterns that should be filtered out during application">,
444443
ListOption<"enabledPatterns", "enable-patterns", "std::string",
445444
"Labels of patterns that should be used during application, all "
446-
"other patterns are filtered out",
447-
"llvm::cl::MiscFlags::CommaSeparated">,
445+
"other patterns are filtered out">,
448446
```
449447
450448
These options may be used to provide filtering behavior when constructing any

mlir/include/mlir/Dialect/Affine/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def AffineVectorize : Pass<"affine-super-vectorize", "FuncOp"> {
348348
let options = [
349349
ListOption<"vectorSizes", "virtual-vector-size", "int64_t",
350350
"Specify an n-D virtual vector size for vectorization",
351-
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
351+
"llvm::cl::ZeroOrMore">,
352352
// Optionally, the fixed mapping from loop to fastest varying MemRef
353353
// dimension for all the MemRefs within a loop pattern:
354354
// the index represents the loop depth, the value represents the k^th
@@ -359,7 +359,7 @@ def AffineVectorize : Pass<"affine-super-vectorize", "FuncOp"> {
359359
"Specify a 1-D, 2-D or 3-D pattern of fastest varying memory "
360360
"dimensions to match. See defaultPatterns in Vectorize.cpp for "
361361
"a description and examples. This is used for testing purposes",
362-
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
362+
"llvm::cl::ZeroOrMore">,
363363
Option<"vectorizeReductions", "vectorize-reductions", "bool",
364364
/*default=*/"false",
365365
"Vectorize known reductions expressed via iter_args. "

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,7 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
215215
"Specify if buffers should be deallocated. For compatibility with "
216216
"core bufferization passes.">,
217217
ListOption<"dialectFilter", "dialect-filter", "std::string",
218-
"Restrict bufferization to ops from these dialects.",
219-
"llvm::cl::MiscFlags::CommaSeparated">,
218+
"Restrict bufferization to ops from these dialects.">,
220219
Option<"fullyDynamicLayoutMaps", "fully-dynamic-layout-maps", "bool",
221220
/*default=*/"true",
222221
"Generate MemRef types with dynamic offset+strides by default.">,

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def LinalgTiling : Pass<"linalg-tile", "FuncOp"> {
194194
];
195195
let options = [
196196
ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes",
197-
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
197+
"llvm::cl::ZeroOrMore">,
198198
Option<"loopType", "loop-type", "std::string", /*default=*/"\"for\"",
199199
"Specify the type of loops to generate: for, parallel">
200200
];

mlir/include/mlir/Dialect/SCF/Passes.td

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,11 @@ def SCFParallelLoopCollapsing : Pass<"scf-parallel-loop-collapsing"> {
5555
let constructor = "mlir::createParallelLoopCollapsingPass()";
5656
let options = [
5757
ListOption<"clCollapsedIndices0", "collapsed-indices-0", "unsigned",
58-
"Which loop indices to combine 0th loop index",
59-
"llvm::cl::MiscFlags::CommaSeparated">,
58+
"Which loop indices to combine 0th loop index">,
6059
ListOption<"clCollapsedIndices1", "collapsed-indices-1", "unsigned",
61-
"Which loop indices to combine into the position 1 loop index",
62-
"llvm::cl::MiscFlags::CommaSeparated">,
60+
"Which loop indices to combine into the position 1 loop index">,
6361
ListOption<"clCollapsedIndices2", "collapsed-indices-2", "unsigned",
64-
"Which loop indices to combine into the position 2 loop index",
65-
"llvm::cl::MiscFlags::CommaSeparated">,
62+
"Which loop indices to combine into the position 2 loop index">,
6663
];
6764
}
6865

@@ -77,8 +74,7 @@ def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling", "FuncOp"> {
7774
let constructor = "mlir::createParallelLoopTilingPass()";
7875
let options = [
7976
ListOption<"tileSizes", "parallel-loop-tile-sizes", "int64_t",
80-
"Factors to tile parallel loops by",
81-
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
77+
"Factors to tile parallel loops by", "llvm::cl::ZeroOrMore">,
8278
Option<"noMinMaxBounds", "no-min-max-bounds", "bool",
8379
/*default=*/"false",
8480
"Perform tiling with fixed upper bound with inbound check "

mlir/include/mlir/Pass/PassOptions.h

Lines changed: 146 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,63 @@
1616

1717
#include "mlir/Support/LLVM.h"
1818
#include "mlir/Support/LogicalResult.h"
19+
#include "llvm/ADT/FunctionExtras.h"
1920
#include "llvm/ADT/StringRef.h"
2021
#include "llvm/Support/CommandLine.h"
2122
#include "llvm/Support/Compiler.h"
2223
#include <memory>
2324

2425
namespace mlir {
2526
namespace detail {
27+
namespace pass_options {
28+
/// Parse a string containing a list of comma-delimited elements, invoking the
29+
/// given parser for each sub-element and passing them to the provided
30+
/// element-append functor.
31+
LogicalResult
32+
parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
33+
StringRef optionStr,
34+
function_ref<LogicalResult(StringRef)> elementParseFn);
35+
template <typename ElementParser, typename ElementAppendFn>
36+
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName,
37+
StringRef optionStr,
38+
ElementParser &elementParser,
39+
ElementAppendFn &&appendFn) {
40+
return parseCommaSeparatedList(
41+
opt, argName, optionStr, [&](StringRef valueStr) {
42+
typename ElementParser::parser_data_type value = {};
43+
if (elementParser.parse(opt, argName, valueStr, value))
44+
return failure();
45+
appendFn(value);
46+
return success();
47+
});
48+
}
49+
50+
/// Trait used to detect if a type has a operator<< method.
51+
template <typename T>
52+
using has_stream_operator_trait =
53+
decltype(std::declval<raw_ostream &>() << std::declval<T>());
54+
template <typename T>
55+
using has_stream_operator = llvm::is_detected<has_stream_operator_trait, T>;
56+
57+
/// Utility methods for printing option values.
58+
template <typename ParserT>
59+
static void printOptionValue(raw_ostream &os, const bool &value) {
60+
os << (value ? StringRef("true") : StringRef("false"));
61+
}
62+
template <typename ParserT, typename DataT>
63+
static std::enable_if_t<has_stream_operator<DataT>::value>
64+
printOptionValue(raw_ostream &os, const DataT &value) {
65+
os << value;
66+
}
67+
template <typename ParserT, typename DataT>
68+
static std::enable_if_t<!has_stream_operator<DataT>::value>
69+
printOptionValue(raw_ostream &os, const DataT &value) {
70+
// If the value can't be streamed, fallback to checking for a print in the
71+
// parser.
72+
ParserT::print(os, value);
73+
}
74+
} // namespace pass_options
75+
2676
/// Base container class and manager for all pass options.
2777
class PassOptions : protected llvm::cl::SubCommand {
2878
private:
@@ -85,11 +135,7 @@ class PassOptions : protected llvm::cl::SubCommand {
85135
}
86136
template <typename DataT, typename ParserT>
87137
static void printValue(raw_ostream &os, ParserT &parser, const DataT &value) {
88-
os << value;
89-
}
90-
template <typename ParserT>
91-
static void printValue(raw_ostream &os, ParserT &parser, const bool &value) {
92-
os << (value ? StringRef("true") : StringRef("false"));
138+
detail::pass_options::printOptionValue<ParserT>(os, value);
93139
}
94140

95141
public:
@@ -149,22 +195,27 @@ class PassOptions : protected llvm::cl::SubCommand {
149195
};
150196

151197
/// This class represents a specific pass option that contains a list of
152-
/// values of the provided data type.
198+
/// values of the provided data type. The elements within the textual form of
199+
/// this option are parsed assuming they are comma-separated. Delimited
200+
/// sub-ranges within individual elements of the list may contain commas that
201+
/// are not treated as separators for the top-level list.
153202
template <typename DataType, typename OptionParser = OptionParser<DataType>>
154203
class ListOption
155204
: public llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>,
156205
public OptionBase {
157206
public:
158207
template <typename... Args>
159-
ListOption(PassOptions &parent, StringRef arg, Args &&... args)
208+
ListOption(PassOptions &parent, StringRef arg, Args &&...args)
160209
: llvm::cl::list<DataType, /*StorageClass=*/bool, OptionParser>(
161-
arg, llvm::cl::sub(parent), std::forward<Args>(args)...) {
210+
arg, llvm::cl::sub(parent), std::forward<Args>(args)...),
211+
elementParser(*this) {
162212
assert(!this->isPositional() && !this->isSink() &&
163213
"sink and positional options are not supported");
214+
assert(!(this->getMiscFlags() & llvm::cl::MiscFlags::CommaSeparated) &&
215+
"ListOption is implicitly comma separated, specifying "
216+
"CommaSeparated is extraneous");
164217
parent.options.push_back(this);
165-
166-
// Set a callback to track if this option has a value.
167-
this->setCallback([this](const auto &) { this->optHasValue = true; });
218+
elementParser.initialize();
168219
}
169220
~ListOption() override = default;
170221
ListOption<DataType, OptionParser> &
@@ -174,6 +225,14 @@ class PassOptions : protected llvm::cl::SubCommand {
174225
return *this;
175226
}
176227

228+
bool handleOccurrence(unsigned pos, StringRef argName,
229+
StringRef arg) override {
230+
this->optHasValue = true;
231+
return failed(detail::pass_options::parseCommaSeparatedList(
232+
*this, argName, arg, elementParser,
233+
[&](const DataType &value) { this->addValue(value); }));
234+
}
235+
177236
/// Allow assigning from an ArrayRef.
178237
ListOption<DataType, OptionParser> &operator=(ArrayRef<DataType> values) {
179238
((std::vector<DataType> &)*this).assign(values.begin(), values.end());
@@ -211,6 +270,9 @@ class PassOptions : protected llvm::cl::SubCommand {
211270
void copyValueFrom(const OptionBase &other) final {
212271
*this = static_cast<const ListOption<DataType, OptionParser> &>(other);
213272
}
273+
274+
/// The parser to use for parsing the list elements.
275+
OptionParser elementParser;
214276
};
215277

216278
PassOptions() = default;
@@ -255,9 +317,7 @@ class PassOptions : protected llvm::cl::SubCommand {
255317
/// Usage:
256318
///
257319
/// struct MyPipelineOptions : PassPipelineOptions<MyPassOptions> {
258-
/// ListOption<int> someListFlag{
259-
/// *this, "flag-name", llvm::cl::MiscFlags::CommaSeparated,
260-
/// llvm::cl::desc("...")};
320+
/// ListOption<int> someListFlag{*this, "flag-name", llvm::cl::desc("...")};
261321
/// };
262322
template <typename T> class PassPipelineOptions : public detail::PassOptions {
263323
public:
@@ -278,5 +338,77 @@ struct EmptyPipelineOptions : public PassPipelineOptions<EmptyPipelineOptions> {
278338

279339
} // namespace mlir
280340

341+
//===----------------------------------------------------------------------===//
342+
// MLIR Options
343+
//===----------------------------------------------------------------------===//
344+
345+
namespace llvm {
346+
namespace cl {
347+
//===----------------------------------------------------------------------===//
348+
// std::vector+SmallVector
349+
350+
namespace detail {
351+
template <typename VectorT, typename ElementT>
352+
class VectorParserBase : public basic_parser_impl {
353+
public:
354+
VectorParserBase(Option &opt) : basic_parser_impl(opt), elementParser(opt) {}
355+
356+
using parser_data_type = VectorT;
357+
358+
bool parse(Option &opt, StringRef argName, StringRef arg,
359+
parser_data_type &vector) {
360+
if (!arg.consume_front("[") || !arg.consume_back("]")) {
361+
return opt.error("expected vector option to be wrapped with '[]'",
362+
argName);
363+
}
364+
365+
return failed(mlir::detail::pass_options::parseCommaSeparatedList(
366+
opt, argName, arg, elementParser,
367+
[&](const ElementT &value) { vector.push_back(value); }));
368+
}
369+
370+
static void print(raw_ostream &os, const VectorT &vector) {
371+
llvm::interleave(
372+
vector, os,
373+
[&](const ElementT &value) {
374+
mlir::detail::pass_options::printOptionValue<
375+
llvm::cl::parser<ElementT>>(os, value);
376+
},
377+
",");
378+
}
379+
380+
void printOptionInfo(const Option &opt, size_t globalWidth) const {
381+
// Add the `vector<>` qualifier to the option info.
382+
outs() << " --" << opt.ArgStr;
383+
outs() << "=<vector<" << elementParser.getValueName() << ">>";
384+
Option::printHelpStr(opt.HelpStr, globalWidth, getOptionWidth(opt));
385+
}
386+
387+
size_t getOptionWidth(const Option &opt) const {
388+
// Add the `vector<>` qualifier to the option width.
389+
StringRef vectorExt("vector<>");
390+
return elementParser.getOptionWidth(opt) + vectorExt.size();
391+
}
392+
393+
private:
394+
llvm::cl::parser<ElementT> elementParser;
395+
};
396+
} // namespace detail
397+
398+
template <typename T>
399+
class parser<std::vector<T>>
400+
: public detail::VectorParserBase<std::vector<T>, T> {
401+
public:
402+
parser(Option &opt) : detail::VectorParserBase<std::vector<T>, T>(opt) {}
403+
};
404+
template <typename T, unsigned N>
405+
class parser<SmallVector<T, N>>
406+
: public detail::VectorParserBase<SmallVector<T, N>, T> {
407+
public:
408+
parser(Option &opt) : detail::VectorParserBase<SmallVector<T, N>, T>(opt) {}
409+
};
410+
} // end namespace cl
411+
} // end namespace llvm
412+
281413
#endif // MLIR_PASS_PASSOPTIONS_H_
282414

mlir/include/mlir/Reducer/Passes.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ def CommonReductionPassOptions {
2020
Option<"testerName", "test", "std::string", /* default */"",
2121
"The location of the tester which tests the file interestingness">,
2222
ListOption<"testerArgs", "test-arg", "std::string",
23-
"arguments of the tester",
24-
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">,
23+
"arguments of the tester", "llvm::cl::ZeroOrMore">,
2524
];
2625
}
2726

mlir/include/mlir/Rewrite/PassUtil.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,10 @@ def RewritePassUtils {
2424
// created.
2525
ListOption<"disabledPatterns", "disable-patterns", "std::string",
2626
"Labels of patterns that should be filtered out during"
27-
" application",
28-
"llvm::cl::MiscFlags::CommaSeparated">,
27+
" application">,
2928
ListOption<"enabledPatterns", "enable-patterns", "std::string",
3029
"Labels of patterns that should be used during"
31-
" application, all other patterns are filtered out",
32-
"llvm::cl::MiscFlags::CommaSeparated">,
30+
" application, all other patterns are filtered out">,
3331
];
3432
}
3533

0 commit comments

Comments
 (0)