Skip to content

Commit b3ee7f1

Browse files
committed
[mlir][OpDefGen] Add support for generating local functions for shared utilities
This revision adds a new `StaticVerifierFunctionEmitter` class that emits local static functions in the .cpp file for shared operation verification. This class deduplicates shared operation verification code by emitting static functions alongside the op definitions. These methods are local to the definition file, and are invoked within the operation verify methods. The first bit of shared verification is for the type constraints used when verifying operands and results. An example is shown below: ``` static LogicalResult localVerify(...) { ... } LogicalResult OpA::verify(...) { if (failed(localVerify(...))) return failure(); ... } LogicalResult OpB::verify(...) { if (failed(localVerify(...))) return failure(); ... } ``` This allowed for saving >400kb of code size from a downstream TensorFlow project (~15% of MLIR code size). Differential Revision: https://reviews.llvm.org/D91381
1 parent cf5845d commit b3ee7f1

File tree

3 files changed

+201
-30
lines changed

3 files changed

+201
-30
lines changed

mlir/include/mlir/TableGen/Constraint.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ class Constraint {
5252

5353
Kind getKind() const { return kind; }
5454

55+
/// Get an opaque pointer to the constraint.
56+
const void *getAsOpaquePointer() const { return def; }
57+
/// Construct a constraint from the opaque pointer representation.
58+
static Constraint getFromOpaquePointer(const void *ptr) {
59+
return Constraint(reinterpret_cast<const llvm::Record *>(ptr));
60+
}
61+
5562
protected:
5663
Constraint(Kind kind, const llvm::Record *record);
5764

mlir/test/mlir-tblgen/predicate.td

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,22 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
1515
let arguments = (ins I32OrF32:$x);
1616
}
1717

18+
// CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
19+
// CHECK: if (!((type.isInteger(32) || type.isF32()))) {
20+
// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
21+
22+
// CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
23+
// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
24+
// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
25+
26+
// CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
27+
// CHECK: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
28+
// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
29+
1830
// CHECK-LABEL: OpA::verify
1931
// CHECK: auto valueGroup0 = getODSOperands(0);
2032
// CHECK: for (::mlir::Value v : valueGroup0) {
21-
// CHECK: if (!((v.getType().isInteger(32) || v.getType().isF32())))
33+
// CHECK: if (::mlir::failed([[$INTEGER_FLOAT_CONSTRAINT]]
2234

