@@ -117,6 +117,144 @@ static const char *const opCommentHeader = R"(
117
117
118
118
)" ;
119
119
120
+ // ===----------------------------------------------------------------------===//
121
+ // StaticVerifierFunctionEmitter
122
+ // ===----------------------------------------------------------------------===//
123
+
124
+ namespace {
125
+ // / This class deduplicates shared operation verification code by emitting
126
+ // / static functions alongside the op definitions. These methods are local to
127
+ // / the definition file, and are invoked within the operation verify methods.
128
+ // / An example is shown below:
129
+ // /
130
+ // / static LogicalResult localVerify(...)
131
+ // /
132
+ // / LogicalResult OpA::verify(...) {
133
+ // / if (failed(localVerify(...)))
134
+ // / return failure();
135
+ // / ...
136
+ // / }
137
+ // /
138
+ // / LogicalResult OpB::verify(...) {
139
+ // / if (failed(localVerify(...)))
140
+ // / return failure();
141
+ // / ...
142
+ // / }
143
+ // /
144
+ class StaticVerifierFunctionEmitter {
145
+ public:
146
+ StaticVerifierFunctionEmitter (const llvm::RecordKeeper &records,
147
+ ArrayRef<llvm::Record *> opDefs,
148
+ raw_ostream &os, bool emitDecl);
149
+
150
+ // / Get the name of the local function used for the given type constraint.
151
+ // / These functions are used for operand and result constraints and have the
152
+ // / form:
153
+ // / LogicalResult(Operation *op, Type type, StringRef valueKind,
154
+ // / unsigned valueGroupStartIndex);
155
+ StringRef getTypeConstraintFn (const Constraint &constraint) const {
156
+ auto it = localTypeConstraints.find (constraint.getAsOpaquePointer ());
157
+ assert (it != localTypeConstraints.end () && " expected valid constraint fn" );
158
+ return it->second ;
159
+ }
160
+
161
+ private:
162
+ // / Returns a unique name to use when generating local methods.
163
+ static std::string getUniqueName (const llvm::RecordKeeper &records);
164
+
165
+ // / Emit local methods for the type constraints used within the provided op
166
+ // / definitions.
167
+ void emitTypeConstraintMethods (ArrayRef<llvm::Record *> opDefs,
168
+ raw_ostream &os, bool emitDecl);
169
+
170
+ // / A unique label for the file currently being generated. This is used to
171
+ // / ensure that the local functions have a unique name.
172
+ std::string uniqueOutputLabel;
173
+
174
+ // / A set of functions implementing type constraints, used for operand and
175
+ // / result verification.
176
+ llvm::DenseMap<const void *, std::string> localTypeConstraints;
177
+ };
178
+ } // namespace
179
+
180
+ StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter (
181
+ const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
182
+ raw_ostream &os, bool emitDecl)
183
+ : uniqueOutputLabel(getUniqueName(records)) {
184
+ llvm::Optional<NamespaceEmitter> namespaceEmitter;
185
+ if (!emitDecl) {
186
+ os << formatv (opCommentHeader, " Local Utility Method" , " Definitions" );
187
+ namespaceEmitter.emplace (os, Operator (*opDefs[0 ]).getDialect ());
188
+ }
189
+
190
+ emitTypeConstraintMethods (opDefs, os, emitDecl);
191
+ }
192
+
193
+ std::string StaticVerifierFunctionEmitter::getUniqueName (
194
+ const llvm::RecordKeeper &records) {
195
+ // Use the input file name when generating a unique name.
196
+ std::string inputFilename = records.getInputFilename ();
197
+
198
+ // Drop all but the base filename.
199
+ StringRef nameRef = llvm::sys::path::filename (inputFilename);
200
+ nameRef.consume_back (" .td" );
201
+
202
+ // Sanitize any invalid characters.
203
+ std::string uniqueName;
204
+ for (char c : nameRef) {
205
+ if (llvm::isAlnum (c) || c == ' _' )
206
+ uniqueName.push_back (c);
207
+ else
208
+ uniqueName.append (llvm::utohexstr ((unsigned char )c));
209
+ }
210
+ return uniqueName;
211
+ }
212
+
213
+ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods (
214
+ ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
215
+ // Collect a set of all of the used type constraints within the operation
216
+ // definitions.
217
+ llvm::SetVector<const void *> typeConstraints;
218
+ for (Record *def : opDefs) {
219
+ Operator op (*def);
220
+ for (NamedTypeConstraint &operand : op.getOperands ())
221
+ if (operand.hasPredicate ())
222
+ typeConstraints.insert (operand.constraint .getAsOpaquePointer ());
223
+ for (NamedTypeConstraint &result : op.getResults ())
224
+ if (result.hasPredicate ())
225
+ typeConstraints.insert (result.constraint .getAsOpaquePointer ());
226
+ }
227
+
228
+ FmtContext fctx;
229
+ for (auto it : llvm::enumerate (typeConstraints)) {
230
+ // Generate an obscure and unique name for this type constraint.
231
+ std::string name = (Twine (" __mlir_ods_local_type_constraint_" ) +
232
+ uniqueOutputLabel + Twine (it.index ()))
233
+ .str ();
234
+ localTypeConstraints.try_emplace (it.value (), name);
235
+
236
+ // Only generate the methods if we are generating definitions.
237
+ if (emitDecl)
238
+ continue ;
239
+
240
+ Constraint constraint = Constraint::getFromOpaquePointer (it.value ());
241
+ os << " static ::mlir::LogicalResult " << name
242
+ << " (::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
243
+ " valueKind, unsigned valueGroupStartIndex) {\n " ;
244
+
245
+ os << " if (!("
246
+ << tgfmt (constraint.getConditionTemplate (), &fctx.withSelf (" type" ))
247
+ << " )) {\n "
248
+ << formatv (
249
+ " return op->emitOpError(valueKind) << \" #\" << "
250
+ " valueGroupStartIndex << \" must be {0}, but got \" << type;\n " ,
251
+ constraint.getDescription ())
252
+ << " }\n "
253
+ << " return ::mlir::success();\n "
254
+ << " }\n\n " ;
255
+ }
256
+ }
257
+
120
258
// ===----------------------------------------------------------------------===//
121
259
// Utility structs and functions
122
260
// ===----------------------------------------------------------------------===//
@@ -164,11 +302,16 @@ namespace {
164
302
// Helper class to emit a record into the given output stream.
165
303
class OpEmitter {
166
304
public:
167
- static void emitDecl (const Operator &op, raw_ostream &os);
168
- static void emitDef (const Operator &op, raw_ostream &os);
305
+ static void
306
+ emitDecl (const Operator &op, raw_ostream &os,
307
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter);
308
+ static void
309
+ emitDef (const Operator &op, raw_ostream &os,
310
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter);
169
311
170
312
private:
171
- OpEmitter (const Operator &op);
313
+ OpEmitter (const Operator &op,
314
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter);
172
315
173
316
void emitDecl (raw_ostream &os);
174
317
void emitDef (raw_ostream &os);
@@ -321,6 +464,9 @@ class OpEmitter {
321
464
322
465
// The format context for verification code generation.
323
466
FmtContext verifyCtx;
467
+
468
+ // The emitter containing all of the locally emitted verification functions.
469
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter;
324
470
};
325
471
} // end anonymous namespace
326
472
@@ -434,9 +580,11 @@ static void genAttributeVerifier(const Operator &op, const char *attrGet,
434
580
}
435
581
}
436
582
437
- OpEmitter::OpEmitter (const Operator &op)
583
+ OpEmitter::OpEmitter (const Operator &op,
584
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter)
438
585
: def(op.getDef()), op(op),
439
- opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
586
+ opClass(op.getCppClassName(), op.getExtraClassDeclaration()),
587
+ staticVerifierEmitter(staticVerifierEmitter) {
440
588
verifyCtx.withOp (" (*this->getOperation())" );
441
589
442
590
genTraits ();
@@ -464,12 +612,16 @@ OpEmitter::OpEmitter(const Operator &op)
464
612
genSideEffectInterfaceMethods ();
465
613
}
466
614
467
- void OpEmitter::emitDecl (const Operator &op, raw_ostream &os) {
468
- OpEmitter (op).emitDecl (os);
615
+ void OpEmitter::emitDecl (
616
+ const Operator &op, raw_ostream &os,
617
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
618
+ OpEmitter (op, staticVerifierEmitter).emitDecl (os);
469
619
}
470
620
471
- void OpEmitter::emitDef (const Operator &op, raw_ostream &os) {
472
- OpEmitter (op).emitDef (os);
621
+ void OpEmitter::emitDef (
622
+ const Operator &op, raw_ostream &os,
623
+ const StaticVerifierFunctionEmitter &staticVerifierEmitter) {
624
+ OpEmitter (op, staticVerifierEmitter).emitDef (os);
473
625
}
474
626
475
627
void OpEmitter::emitDecl (raw_ostream &os) { opClass.writeDeclTo (os); }
@@ -1891,23 +2043,16 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
1891
2043
// Otherwise, if there is no predicate there is nothing left to do.
1892
2044
if (!hasPredicate)
1893
2045
continue ;
1894
-
1895
2046
// Emit a loop to check all the dynamic values in the pack.
2047
+ StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn (
2048
+ staticValue.value ().constraint );
1896
2049
body << " for (::mlir::Value v : valueGroup" << staticValue.index ()
1897
- << " ) {\n " ;
1898
-
1899
- auto constraint = staticValue.value ().constraint ;
1900
- body << " (void)v;\n "
1901
- << " if (!("
1902
- << tgfmt (constraint.getConditionTemplate (),
1903
- &fctx.withSelf (" v.getType()" ))
1904
- << " )) {\n "
1905
- << formatv (" return emitOpError(\" {0} #\" ) << index "
1906
- " << \" must be {1}, but got \" << v.getType();\n " ,
1907
- valueKind, constraint.getDescription ())
1908
- << " }\n " // if
2050
+ << " ) {\n "
2051
+ << " if (::mlir::failed(" << constraintFn
2052
+ << " (getOperation(), v.getType(), \" " << valueKind << " \" , index)))\n "
2053
+ << " return ::mlir::failure();\n "
1909
2054
<< " ++index;\n "
1910
- << " }\n " ; // for
2055
+ << " }\n " ;
1911
2056
}
1912
2057
1913
2058
body << " }\n " ;
@@ -2248,7 +2393,8 @@ void OpOperandAdaptorEmitter::emitDef(const Operator &op, raw_ostream &os) {
2248
2393
}
2249
2394
2250
2395
// Emits the opcode enum and op classes.
2251
- static void emitOpClasses (const std::vector<Record *> &defs, raw_ostream &os,
2396
+ static void emitOpClasses (const RecordKeeper &recordKeeper,
2397
+ const std::vector<Record *> &defs, raw_ostream &os,
2252
2398
bool emitDecl) {
2253
2399
// First emit forward declaration for each class, this allows them to refer
2254
2400
// to each others in traits for example.
@@ -2264,17 +2410,23 @@ static void emitOpClasses(const std::vector<Record *> &defs, raw_ostream &os,
2264
2410
}
2265
2411
2266
2412
IfDefScope scope (" GET_OP_CLASSES" , os);
2413
+ if (defs.empty ())
2414
+ return ;
2415
+
2416
+ // Generate all of the locally instantiated methods first.
2417
+ StaticVerifierFunctionEmitter staticVerifierEmitter (recordKeeper, defs, os,
2418
+ emitDecl);
2267
2419
for (auto *def : defs) {
2268
2420
Operator op (*def);
2269
2421
NamespaceEmitter emitter (os, op.getDialect ());
2270
2422
if (emitDecl) {
2271
2423
os << formatv (opCommentHeader, op.getQualCppClassName (), " declarations" );
2272
2424
OpOperandAdaptorEmitter::emitDecl (op, os);
2273
- OpEmitter::emitDecl (op, os);
2425
+ OpEmitter::emitDecl (op, os, staticVerifierEmitter );
2274
2426
} else {
2275
2427
os << formatv (opCommentHeader, op.getQualCppClassName (), " definitions" );
2276
2428
OpOperandAdaptorEmitter::emitDef (op, os);
2277
- OpEmitter::emitDef (op, os);
2429
+ OpEmitter::emitDef (op, os, staticVerifierEmitter );
2278
2430
}
2279
2431
}
2280
2432
}
@@ -2329,7 +2481,7 @@ static bool emitOpDecls(const RecordKeeper &recordKeeper, raw_ostream &os) {
2329
2481
emitSourceFileHeader (" Op Declarations" , os);
2330
2482
2331
2483
const auto &defs = getAllDerivedDefinitions (recordKeeper, " Op" );
2332
- emitOpClasses (defs, os, /* emitDecl=*/ true );
2484
+ emitOpClasses (recordKeeper, defs, os, /* emitDecl=*/ true );
2333
2485
2334
2486
return false ;
2335
2487
}
@@ -2339,7 +2491,7 @@ static bool emitOpDefs(const RecordKeeper &recordKeeper, raw_ostream &os) {
2339
2491
2340
2492
const auto &defs = getAllDerivedDefinitions (recordKeeper, " Op" );
2341
2493
emitOpList (defs, os);
2342
- emitOpClasses (defs, os, /* emitDecl=*/ false );
2494
+ emitOpClasses (recordKeeper, defs, os, /* emitDecl=*/ false );
2343
2495
2344
2496
return false ;
2345
2497
}
0 commit comments