Skip to content

Commit 397215c

Browse files
committed
[mlir][ods] Support dialect specific content emission via hooks
Thus far we can only generate the same set of methods even for operations in different dialects. This is problematic for dialects that want to generate additional operation class methods programmatically, e.g., a special builder method or attribute getter method. Apparently we cannot update the OpDefinitionsGen backend every time when such a need arises. So this CL introduces a hook into the OpDefinitionsGen backend to allow dialects to emit additional methods and traits to operation classes. Differential Revision: https://reviews.llvm.org/D72514
1 parent ca4a55f commit 397215c

File tree

3 files changed

+81
-7
lines changed

3 files changed

+81
-7
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
//===- ODSDialectHook.h - Dialect customization hooks into ODS --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines ODS customization hooks for dialects to programmatically
10+
// emit dialect specific contents in ODS C++ code emission.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_TABLEGEN_ODSDIALECTHOOK_H_
15+
#define MLIR_TABLEGEN_ODSDIALECTHOOK_H_
16+
17+
#include <functional>
18+
19+
namespace llvm {
20+
class StringRef;
21+
}
22+
23+
namespace mlir {
24+
namespace tblgen {
25+
class Operator;
26+
class OpClass;
27+
28+
// The emission function for dialect specific content. It takes in an Operator
29+
// and updates the OpClass accordingly.
30+
using DialectEmitFunction =
31+
std::function<void(const Operator &srcOp, OpClass &emitClass)>;
32+
33+
// ODSDialectHookRegistration provides a global initializer that registers a
34+
// dialect specific content emission function.
35+
struct ODSDialectHookRegistration {
36+
ODSDialectHookRegistration(llvm::StringRef dialectName,
37+
DialectEmitFunction emitFn);
38+
};
39+
} // namespace tblgen
40+
} // namespace mlir
41+
42+
#endif // MLIR_TABLEGEN_ODSDIALECTHOOK_H_

mlir/include/mlir/TableGen/Operator.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class Operator {
4646
// Returns this op's dialect name.
4747
StringRef getDialectName() const;
4848

49+
// Returns the dialect of the op.
50+
const Dialect &getDialect() const { return dialect; }
51+
4952
// Returns the operation name. The name will follow the "<dialect>.<op-name>"
5053
// format if its dialect name is not empty.
5154
std::string getOperationName() const;
@@ -156,14 +159,8 @@ class Operator {
156159
StringRef getExtraClassDeclaration() const;
157160

158161
// Returns the Tablegen definition this operator was constructed from.
159-
// TODO(antiagainst,zinenko): do not expose the TableGen record, this is a
160-
// temporary solution to OpEmitter requiring a Record because Operator does
161-
// not provide enough methods.
162162
const llvm::Record &getDef() const;
163163

164-
// Returns the dialect of the op.
165-
const Dialect &getDialect() const { return dialect; }
166-
167164
// Prints the contents in this operator to the given `os`. This is used for
168165
// debugging purposes.
169166
void print(llvm::raw_ostream &os) const;

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,49 @@
1414
#include "mlir/Support/STLExtras.h"
1515
#include "mlir/TableGen/Format.h"
1616
#include "mlir/TableGen/GenInfo.h"
17+
#include "mlir/TableGen/ODSDialectHook.h"
1718
#include "mlir/TableGen/OpClass.h"
1819
#include "mlir/TableGen/OpInterfaces.h"
1920
#include "mlir/TableGen/OpTrait.h"
2021
#include "mlir/TableGen/Operator.h"
2122
#include "llvm/ADT/StringExtras.h"
23+
#include "llvm/Support/ManagedStatic.h"
2224
#include "llvm/Support/Signals.h"
2325
#include "llvm/TableGen/Error.h"
2426
#include "llvm/TableGen/Record.h"
2527
#include "llvm/TableGen/TableGenBackend.h"
2628

2729
#define DEBUG_TYPE "mlir-tblgen-opdefgen"
2830

29-
using namespace llvm;
3031
using namespace mlir;
3132
using namespace mlir::tblgen;
3233

34+
using llvm::CodeInit;
35+
using llvm::DefInit;
36+
using llvm::formatv;
37+
using llvm::Init;
38+
using llvm::ListInit;
39+
using llvm::Record;
40+
using llvm::RecordKeeper;
41+
using llvm::StringInit;
42+
43+
//===----------------------------------------------------------------------===//
44+
// Dialect hook registration
45+
//===----------------------------------------------------------------------===//
46+
47+
static llvm::ManagedStatic<llvm::StringMap<DialectEmitFunction>> dialectHooks;
48+
49+
ODSDialectHookRegistration::ODSDialectHookRegistration(
50+
StringRef dialectName, DialectEmitFunction emitFn) {
51+
bool inserted = dialectHooks->try_emplace(dialectName, emitFn).second;
52+
assert(inserted && "Multiple ODS hooks for the same dialect!");
53+
(void)inserted;
54+
}
55+
56+
//===----------------------------------------------------------------------===//
57+
// Static string definitions
58+
//===----------------------------------------------------------------------===//
59+
3360
static const char *const tblgenNamePrefix = "tblgen_";
3461
static const char *const generatedArgName = "tblgen_arg";
3562
static const char *const builderOpState = "tblgen_state";
@@ -279,6 +306,7 @@ OpEmitter::OpEmitter(const Operator &op)
279306
verifyCtx.withOp("(*this->getOperation())");
280307

281308
genTraits();
309+
282310
// Generate C++ code for various op methods. The order here determines the
283311
// methods in the generated file.
284312
genOpAsmInterface();
@@ -294,6 +322,13 @@ OpEmitter::OpEmitter(const Operator &op)
294322
genCanonicalizerDecls();
295323
genFolderDecls();
296324
genOpInterfaceMethods();
325+
326+
// If a dialect hook is registered for this op's dialect, emit dialect
327+
// specific content.
328+
auto dialectHookIt = dialectHooks->find(op.getDialectName());
329+
if (dialectHookIt != dialectHooks->end()) {
330+
dialectHookIt->second(op, opClass);
331+
}
297332
}
298333

299334
void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {

0 commit comments

Comments
 (0)