2335
def OpB : NS_Op<"op_for_And_PredOpTrait", [
2436
PredOpTrait<"both first and second holds",
@@ -93,4 +105,4 @@ def OpK : NS_Op<"op_for_AnyTensorOf", []> {
93105
// CHECK-LABEL: OpK::verify
94106
// CHECK: auto valueGroup0 = getODSOperands(0);
95107
// CHECK: for (::mlir::Value v : valueGroup0) {
96-
// CHECK: if (!(((v.getType().isa<::mlir::TensorType>())) && (((v.getType().cast<::mlir::ShapedType>().getElementType().isF32())) || ((v.getType().cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32))))))
108+
// CHECK: if (::mlir::failed([[$TENSOR_INTEGER_FLOAT_CONSTRAINT]]

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 180 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,144 @@ static const char *const opCommentHeader = R"(
117117
118118
)";
119119

120+
//===----------------------------------------------------------------------===//
121+
// StaticVerifierFunctionEmitter
122+
//===----------------------------------------------------------------------===//
123+
124+
namespace {
125+
/// This class deduplicates shared operation verification code by emitting
126+
/// static functions alongside the op definitions. These methods are local to
127+
/// the definition file, and are invoked within the operation verify methods.
128+
/// An example is shown below:
129+
///
130+
/// static LogicalResult localVerify(...)
131+
///
132+
/// LogicalResult OpA::verify(...) {
133+
/// if (failed(localVerify(...)))
134+
/// return failure();
135+
/// ...
136+
/// }
137+
///
138+
/// LogicalResult OpB::verify(...) {
139+
/// if (failed(localVerify(...)))
140+
/// return failure();
141+
/// ...
142+
/// }
143+
///
144+
class StaticVerifierFunctionEmitter {
145+
public:
146+
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
147+
ArrayRef<llvm::Record *> opDefs,
148+
raw_ostream &os, bool emitDecl);
149+
150+
/// Get the name of the local function used for the given type constraint.
151+
/// These functions are used for operand and result constraints and have the
152+
/// form:
153+
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
154+
/// unsigned valueGroupStartIndex);
155+
StringRef getTypeConstraintFn(const Constraint &constraint) const {
156+
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
157+
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
158+
return it->second;
159+
}
160+
161+
private:
162+
/// Returns a unique name to use when generating local methods.
163+
static std::string getUniqueName(const llvm::RecordKeeper &records);
164+
165+
/// Emit local methods for the type constraints used within the provided op
166+
/// definitions.
167+
void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
168+
raw_ostream &os, bool emitDecl);
169+
170+
/// A unique label for the file currently being generated. This is used to
171+
/// ensure that the local functions have a unique name.
172+
std::string uniqueOutputLabel;
173+
174+
/// A set of functions implementing type constraints, used for operand and
175+
/// result verification.
176+
llvm::DenseMap<const void *, std::string> localTypeConstraints;
177+
};
178+
} // namespace
179+
180+
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
181+
const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
182+
raw_ostream &os, bool emitDecl)
183+
: uniqueOutputLabel(getUniqueName(records)) {
184+
llvm::Optional<NamespaceEmitter> namespaceEmitter;
185+
if (!emitDecl) {
186+
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
187+
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getDialect());
188+
}
189+
190+
emitTypeConstraintMethods(opDefs, os, emitDecl);
191+
}
192+
193+
std::string StaticVerifierFunctionEmitter::getUniqueName(
194+
const llvm::RecordKeeper &records) {
195+
// Use the input file name when generating a unique name.
196+
std::string inputFilename = records.getInputFilename();
197+
198+
// Drop all but the base filename.
199+
StringRef nameRef = llvm::sys::path::filename(inputFilename);
200+
nameRef.consume_back(".td");
201+
202+
// Sanitize any invalid characters.
203+
std::string uniqueName;
204+
for (char c : nameRef) {
205+
if (llvm::isAlnum(c) || c == '_')
206+
uniqueName.push_back(c);
207+
else
208+
uniqueName.append(llvm::utohexstr((unsigned char)c));
209+
}
210+
return uniqueName;
211+
}
212+
213+
void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
214+
ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
215+
// Collect a set of all of the used type constraints within the operation
216+
// definitions.
217+
llvm::SetVector<const void *> typeConstraints;
218+
for (Record *def : opDefs) {
219+
Operator op(*def);
220+
for (NamedTypeConstraint &operand : op.getOperands())
221+
if (operand.hasPredicate())
222+
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
223+
for (NamedTypeConstraint &result : op.getResults())
224+
if (result.hasPredicate())
225+
typeConstraints.insert(result.constraint.getAsOpaquePointer());
226+
}
227+
228+
FmtContext fctx;
229+
for (auto it : llvm::enumerate(typeConstraints)) {
230+
// Generate an obscure and unique name for this type constraint.
231+
std::string name = (Twine("__mlir_ods_local_type_constraint_") +
232+
uniqueOutputLabel + Twine(it.index()))
233+
.str();
234+
localTypeConstraints.try_emplace(it.value(), name);
235+
236+
// Only generate the methods if we are generating definitions.
237+
if (emitDecl)
238+
continue;
239+
240+
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
241+
os << "static ::mlir::LogicalResult " << name
242+
<< "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
243+
"valueKind, unsigned valueGroupStartIndex) {\n";
244+
245+
os << " if (!("
246+
<< tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
247+
<< ")) {\n"
248+
<< formatv(
249+
" return op->emitOpError(valueKind) << \" #\" << "
250+
"valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
251+
constraint.getDescription())
252+
<< " }\n"
253+
<< " return ::mlir::success();\n"
254+
<< "}\n\n";
255+
}
256+
}
257+
120258
//===----------------------------------------------------------------------===//
121259
// Utility structs and functions
122260
//===----------------------------------------------------------------------===//
@@ -164,11 +302,16 @@ namespace {
164302
// Helper class to emit a record into the given output stream.
165303
class OpEmitter {
166304
public:
167-
static void emitDecl(const Operator &op, raw_ostream &os);
168-
static void emitDef(const Operator &op, raw_ostream &os);
305+
static void
306+
emitDecl(const Operator &op, raw_ostream &os,
307+
const StaticVerifierFunctionEmitter &staticVerifierEmitter);
308+
static void
309+
emitDef(const Operator &op, raw_ostream &os,
310+
const StaticVerifierFunctionEmitter &staticVerifierEmitter);
169311

170312
private:
171-
OpEmitter(const Operator &op);
313+
OpEmitter(const Operator &op,
314+
const StaticVerifierFunctionEmitter &staticVerifierEmitter);
172315

173316
void emitDecl(raw_ostream &os);
174317
void emitDef(raw_ostream &os);
@@ -321,6 +464,9 @@ class OpEmitter {
321464

322465
// The format context for verification code generation.
323466
FmtContext verifyCtx;
467+
468+
// The emitter containing all of the locally emitted verification functions.
469+
const StaticVerifierFunctionEmitter &staticVerifierEmitter;
324470
};
325471
} // end anonymous namespace
326472

@@ -434,9 +580,11 @@ static void genAttributeVerifier(const Operator &op, const char *attrGet,
434580
}
435581
}
436582

