Skip to content

Commit f3798ad

Browse files
committed
Static verifier for type/attribute in DRR
Generate static function for matching the type/attribute to reduce the memory footprint. Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D110199
1 parent ca47447 commit f3798ad

File tree

6 files changed

+208
-74
lines changed

6 files changed

+208
-74
lines changed

mlir/include/mlir/TableGen/CodeGenHelpers.h

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Support/IndentedOstream.h"
1717
#include "mlir/TableGen/Dialect.h"
18+
#include "mlir/TableGen/Format.h"
1819
#include "llvm/ADT/DenseMap.h"
1920
#include "llvm/ADT/StringExtras.h"
2021
#include "llvm/ADT/StringRef.h"
@@ -91,8 +92,7 @@ class NamespaceEmitter {
9192
///
9293
class StaticVerifierFunctionEmitter {
9394
public:
94-
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
95-
raw_ostream &os);
95+
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records);
9696

9797
/// Emit the static verifier functions for `llvm::Record`s. The
9898
/// `signatureFormat` describes the required arguments and it must have a
@@ -112,30 +112,40 @@ class StaticVerifierFunctionEmitter {
112112
///
113113
/// `typeArgName` is used to identify the argument that needs to check its
114114
/// type. The constraint template will replace `$_self` with it.
115-
void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat,
116-
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs,
117-
bool emitDecl);
115+
116+
/// This is the helper to generate the constraint functions from op
117+
/// definitions.
118+
void emitConstraintMethodsInNamespace(StringRef signatureFormat,
119+
StringRef errorHandlerFormat,
120+
StringRef cppNamespace,
121+
ArrayRef<const void *> constraints,
122+
raw_ostream &rawOs, bool emitDecl);
123+
124+
/// Emit the static functions for the giving type constraints.
125+
void emitConstraintMethods(StringRef signatureFormat,
126+
StringRef errorHandlerFormat,
127+
ArrayRef<const void *> constraints,
128+
raw_ostream &rawOs, bool emitDecl);
118129

119130
/// Get the name of the local function used for the given type constraint.
120131
/// These functions are used for operand and result constraints and have the
121132
/// form:
122133
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
123134
/// unsigned valueGroupStartIndex);
124-
StringRef getTypeConstraintFn(const Constraint &constraint) const;
135+
StringRef getConstraintFn(const Constraint &constraint) const;
136+
137+
/// The setter to set `self` in format context.
138+
StaticVerifierFunctionEmitter &setSelf(StringRef str);
139+
140+
/// The setter to set `builder` in format context.
141+
StaticVerifierFunctionEmitter &setBuilder(StringRef str);
125142

126143
private:
127144
/// Returns a unique name to use when generating local methods.
128145
static std::string getUniqueName(const llvm::RecordKeeper &records);
129146

130-
/// Emit local methods for the type constraints used within the provided op
131-
/// definitions.
132-
void emitTypeConstraintMethods(StringRef signatureFormat,
133-
StringRef errorHandlerFormat,
134-
StringRef typeArgName,
135-
ArrayRef<llvm::Record *> opDefs,
136-
bool emitDecl);
137-
138-
raw_indented_ostream os;
147+
/// The format context used for building the verifier function.
148+
FmtContext fctx;
139149

140150
/// A unique label for the file currently being generated. This is used to
141151
/// ensure that the local functions have a unique name.

mlir/include/mlir/TableGen/Pattern.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ class DagLeaf {
113113
void print(raw_ostream &os) const;
114114

115115
private:
116+
friend llvm::DenseMapInfo<DagLeaf>;
117+
const void *getAsOpaquePointer() const { return def; }
118+
116119
// Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and
117120
// also a subclass of the given `superclass`.
118121
bool isSubClassOf(StringRef superclass) const;
@@ -523,6 +526,24 @@ struct DenseMapInfo<mlir::tblgen::DagNode> {
523526
return lhs.node == rhs.node;
524527
}
525528
};
529+
530+
template <>
531+
struct DenseMapInfo<mlir::tblgen::DagLeaf> {
532+
static mlir::tblgen::DagLeaf getEmptyKey() {
533+
return mlir::tblgen::DagLeaf(
534+
llvm::DenseMapInfo<llvm::Init *>::getEmptyKey());
535+
}
536+
static mlir::tblgen::DagLeaf getTombstoneKey() {
537+
return mlir::tblgen::DagLeaf(
538+
llvm::DenseMapInfo<llvm::Init *>::getTombstoneKey());
539+
}
540+
static unsigned getHashValue(mlir::tblgen::DagLeaf leaf) {
541+
return llvm::hash_value(leaf.getAsOpaquePointer());
542+
}
543+
static bool isEqual(mlir::tblgen::DagLeaf lhs, mlir::tblgen::DagLeaf rhs) {
544+
return lhs.def == rhs.def;
545+
}
546+
};
526547
} // end namespace llvm
527548

