Skip to content

Commit 96caf38

Browse files
committed
[mlir] Hoist out getRequestedOpDefinitions helper
Enables performing the same filtering in the op doc definition as in the op definition generator. Differential Revision: https://reviews.llvm.org/D99793
1 parent b09df24 commit 96caf38

File tree

5 files changed

+103
-51
lines changed

5 files changed

+103
-51
lines changed

mlir/tools/mlir-tblgen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_tablegen(mlir-tblgen MLIR
1515
OpDefinitionsGen.cpp
1616
OpDocGen.cpp
1717
OpFormatGen.cpp
18+
OpGenHelpers.cpp
1819
OpInterfacesGen.cpp
1920
OpPythonBindingGen.cpp
2021
PassCAPIGen.cpp

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 4 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "OpFormatGen.h"
15+
#include "OpGenHelpers.h"
1516
#include "mlir/TableGen/CodeGenHelpers.h"
1617
#include "mlir/TableGen/Format.h"
1718
#include "mlir/TableGen/GenInfo.h"
@@ -22,8 +23,7 @@
2223
#include "mlir/TableGen/SideEffects.h"
2324
#include "llvm/ADT/Sequence.h"
2425
#include "llvm/ADT/StringExtras.h"
25-
#include "llvm/Support/CommandLine.h"
26-
#include "llvm/Support/Regex.h"
26+
#include "llvm/Support/Path.h"
2727
#include "llvm/Support/Signals.h"
2828
#include "llvm/TableGen/Error.h"
2929
#include "llvm/TableGen/Record.h"
@@ -35,17 +35,6 @@ using namespace llvm;
3535
using namespace mlir;
3636
using namespace mlir::tblgen;
3737

38-
cl::OptionCategory opDefGenCat("Options for -gen-op-defs and -gen-op-decls");
39-
40-
static cl::opt<std::string> opIncFilter(
41-
"op-include-regex",
42-
cl::desc("Regex of name of op's to include (no filter if empty)"),
43-
cl::cat(opDefGenCat));
44-
static cl::opt<std::string> opExcFilter(
45-
"op-exclude-regex",
46-
cl::desc("Regex of name of op's to exclude (no filter if empty)"),
47-
cl::cat(opDefGenCat));
48-
4938
static const char *const tblgenNamePrefix = "tblgen_";
5039
static const char *const generatedArgName = "odsArg";
5140
static const char *const odsBuilder = "odsBuilder";
@@ -2472,44 +2461,10 @@ static void emitOpList(const std::vector<Record *> &defs, raw_ostream &os) {
24722461
[&os]() { os << ",\n"; });
24732462
}
24742463

2475-
static std::string getOperationName(const Record &def) {
2476-
auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
2477-
auto opName = def.getValueAsString("opName");
2478-
if (prefix.empty())
2479-
return std::string(opName);
2480-
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
2481-
}
2482-
2483-
static std::vector<Record *>
2484-
getAllDerivedDefinitions(const RecordKeeper &recordKeeper,
2485-
StringRef className) {
2486-
Record *classDef = recordKeeper.getClass(className);
2487-
if (!classDef)
2488-
PrintFatalError("ERROR: Couldn't find the `" + className + "' class!\n");
2489-
2490-
llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
2491-
std::vector<Record *> defs;
2492-
for (const auto &def : recordKeeper.getDefs()) {
2493-
if (!def.second->isSubClassOf(classDef))
2494-
continue;
2495-
// Include if no include filter or include filter matches.
2496-
if (!opIncFilter.empty() &&
2497-
!includeRegex.match(getOperationName(*def.second)))
2498-
continue;
2499-
// Unless there is an exclude filter and it matches.
2500-
if (!opExcFilter.empty() &&
2501-
excludeRegex.match(getOperationName(*def.second)))
2502-
continue;
2503-
defs.push_back(def.second.get());
2504-
}
2505-
2506-
return defs;
2507-
}
2508-
25092464
static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
25102465
emitSourceFileHeader("Op Declarations", os);
25112466

2512-
const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2467+
std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
25132468
emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
25142469

25152470
return false;
@@ -2518,7 +2473,7 @@ static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
25182473
static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
25192474
emitSourceFileHeader("Op Definitions", os);
25202475

2521-
const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2476+
std::vector<Record *> defs = getRequestedOpDefinitions(recordKeeper);
25222477
emitOpList(defs, os);
25232478
emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
25242479