437-
OpEmitter::OpEmitter(const Operator &op)
583+
OpEmitter::OpEmitter(const Operator &op,
584+
const StaticVerifierFunctionEmitter &staticVerifierEmitter)
438585
: def(op.getDef()), op(op),
439-
opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
586+
opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
587+
staticVerifierEmitter(staticVerifierEmitter) {
440588
verifyCtx.withOp("(*this->getOperation())");
441589

442590
genTraits();
@@ -464,12 +612,16 @@ OpEmitter::OpEmitter(const Operator &op)
464612
genSideEffectInterfaceMethods();
465613
}
466614

467-
void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) {
468-
OpEmitter(op).emitDecl(os);
615+
void OpEmitter::emitDecl(
616+
const Operator &op, raw_ostream &os,
617+
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
618+
OpEmitter(op, staticVerifierEmitter).emitDecl(os);
469619
}
470620

471-
void OpEmitter::emitDef(const Operator &op, raw_ostream &os) {
472-
OpEmitter(op).emitDef(os);
621+
void OpEmitter::emitDef(
622+
const Operator &op, raw_ostream &os,
623+
const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
624+
OpEmitter(op, staticVerifierEmitter).emitDef(os);
473625
}
474626

475627
void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); }
@@ -1891,23 +2043,16 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
18912043
// Otherwise, if there is no predicate there is nothing left to do.
18922044
if (!hasPredicate)
18932045
continue;
1894-
18952046
// Emit a loop to check all the dynamic values in the pack.
2047+
StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
2048+
staticValue.value().constraint);
18962049
body << " for (::mlir::Value v : valueGroup" << staticValue.index()
1897-
<< ") {\n";
1898-
1899-
auto constraint = staticValue.value().constraint;
1900-
body << " (void)v;\n"
1901-
<< " if (!("
1902-
<< tgfmt(constraint.getConditionTemplate(),
1903-
&fctx.withSelf("v.getType()"))
1904-
<< ")) {\n"
1905-
<< formatv(" return emitOpError(\"{0} #\") << index "
1906-
"<< \" must be {1}, but got \" << v.getType();\n",
1907-
valueKind, constraint.getDescription())
1908-
<< " }\n" // if
2050+
<< ") {\n"
2051+
<< " if (::mlir::failed(" << constraintFn
2052+
<< "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n"
2053+
<< " return ::mlir::failure();\n"
19092054
<< " ++index;\n"
1910-
<< " }\n"; // for
2055+
<< " }\n";
19112056
}
19122057

19132058
body << " }\n";
@@ -2248,7 +2393,8 @@ void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
22482393
}
22492394

22502395
// Emits the opcode enum and op classes.
2251-
static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
2396+
static void emitOpClasses(const RecordKeeper &recordKeeper,
2397+
const std::vector<Record *> &defs, raw_ostream &os,
22522398
bool emitDecl) {
22532399
// First emit forward declaration for each class, this allows them to refer
22542400
// to each others in traits for example.
@@ -2264,17 +2410,23 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
22642410
}
22652411

22662412
IfDefScope scope("GET_OP_CLASSES", os);
2413+
if (defs.empty())
2414+
return;
2415+
2416+
// Generate all of the locally instantiated methods first.
2417+
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
2418+
emitDecl);
22672419
for (auto *def : defs) {
22682420
Operator op(*def);
22692421
NamespaceEmitter emitter(os, op.getDialect());
22702422
if (emitDecl) {
22712423
os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations");
22722424
OpOperandAdaptorEmitter::emitDecl(op, os);
2273-
OpEmitter::emitDecl(op, os);
2425+
OpEmitter::emitDecl(op, os, staticVerifierEmitter);
22742426
} else {
22752427
os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions");
22762428
OpOperandAdaptorEmitter::emitDef(op, os);
2277-
OpEmitter::emitDef(op, os);
2429+
OpEmitter::emitDef(op, os, staticVerifierEmitter);
22782430
}
22792431
}
22802432
}
@@ -2329,7 +2481,7 @@ static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
23292481
emitSourceFileHeader("Op Declarations", os);
23302482

23312483
const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
2332-
emitOpClasses(defs, os, /*emitDecl=*/true);
2484+
emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true);
23332485

23342486
return false;
23352487
}
@@ -2339,7 +2491,7 @@ static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
23392491

23402492
const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op");
23412493
emitOpList(defs, os);
2342-
emitOpClasses(defs, os, /*emitDecl=*/false);
2494+
emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false);
23432495

23442496
return false;
23452497
}

0 commit comments

Comments
 (0)