528549
#endif // MLIR_TABLEGEN_PATTERN_H_

mlir/test/mlir-tblgen/rewriter-static-matcher.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,16 @@ def COp : NS_Op<"c_op", []> {
3737
// Test static matcher for duplicate DagNode
3838
// ---
3939

40-
// CHECK: static ::mlir::LogicalResult static_dag_matcher_0
40+
// CHECK-DAG: static ::mlir::LogicalResult [[$TYPE_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Type typeOrAttr}}
41+
// CHECK-DAG: static ::mlir::LogicalResult [[$ATTR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]({{.*::mlir::Attribute}}
42+
// CHECK-DAG: static ::mlir::LogicalResult [[$DAG_MATCHER:static_dag_matcher.*]](
43+
// CHECK: if(failed([[$TYPE_CONSTRAINT]]
44+
// CHECK: if(failed([[$ATTR_CONSTRAINT]]
4145

42-
// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
46+
// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
4347
def : Pat<(AOp (BOp I32Attr:$attr, I32:$int)),
4448
(AOp $int)>;
4549

46-
// CHECK: if(failed(static_dag_matcher_0(rewriter, op1, tblgen_ops
50+
// CHECK: if(failed([[$DAG_MATCHER]](rewriter, op1, tblgen_ops
4751
def : Pat<(COp $_, (BOp I32Attr:$attr, I32:$int)),
4852
(COp $attr, $int)>;

mlir/tools/mlir-tblgen/CodeGenHelpers.cpp

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/TableGen/CodeGenHelpers.h"
15-
#include "mlir/TableGen/Format.h"
1615
#include "mlir/TableGen/Operator.h"
1716
#include "llvm/ADT/SetVector.h"
1817
#include "llvm/Support/FormatVariadic.h"
@@ -24,21 +23,34 @@ using namespace mlir;
2423
using namespace mlir::tblgen;
2524

2625
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
27-
const llvm::RecordKeeper &records, raw_ostream &os)
28-
: os(os), uniqueOutputLabel(getUniqueName(records)) {}
26+
const llvm::RecordKeeper &records)
27+
: uniqueOutputLabel(getUniqueName(records)) {}
2928

30-
void StaticVerifierFunctionEmitter::emitFunctionsFor(
29+
StaticVerifierFunctionEmitter &
30+
StaticVerifierFunctionEmitter::setSelf(StringRef str) {
31+
fctx.withSelf(str);
32+
return *this;
33+
}
34+
35+
StaticVerifierFunctionEmitter &
36+
StaticVerifierFunctionEmitter::setBuilder(StringRef str) {
37+
fctx.withBuilder(str);
38+
return *this;
39+
}
40+
41+
void StaticVerifierFunctionEmitter::emitConstraintMethodsInNamespace(
3142
StringRef signatureFormat, StringRef errorHandlerFormat,
32-
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
43+
StringRef cppNamespace, ArrayRef<const void *> constraints, raw_ostream &os,
44+
bool emitDecl) {
3345
llvm::Optional<NamespaceEmitter> namespaceEmitter;
3446
if (!emitDecl)
35-
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
47+
namespaceEmitter.emplace(os, cppNamespace);
3648

37-
emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName,
38-
opDefs, emitDecl);
49+
emitConstraintMethods(signatureFormat, errorHandlerFormat, constraints, os,
50+
emitDecl);
3951
}
4052