mlir/tools/mlir-tblgen/OpDocGen.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "DocGenUtilities.h"
15+
#include "OpGenHelpers.h"
1516
#include "mlir/Support/IndentedOstream.h"
1617
#include "mlir/TableGen/AttrOrTypeDef.h"
1718
#include "mlir/TableGen/GenInfo.h"
@@ -141,7 +142,7 @@ static void emitOpDoc(Operator op, raw_ostream &os) {
141142
}
142143

143144
static void emitOpDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
144-
auto opDefs = recordKeeper.getAllDerivedDefinitions("Op");
145+
auto opDefs = getRequestedOpDefinitions(recordKeeper);
145146

146147
os << "<!-- Autogenerated by mlir-tblgen; don't manually edit -->\n";
147148
for (const llvm::Record *opDef : opDefs)
@@ -269,7 +270,7 @@ static void emitDialectDoc(const Dialect &dialect, ArrayRef<AttrDef> attrDefs,
269270
}
270271

271272
static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) {
272-
std::vector<Record *> opDefs = recordKeeper.getAllDerivedDefinitions("Op");
273+
std::vector<Record *> opDefs = getRequestedOpDefinitions(recordKeeper);
273274
std::vector<Record *> typeDefs =
274275
recordKeeper.getAllDerivedDefinitions("DialectType");
275276
std::vector<Record *> typeDefDefs =
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines helpers used in the op generators.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "OpGenHelpers.h"
14+
#include "llvm/Support/CommandLine.h"
15+
#include "llvm/Support/FormatVariadic.h"
16+
#include "llvm/Support/Regex.h"
17+
#include "llvm/TableGen/Error.h"
18+
19+
using namespace llvm;
20+
using namespace mlir;
21+
using namespace mlir::tblgen;
22+
23+
cl::OptionCategory opDefGenCat("Options for op definition generators");
24+
25+
static cl::opt<std::string> opIncFilter(
26+
"op-include-regex",
27+
cl::desc("Regex of name of op's to include (no filter if empty)"),
28+
cl::cat(opDefGenCat));
29+
static cl::opt<std::string> opExcFilter(
30+
"op-exclude-regex",
31+
cl::desc("Regex of name of op's to exclude (no filter if empty)"),
32+
cl::cat(opDefGenCat));
33+
34+
static std::string getOperationName(const Record &def) {
35+
auto prefix = def.getValueAsDef("opDialect")->getValueAsString("name");
36+
auto opName = def.getValueAsString("opName");
37+
if (prefix.empty())
38+
return std::string(opName);
39+
return std::string(llvm::formatv("{0}.{1}", prefix, opName));
40+
}
41+
42+
std::vector<Record *>
43+
mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper &recordKeeper) {
44+
Record *classDef = recordKeeper.getClass("Op");
45+
if (!classDef)
46+
PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
47+
48+
llvm::Regex includeRegex(opIncFilter), excludeRegex(opExcFilter);
49+
std::vector<Record *> defs;
50+
for (const auto &def : recordKeeper.getDefs()) {
51+
if (!def.second->isSubClassOf(classDef))
52+
continue;
53+
// Include if no include filter or include filter matches.
54+
if (!opIncFilter.empty() &&
55+
!includeRegex.match(getOperationName(*def.second)))
56+
continue;
57+
// Unless there is an exclude filter and it matches.
58+
if (!opExcFilter.empty() &&
59+
excludeRegex.match(getOperationName(*def.second)))
60+
continue;
61+
defs.push_back(def.second.get());
62+
}
63+
64+
return defs;
65+
}

mlir/tools/mlir-tblgen/OpGenHelpers.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- OpGenHelpers.h - MLIR operation generator helpers --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines helpers used in the op generators.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
14+
#define MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_
15+
16+
#include "llvm/TableGen/Record.h"
17+
#include <vector>
18+
19+
namespace mlir {
20+
namespace tblgen {
21+
22+
/// Returns all the op definitions filtered by the user. The filtering is via
23+
/// command-line option "op-include-regex" and "op-exclude-regex".
24+
std::vector<llvm::Record *>
25+
getRequestedOpDefinitions(const llvm::RecordKeeper &recordKeeper);
26+
27+
} // end namespace tblgen
28+
} // end namespace mlir
29+
30+
#endif // MLIR_TOOLS_MLIRTBLGEN_OPGENHELPERS_H_

0 commit comments

Comments
 (0)