@@ -290,11 +290,16 @@ class OpEmitter {
290
290
// Generates the traits used by the object.
291
291
void genTraits ();
292
292
293
- // Generate the OpInterface methods.
293
+ // Generate the OpInterface methods for all interfaces .
294
294
void genOpInterfaceMethods ();
295
295
296
- // Generate op interface method.
297
- void genOpInterfaceMethod (const tblgen::InterfaceOpTrait *trait);
296
+ // Generate op interface methods for the given interface.
297
+ void genOpInterfaceMethods (const tblgen::InterfaceOpTrait *trait);
298
+
299
+ // Generate op interface method for the given interface method. If
300
+ // 'declaration' is true, generates a declaration, else a definition.
301
+ OpMethod *genOpInterfaceMethod (const tblgen::InterfaceMethod &method,
302
+ bool declaration = true );
298
303
299
304
// Generate the side effect interface methods.
300
305
void genSideEffectInterfaceMethods ();
@@ -1588,7 +1593,7 @@ void OpEmitter::genFolderDecls() {
1588
1593
}
1589
1594
}
1590
1595
1591
- void OpEmitter::genOpInterfaceMethod (const tblgen::InterfaceOpTrait *opTrait) {
1596
+ void OpEmitter::genOpInterfaceMethods (const tblgen::InterfaceOpTrait *opTrait) {
1592
1597
auto interface = opTrait->getOpInterface ();
1593
1598
1594
1599
// Get the set of methods that should always be declared.
@@ -1606,23 +1611,29 @@ void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
1606
1611
if (method.getDefaultImplementation () &&
1607
1612
!alwaysDeclaredMethods.count (method.getName ()))
1608
1613
continue ;
1609
-
1610
- SmallVector<OpMethodParameter, 4 > paramList;
1611
- for (const InterfaceMethod::Argument &arg : method.getArguments ())
1612
- paramList.emplace_back (arg.type , arg.name );
1613
-
1614
- auto properties = method.isStatic () ? OpMethod::MP_StaticDeclaration
1615
- : OpMethod::MP_Declaration;
1616
- opClass.addMethodAndPrune (method.getReturnType (), method.getName (),
1617
- properties, std::move (paramList));
1614
+ genOpInterfaceMethod (method);
1618
1615
}
1619
1616
}
1620
1617
1618
+ OpMethod *OpEmitter::genOpInterfaceMethod (const InterfaceMethod &method,
1619
+ bool declaration) {
1620
+ SmallVector<OpMethodParameter, 4 > paramList;
1621
+ for (const InterfaceMethod::Argument &arg : method.getArguments ())
1622
+ paramList.emplace_back (arg.type , arg.name );
1623
+
1624
+ auto properties = method.isStatic () ? OpMethod::MP_Static : OpMethod::MP_None;
1625
+ if (declaration)
1626
+ properties =
1627
+ static_cast <OpMethod::Property>(properties | OpMethod::MP_Declaration);
1628
+ return opClass.addMethodAndPrune (method.getReturnType (), method.getName (),
1629
+ properties, std::move (paramList));
1630
+ }
1631
+
1621
1632
void OpEmitter::genOpInterfaceMethods () {
1622
1633
for (const auto &trait : op.getTraits ()) {
1623
1634
if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
1624
1635
if (opTrait->shouldDeclareMethods ())
1625
- genOpInterfaceMethod (opTrait);
1636
+ genOpInterfaceMethods (opTrait);
1626
1637
}
1627
1638
}
1628
1639
@@ -1727,18 +1738,20 @@ void OpEmitter::genSideEffectInterfaceMethods() {
1727
1738
void OpEmitter::genTypeInterfaceMethods () {
1728
1739
if (!op.allResultTypesKnown ())
1729
1740
return ;
1730
-
1731
- SmallVector<OpMethodParameter, 4 > paramList;
1732
- paramList.emplace_back (" ::mlir::MLIRContext *" , " context" );
1733
- paramList.emplace_back (" ::llvm::Optional<::mlir::Location>" , " location" );
1734
- paramList.emplace_back (" ::mlir::ValueRange" , " operands" );
1735
- paramList.emplace_back (" ::mlir::DictionaryAttr" , " attributes" );
1736
- paramList.emplace_back (" ::mlir::RegionRange" , " regions" );
1737
- paramList.emplace_back (" ::llvm::SmallVectorImpl<::mlir::Type>&" ,
1738
- " inferredReturnTypes" );
1739
- auto *method =
1740
- opClass.addMethodAndPrune (" ::mlir::LogicalResult" , " inferReturnTypes" ,
1741
- OpMethod::MP_Static, std::move (paramList));
1741
+ // Generate 'inferReturnTypes' method declaration using the interface method
1742
+ // declared in 'InferTypeOpInterface' op interface.
1743
+ const auto *trait = dyn_cast<InterfaceOpTrait>(
1744
+ op.getTrait (" ::mlir::InferTypeOpInterface::Trait" ));
1745
+ auto interface = trait->getOpInterface ();
1746
+ OpMethod *method = [&]() -> OpMethod * {
1747
+ for (const InterfaceMethod &interfaceMethod : interface.getMethods ()) {
1748
+ if (interfaceMethod.getName () == " inferReturnTypes" ) {
1749
+ return genOpInterfaceMethod (interfaceMethod, /* declaration=*/ false );
1750
+ }
1751
+ }
1752
+ assert (0 && " unable to find inferReturnTypes interface method" );
1753
+ return nullptr ;
1754
+ }();
1742
1755
auto &body = method->body ();
1743
1756
body << " inferredReturnTypes.resize(" << op.getNumResults () << " );\n " ;
1744
1757
0 commit comments