41-
StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
53+
StringRef StaticVerifierFunctionEmitter::getConstraintFn(
4254
const Constraint &constraint) const {
4355
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
4456
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
@@ -65,28 +77,16 @@ std::string StaticVerifierFunctionEmitter::getUniqueName(
6577
return uniqueName;
6678
}
6779

68-
void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
80+
void StaticVerifierFunctionEmitter::emitConstraintMethods(
6981
StringRef signatureFormat, StringRef errorHandlerFormat,
70-
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
71-
// Collect a set of all of the used type constraints within the operation
72-
// definitions.
73-
llvm::SetVector<const void *> typeConstraints;
74-
for (Record *def : opDefs) {
75-
Operator op(*def);
76-
for (NamedTypeConstraint &operand : op.getOperands())
77-
if (operand.hasPredicate())
78-
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
79-
for (NamedTypeConstraint &result : op.getResults())
80-
if (result.hasPredicate())
81-
typeConstraints.insert(result.constraint.getAsOpaquePointer());
82-
}
82+
ArrayRef<const void *> constraints, raw_ostream &rawOs, bool emitDecl) {
83+
raw_indented_ostream os(rawOs);
8384

8485
// Record the mapping from predicate to constraint. If two constraints has the
8586
// same predicate and constraint summary, they can share the same verification
8687
// function.
8788
llvm::DenseMap<Pred, const void *> predToConstraint;
88-
FmtContext fctx;
89-
for (auto it : llvm::enumerate(typeConstraints)) {
89+
for (auto it : llvm::enumerate(constraints)) {
9090
std::string name;
9191
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
9292
Pred pred = constraint.getPredicate();
@@ -101,7 +101,7 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
101101
// summary, otherwise we may report the wrong message while verification
102102
// fails.
103103
if (constraint.getSummary() == built.getSummary()) {
104-
name = getTypeConstraintFn(built).str();
104+
name = getConstraintFn(built).str();
105105
break;
106106
}
107107
++iter;
@@ -126,12 +126,11 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
126126
continue;
127127

128128
os << formatv(signatureFormat.data(), name) << " {\n";
129-
os.indent() << "if (!("
130-
<< tgfmt(constraint.getConditionTemplate(),
131-
&fctx.withSelf(typeArgName))
129+
os.indent() << "if (!(" << tgfmt(constraint.getConditionTemplate(), &fctx)
132130
<< ")) {\n";
133131
os.indent() << "return "
134-
<< formatv(errorHandlerFormat.data(), constraint.getSummary())
132+
<< formatv(errorHandlerFormat.data(),
133+
escapeString(constraint.getSummary()))
135134
<< ";\n";
136135
os.unindent() << "}\nreturn ::mlir::success();\n";
137136
os.unindent() << "}\n\n";

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2233,7 +2233,7 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
22332233
continue;
22342234
// Emit a loop to check all the dynamic values in the pack.
22352235
StringRef constraintFn =
2236-
staticVerifierEmitter.getTypeConstraintFn(value.constraint);
2236+
staticVerifierEmitter.getConstraintFn(value.constraint);
22372237
body << " for (::mlir::Value v : valueGroup" << staticValue.index()
22382238
<< ") {\n"
22392239
<< " if (::mlir::failed(" << constraintFn
@@ -2639,11 +2639,27 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
26392639
return;
26402640

26412641
// Generate all of the locally instantiated methods first.
2642-
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os);
2642+
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper);
26432643
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
2644-
staticVerifierEmitter.emitFunctionsFor(
2645-
typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type",
2646-
defs, emitDecl);
2644+
staticVerifierEmitter.setSelf("type");
2645+
2646+
// Collect a set of all of the used type constraints within the operation
2647+
// definitions.
2648+
llvm::SetVector<const void *> typeConstraints;
2649+
for (Record *def : defs) {
2650+
Operator op(*def);
2651+
for (NamedTypeConstraint &operand : op.getOperands())
2652+
if (operand.hasPredicate())
2653+
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
2654+
for (NamedTypeConstraint &result : op.getResults())
2655+
if (result.hasPredicate())
2656+
typeConstraints.insert(result.constraint.getAsOpaquePointer());
2657+
}
2658+
2659+
staticVerifierEmitter.emitConstraintMethodsInNamespace(
2660+
typeVerifierSignature, typeVerifierErrorHandler,
2661+
Operator(*defs[0]).getCppNamespace(), typeConstraints.getArrayRef(), os,
2662+
emitDecl);
26472663

26482664
for (auto *def : defs) {
26492665
Operator op(*def);

0 commit comments

Comments
 (0)