Skip to content

Commit 69b6454

Browse files
committed
[mlir] Plumb through default attribute populate for extensible dialect.
1 parent fd6dae9 commit 69b6454

File tree

2 files changed

+41
-24
lines changed

2 files changed

+41
-24
lines changed

mlir/include/mlir/IR/ExtensibleDialect.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,8 @@ class DynamicOpDefinition {
362362
OperationName::PrintAssemblyFn &&printFn,
363363
OperationName::FoldHookFn &&foldHookFn,
364364
OperationName::GetCanonicalizationPatternsFn
365-
&&getCanonicalizationPatternsFn);
365+
&&getCanonicalizationPatternsFn,
366+
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
366367

367368
/// Returns the op typeID.
368369
TypeID getTypeID() { return typeID; }
@@ -405,15 +406,23 @@ class DynamicOpDefinition {
405406
getCanonicalizationPatternsFn = std::move(getCanonicalizationPatterns);
406407
}
407408

409+
/// Set the hook populating default attributes.
410+
void setPopulateDefaultAttrsFn(
411+
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrs) {
412+
populateDefaultAttrsFn = std::move(populateDefaultAttrs);
413+
}
414+
408415
private:
409-
DynamicOpDefinition(StringRef name, ExtensibleDialect *dialect,
410-
OperationName::VerifyInvariantsFn &&verifyFn,
411-
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
412-
OperationName::ParseAssemblyFn &&parseFn,
413-
OperationName::PrintAssemblyFn &&printFn,
414-
OperationName::FoldHookFn &&foldHookFn,
415-
OperationName::GetCanonicalizationPatternsFn
416-
&&getCanonicalizationPatternsFn);
416+
DynamicOpDefinition(
417+
StringRef name, ExtensibleDialect *dialect,
418+
OperationName::VerifyInvariantsFn &&verifyFn,
419+
OperationName::VerifyRegionInvariantsFn &&verifyRegionFn,
420+
OperationName::ParseAssemblyFn &&parseFn,
421+
OperationName::PrintAssemblyFn &&printFn,
422+
OperationName::FoldHookFn &&foldHookFn,
423+
OperationName::GetCanonicalizationPatternsFn
424+
&&getCanonicalizationPatternsFn,
425+
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn);
417426

418427
/// Unique identifier for this operation.
419428
TypeID typeID;
@@ -431,7 +440,7 @@ class DynamicOpDefinition {
431440
OperationName::PrintAssemblyFn printFn;
432441
OperationName::FoldHookFn foldHookFn;
433442
OperationName::GetCanonicalizationPatternsFn getCanonicalizationPatternsFn;
434-
OperationName::PopulateDefaultAttrsFn getPopulateDefaultAttrsFn;
443+
OperationName::PopulateDefaultAttrsFn populateDefaultAttrsFn;
435444

436445
friend ExtensibleDialect;
437446
};

mlir/lib/IR/ExtensibleDialect.cpp

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -295,13 +295,15 @@ DynamicOpDefinition::DynamicOpDefinition(
295295
OperationName::PrintAssemblyFn &&printFn,
296296
OperationName::FoldHookFn &&foldHookFn,
297297
OperationName::GetCanonicalizationPatternsFn
298-
&&getCanonicalizationPatternsFn)
298+
&&getCanonicalizationPatternsFn,
299+
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn)
299300
: typeID(dialect->allocateTypeID()),
300301
name((dialect->getNamespace() + "." + name).str()), dialect(dialect),
301302
verifyFn(std::move(verifyFn)), verifyRegionFn(std::move(verifyRegionFn)),
302303
parseFn(std::move(parseFn)), printFn(std::move(printFn)),
303304
foldHookFn(std::move(foldHookFn)),
304-
getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)) {}
305+
getCanonicalizationPatternsFn(std::move(getCanonicalizationPatternsFn)),
306+
populateDefaultAttrsFn(std::move(populateDefaultAttrsFn)) {}
305307

306308
std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
307309
StringRef name, ExtensibleDialect *dialect,
@@ -336,25 +338,31 @@ std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
336338
auto getCanonicalizationPatternsFn = [](RewritePatternSet &, MLIRContext *) {
337339
};
338340

341+
auto populateDefaultAttrsFn = [](const RegisteredOperationName &,
342+
NamedAttrList &) {};
343+
339344
return DynamicOpDefinition::get(name, dialect, std::move(verifyFn),
340345
std::move(verifyRegionFn), std::move(parseFn),
341346
std::move(printFn), std::move(foldHookFn),
342-
std::move(getCanonicalizationPatternsFn));
347+
std::move(getCanonicalizationPatternsFn),
348+
std::move(populateDefaultAttrsFn));
343349
}
344350

345-
std::unique_ptr<DynamicOpDefinition>
346-
DynamicOpDefinition::get(StringRef name, ExtensibleDialect *dialect,
347-
OperationName::VerifyInvariantsFn &&verifyFn,
348-
OperationName::VerifyInvariantsFn &&verifyRegionFn,
349-
OperationName::ParseAssemblyFn &&parseFn,
350-
OperationName::PrintAssemblyFn &&printFn,
351-
OperationName::FoldHookFn &&foldHookFn,
352-
OperationName::GetCanonicalizationPatternsFn
353-
&&getCanonicalizationPatternsFn) {
351+
std::unique_ptr<DynamicOpDefinition> DynamicOpDefinition::get(
352+
StringRef name, ExtensibleDialect *dialect,
353+
OperationName::VerifyInvariantsFn &&verifyFn,
354+
OperationName::VerifyInvariantsFn &&verifyRegionFn,
355+
OperationName::ParseAssemblyFn &&parseFn,
356+
OperationName::PrintAssemblyFn &&printFn,
357+
OperationName::FoldHookFn &&foldHookFn,
358+
OperationName::GetCanonicalizationPatternsFn
359+
&&getCanonicalizationPatternsFn,
360+
OperationName::PopulateDefaultAttrsFn &&populateDefaultAttrsFn) {
354361
return std::unique_ptr<DynamicOpDefinition>(new DynamicOpDefinition(
355362
name, dialect, std::move(verifyFn), std::move(verifyRegionFn),
356363
std::move(parseFn), std::move(printFn), std::move(foldHookFn),
357-
std::move(getCanonicalizationPatternsFn)));
364+
std::move(getCanonicalizationPatternsFn),
365+
std::move(populateDefaultAttrsFn)));
358366
}
359367

360368
//===----------------------------------------------------------------------===//
@@ -448,7 +456,7 @@ void ExtensibleDialect::registerDynamicOp(
448456
std::move(op->verifyRegionFn), std::move(op->foldHookFn),
449457
std::move(op->getCanonicalizationPatternsFn),
450458
detail::InterfaceMap::get<>(), std::move(hasTraitFn), {},
451-
std::move(op->getPopulateDefaultAttrsFn));
459+
std::move(op->populateDefaultAttrsFn));
452460
}
453461

454462
bool ExtensibleDialect::classof(const Dialect *dialect) {

0 commit comments

Comments
 (0)