Skip to content

Commit b299ec1

Browse files
committed
Expose callbacks for encoding of types/attributes
[mlir] Expose a mechanism to provide a callback for encoding types and attributes in MLIR bytecode. Two callbacks are exposed, respectively, to the BytecodeWriterConfig and to the ParserConfig. At bytecode parsing/printing, clients have the ability to specify a callback to be used to optionally read/write the encoding. On failure, fallback path will execute the default parsers and printers for the dialect. Testing shows how to leverage this functionality to support back-deployment and backward-compatibility usecases when roundtripping to bytecode a client dialect with type/attributes dependencies on upstream. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D153383
1 parent bb65caf commit b299ec1

20 files changed

+950
-152
lines changed

mlir/include/mlir/Bytecode/BytecodeImplementation.h

Lines changed: 23 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424
#include "llvm/ADT/Twine.h"
2525

2626
namespace mlir {
27+
//===--------------------------------------------------------------------===//
28+
// Dialect Version Interface.
29+
//===--------------------------------------------------------------------===//
30+
31+
/// This class is used to represent the version of a dialect, for the purpose
32+
/// of polymorphic destruction.
33+
class DialectVersion {
34+
public:
35+
virtual ~DialectVersion() = default;
36+
};
37+
2738
//===----------------------------------------------------------------------===//
2839
// DialectBytecodeReader
2940
//===----------------------------------------------------------------------===//
@@ -38,7 +49,14 @@ class DialectBytecodeReader {
3849
virtual ~DialectBytecodeReader() = default;
3950

4051
/// Emit an error to the reader.
41-
virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0;
52+
virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0;
53+
54+
/// Retrieve the dialect version by name if available.
55+
virtual FailureOr<const DialectVersion *>
56+
getDialectVersion(StringRef dialectName) const = 0;
57+
58+
/// Retrieve the context associated to the reader.
59+
virtual MLIRContext *getContext() const = 0;
4260

4361
/// Return the bytecode version being read.
4462
virtual uint64_t getBytecodeVersion() const = 0;
@@ -384,17 +402,6 @@ class DialectBytecodeWriter {
384402
virtual int64_t getBytecodeVersion() const = 0;
385403
};
386404

387-
//===--------------------------------------------------------------------===//
388-
// Dialect Version Interface.
389-
//===--------------------------------------------------------------------===//
390-
391-
/// This class is used to represent the version of a dialect, for the purpose
392-
/// of polymorphic destruction.
393-
class DialectVersion {
394-
public:
395-
virtual ~DialectVersion() = default;
396-
};
397-
398405
//===----------------------------------------------------------------------===//
399406
// BytecodeDialectInterface
400407
//===----------------------------------------------------------------------===//
@@ -409,47 +416,23 @@ class BytecodeDialectInterface
409416
//===--------------------------------------------------------------------===//
410417

411418
/// Read an attribute belonging to this dialect from the given reader. This
412-
/// method should return null in the case of failure.
419+
/// method should return null in the case of failure. Optionally, the dialect
420+
/// version can be accessed through the reader.
413421
virtual Attribute readAttribute(DialectBytecodeReader &reader) const {
414422
reader.emitError() << "dialect " << getDialect()->getNamespace()
415423
<< " does not support reading attributes from bytecode";
416424
return Attribute();
417425
}
418426

419-
/// Read a versioned attribute encoding belonging to this dialect from the
420-
/// given reader. This method should return null in the case of failure, and
421-
/// falls back to the non-versioned reader in case the dialect implements
422-
/// versioning but it does not support versioned custom encodings for the
423-
/// attributes.
424-
virtual Attribute readAttribute(DialectBytecodeReader &reader,
425-
const DialectVersion &version) const {
426-
reader.emitError()
427-
<< "dialect " << getDialect()->getNamespace()
428-
<< " does not support reading versioned attributes from bytecode";
429-
return Attribute();
430-
}
431-
432427
/// Read a type belonging to this dialect from the given reader. This method
433-
/// should return null in the case of failure.
428+
/// should return null in the case of failure. Optionally, the dialect version
429+
/// can be accessed thorugh the reader.
434430
virtual Type readType(DialectBytecodeReader &reader) const {
435431
reader.emitError() << "dialect " << getDialect()->getNamespace()
436432
<< " does not support reading types from bytecode";
437433
return Type();
438434
}
439435

440-
/// Read a versioned type encoding belonging to this dialect from the given
441-
/// reader. This method should return null in the case of failure, and
442-
/// falls back to the non-versioned reader in case the dialect implements
443-
/// versioning but it does not support versioned custom encodings for the
444-
/// types.
445-
virtual Type readType(DialectBytecodeReader &reader,
446-
const DialectVersion &version) const {
447-
reader.emitError()
448-
<< "dialect " << getDialect()->getNamespace()
449-
<< " does not support reading versioned types from bytecode";
450-
return Type();
451-
}
452-
453436
//===--------------------------------------------------------------------===//
454437
// Writing
455438
//===--------------------------------------------------------------------===//

mlir/include/mlir/Bytecode/BytecodeReader.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ class SourceMgr;
2525
} // namespace llvm
2626

2727
namespace mlir {
28-
2928
/// The BytecodeReader allows to load MLIR bytecode files, while keeping the
3029
/// state explicitly available in order to support lazy loading.
3130
/// The `finalize` method must be called before destruction.
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header defines interfaces to read MLIR bytecode files/streams.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_BYTECODE_BYTECODEREADERCONFIG_H
14+
#define MLIR_BYTECODE_BYTECODEREADERCONFIG_H
15+
16+
#include "mlir/Support/LLVM.h"
17+
#include "mlir/Support/LogicalResult.h"
18+
#include "llvm/ADT/ArrayRef.h"
19+
#include "llvm/ADT/SmallVector.h"
20+
#include "llvm/ADT/StringRef.h"
21+
22+
namespace mlir {
23+
class Attribute;
24+
class DialectBytecodeReader;
25+
class Type;
26+
27+
/// A class to interact with the attributes and types parser when parsing MLIR
28+
/// bytecode.
29+
template <class T>
30+
class AttrTypeBytecodeReader {
31+
public:
32+
AttrTypeBytecodeReader() = default;
33+
virtual ~AttrTypeBytecodeReader() = default;
34+
35+
virtual LogicalResult read(DialectBytecodeReader &reader,
36+
StringRef dialectName, T &entry) = 0;
37+
38+
/// Return an Attribute/Type printer implemented via the given callable, whose
39+
/// form should match that of the `parse` function above.
40+
template <typename CallableT,
41+
std::enable_if_t<
42+
std::is_convertible_v<
43+
CallableT, std::function<LogicalResult(
44+
DialectBytecodeReader &, StringRef, T &)>>,
45+
bool> = true>
46+
static std::unique_ptr<AttrTypeBytecodeReader<T>>
47+
fromCallable(CallableT &&readFn) {
48+
struct Processor : public AttrTypeBytecodeReader<T> {
49+
Processor(CallableT &&readFn)
50+
: AttrTypeBytecodeReader(), readFn(std::move(readFn)) {}
51+
LogicalResult read(DialectBytecodeReader &reader, StringRef dialectName,
52+
T &entry) override {
53+
return readFn(reader, dialectName, entry);
54+
}
55+
56+
std::decay_t<CallableT> readFn;
57+
};
58+
return std::make_unique<Processor>(std::forward<CallableT>(readFn));
59+
}
60+
};
61+
62+
//===----------------------------------------------------------------------===//
63+
// BytecodeReaderConfig
64+
//===----------------------------------------------------------------------===//
65+
66+
/// A class containing bytecode-specific configurations of the `ParserConfig`.
67+
class BytecodeReaderConfig {
68+
public:
69+
BytecodeReaderConfig() = default;
70+
71+
/// Returns the callbacks available to the parser.
72+
ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
73+
getAttributeCallbacks() const {
74+
return attributeBytecodeParsers;
75+
}
76+
ArrayRef<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
77+
getTypeCallbacks() const {
78+
return typeBytecodeParsers;
79+
}
80+
81+
/// Attach a custom bytecode parser callback to the configuration for parsing
82+
/// of custom type/attributes encodings.
83+
void attachAttributeCallback(
84+
std::unique_ptr<AttrTypeBytecodeReader<Attribute>> parser) {
85+
attributeBytecodeParsers.emplace_back(std::move(parser));
86+
}
87+
void
88+
attachTypeCallback(std::unique_ptr<AttrTypeBytecodeReader<Type>> parser) {
89+
typeBytecodeParsers.emplace_back(std::move(parser));
90+
}
91+
92+
/// Attach a custom bytecode parser callback to the configuration for parsing
93+
/// of custom type/attributes encodings.
94+
template <typename CallableT>
95+
std::enable_if_t<std::is_convertible_v<
96+
CallableT, std::function<LogicalResult(DialectBytecodeReader &, StringRef,
97+
Attribute &)>>>
98+
attachAttributeCallback(CallableT &&parserFn) {
99+
attachAttributeCallback(AttrTypeBytecodeReader<Attribute>::fromCallable(
100+
std::forward<CallableT>(parserFn)));
101+
}
102+
template <typename CallableT>
103+
std::enable_if_t<std::is_convertible_v<
104+
CallableT,
105+
std::function<LogicalResult(DialectBytecodeReader &, StringRef, Type &)>>>
106+
attachTypeCallback(CallableT &&parserFn) {
107+
attachTypeCallback(AttrTypeBytecodeReader<Type>::fromCallable(
108+
std::forward<CallableT>(parserFn)));
109+
}
110+
111+
private:
112+
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Attribute>>>
113+
attributeBytecodeParsers;
114+
llvm::SmallVector<std::unique_ptr<AttrTypeBytecodeReader<Type>>>
115+
typeBytecodeParsers;
116+
};
117+
118+
} // namespace mlir
119+
120+
#endif // MLIR_BYTECODE_BYTECODEREADERCONFIG_H

mlir/include/mlir/Bytecode/BytecodeWriter.h

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,55 @@
1717

1818
namespace mlir {
1919
class Operation;
20+
class DialectBytecodeWriter;
21+
22+
/// A class to interact with the attributes and types printer when emitting MLIR
23+
/// bytecode.
24+
template <class T>
25+
class AttrTypeBytecodeWriter {
26+
public:
27+
AttrTypeBytecodeWriter() = default;
28+
virtual ~AttrTypeBytecodeWriter() = default;
29+
30+
/// Callback writer API used in IRNumbering, where groups are created and
31+
/// type/attribute components are numbered. At this stage, writer is expected
32+
/// to be a `NumberingDialectWriter`.
33+
virtual LogicalResult write(T entry, std::optional<StringRef> &name,
34+
DialectBytecodeWriter &writer) = 0;
35+
36+
/// Callback writer API used in BytecodeWriter, where groups are created and
37+
/// type/attribute components are numbered. Here, DialectBytecodeWriter is
38+
/// expected to be an actual writer. The optional stringref specified by
39+
/// the user is ignored, since the group was already specified when numbering
40+
/// the IR.
41+
LogicalResult write(T entry, DialectBytecodeWriter &writer) {
42+
std::optional<StringRef> dummy;
43+
return write(entry, dummy, writer);
44+
}
45+
46+
/// Return an Attribute/Type printer implemented via the given callable, whose
47+
/// form should match that of the `write` function above.
48+
template <typename CallableT,
49+
std::enable_if_t<std::is_convertible_v<
50+
CallableT, std::function<LogicalResult(
51+
T, std::optional<StringRef> &,
52+
DialectBytecodeWriter &)>>,
53+
bool> = true>
54+
static std::unique_ptr<AttrTypeBytecodeWriter<T>>
55+
fromCallable(CallableT &&writeFn) {
56+
struct Processor : public AttrTypeBytecodeWriter<T> {
57+
Processor(CallableT &&writeFn)
58+
: AttrTypeBytecodeWriter(), writeFn(std::move(writeFn)) {}
59+
LogicalResult write(T entry, std::optional<StringRef> &name,
60+
DialectBytecodeWriter &writer) override {
61+
return writeFn(entry, name, writer);
62+
}
63+
64+
std::decay_t<CallableT> writeFn;
65+
};
66+
return std::make_unique<Processor>(std::forward<CallableT>(writeFn));
67+
}
68+
};
2069

2170
/// This class contains the configuration used for the bytecode writer. It
2271
/// controls various aspects of bytecode generation, and contains all of the
@@ -48,6 +97,43 @@ class BytecodeWriterConfig {
4897
/// Get the set desired bytecode version to emit.
4998
int64_t getDesiredBytecodeVersion() const;
5099

100+
//===--------------------------------------------------------------------===//
101+
// Types and Attributes encoding
102+
//===--------------------------------------------------------------------===//
103+
104+
/// Retrieve the callbacks.
105+
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Attribute>>>
106+
getAttributeWriterCallbacks() const;
107+
ArrayRef<std::unique_ptr<AttrTypeBytecodeWriter<Type>>>
108+
getTypeWriterCallbacks() const;
109+
110+
/// Attach a custom bytecode printer callback to the configuration for the
111+
/// emission of custom type/attributes encodings.
112+
void attachAttributeCallback(
113+
std::unique_ptr<AttrTypeBytecodeWriter<Attribute>> callback);
114+
void
115+
attachTypeCallback(std::unique_ptr<AttrTypeBytecodeWriter<Type>> callback);
116+
117+
/// Attach a custom bytecode printer callback to the configuration for the
118+
/// emission of custom type/attributes encodings.
119+
template <typename CallableT>
120+
std::enable_if_t<std::is_convertible_v<
121+
CallableT,
122+
std::function<LogicalResult(Attribute, std::optional<StringRef> &,
123+
DialectBytecodeWriter &)>>>
124+
attachAttributeCallback(CallableT &&emitFn) {
125+
attachAttributeCallback(AttrTypeBytecodeWriter<Attribute>::fromCallable(
126+
std::forward<CallableT>(emitFn)));
127+
}
128+
template <typename CallableT>
129+
std::enable_if_t<std::is_convertible_v<
130+
CallableT, std::function<LogicalResult(Type, std::optional<StringRef> &,
131+
DialectBytecodeWriter &)>>>
132+
attachTypeCallback(CallableT &&emitFn) {
133+
attachTypeCallback(AttrTypeBytecodeWriter<Type>::fromCallable(
134+
std::forward<CallableT>(emitFn)));
135+
}
136+
51137
//===--------------------------------------------------------------------===//
52138
// Resources
53139
//===--------------------------------------------------------------------===//

mlir/include/mlir/IR/AsmState.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_IR_ASMSTATE_H_
1515
#define MLIR_IR_ASMSTATE_H_
1616

17+
#include "mlir/Bytecode/BytecodeReaderConfig.h"
1718
#include "mlir/IR/OperationSupport.h"
1819
#include "mlir/Support/LLVM.h"
1920
#include "llvm/ADT/MapVector.h"
@@ -475,6 +476,11 @@ class ParserConfig {
475476
/// Returns if the parser should verify the IR after parsing.
476477
bool shouldVerifyAfterParse() const { return verifyAfterParse; }
477478

479+
/// Returns the parsing configurations associated to the bytecode read.
480+
BytecodeReaderConfig &getBytecodeReaderConfig() const {
481+
return const_cast<BytecodeReaderConfig &>(bytecodeReaderConfig);
482+
}
483+
478484
/// Return the resource parser registered to the given name, or nullptr if no
479485
/// parser with `name` is registered.
480486
AsmResourceParser *getResourceParser(StringRef name) const {
@@ -509,6 +515,7 @@ class ParserConfig {
509515
bool verifyAfterParse;
510516
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
511517
FallbackAsmResourceMap *fallbackResourceMap;
518+
BytecodeReaderConfig bytecodeReaderConfig;
512519
};
513520

514521
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)