Skip to content

Commit 2452334

Browse files
committed
[MLIR] Generate inferReturnTypes declaration using InferTypeOpInterface trait.
- Instead of hardcoding the parameters and return types of 'inferReturnTypes', use the InferTypeOpInterface trait to generate the method declaration. - Fix InferTypeOfInterface to use fully qualified type for inferReturnTypes results. Differential Revision: https://reviews.llvm.org/D92585
1 parent 7f6f9f4 commit 2452334

File tree

2 files changed

+40
-27
lines changed

2 files changed

+40
-27
lines changed

mlir/include/mlir/Interfaces/InferTypeOpInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
3636
which an Operation would be created (e.g., as used in Operation::create)
3737
and the regions of the op.
3838
}],
39-
/*retTy=*/"LogicalResult",
39+
/*retTy=*/"::mlir::LogicalResult",
4040
/*methodName=*/"inferReturnTypes",
4141
/*args=*/(ins "::mlir::MLIRContext *":$context,
4242
"::llvm::Optional<::mlir::Location>":$location,

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,16 @@ class OpEmitter {
290290
// Generates the traits used by the object.
291291
void genTraits();
292292

293-
// Generate the OpInterface methods.
293+
// Generate the OpInterface methods for all interfaces.
294294
void genOpInterfaceMethods();
295295

296-
// Generate op interface method.
297-
void genOpInterfaceMethod(const tblgen::InterfaceOpTrait *trait);
296+
// Generate op interface methods for the given interface.
297+
void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
298+
299+
// Generate op interface method for the given interface method. If
300+
// 'declaration' is true, generates a declaration, else a definition.
301+
OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
302+
bool declaration = true);
298303

299304
// Generate the side effect interface methods.
300305
void genSideEffectInterfaceMethods();
@@ -1588,7 +1593,7 @@ void OpEmitter::genFolderDecls() {
15881593
}
15891594
}
15901595

1591-
void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
1596+
void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
15921597
auto interface = opTrait->getOpInterface();
15931598

15941599
// Get the set of methods that should always be declared.
@@ -1606,23 +1611,29 @@ void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
16061611
if (method.getDefaultImplementation() &&
16071612
!alwaysDeclaredMethods.count(method.getName()))
16081613
continue;
1609-
1610-
SmallVector<OpMethodParameter, 4> paramList;
1611-
for (const InterfaceMethod::Argument &arg : method.getArguments())
1612-
paramList.emplace_back(arg.type, arg.name);
1613-
1614-
auto properties = method.isStatic() ? OpMethod::MP_StaticDeclaration
1615-
: OpMethod::MP_Declaration;
1616-
opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
1617-
properties, std::move(paramList));
1614+
genOpInterfaceMethod(method);
16181615
}
16191616
}
16201617

1618+
OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
1619+
bool declaration) {
1620+
SmallVector<OpMethodParameter, 4> paramList;
1621+
for (const InterfaceMethod::Argument &arg : method.getArguments())
1622+
paramList.emplace_back(arg.type, arg.name);
1623+
1624+
auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
1625+
if (declaration)
1626+
properties =
1627+
static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
1628+
return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
1629+
properties, std::move(paramList));
1630+
}
1631+
16211632
void OpEmitter::genOpInterfaceMethods() {
16221633
for (const auto &trait : op.getTraits()) {
16231634
if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
16241635
if (opTrait->shouldDeclareMethods())
1625-
genOpInterfaceMethod(opTrait);
1636+
genOpInterfaceMethods(opTrait);
16261637
}
16271638
}
16281639

@@ -1727,18 +1738,20 @@ void OpEmitter::genSideEffectInterfaceMethods() {
17271738
void OpEmitter::genTypeInterfaceMethods() {
17281739
if (!op.allResultTypesKnown())
17291740
return;
1730-
1731-
SmallVector<OpMethodParameter, 4> paramList;
1732-
paramList.emplace_back("::mlir::MLIRContext *", "context");
1733-
paramList.emplace_back("::llvm::Optional<::mlir::Location>", "location");
1734-
paramList.emplace_back("::mlir::ValueRange", "operands");
1735-
paramList.emplace_back("::mlir::DictionaryAttr", "attributes");
1736-
paramList.emplace_back("::mlir::RegionRange", "regions");
1737-
paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::Type>&",
1738-
"inferredReturnTypes");
1739-
auto *method =
1740-
opClass.addMethodAndPrune("::mlir::LogicalResult", "inferReturnTypes",
1741-
OpMethod::MP_Static, std::move(paramList));
1741+
// Generate 'inferReturnTypes' method declaration using the interface method
1742+
// declared in 'InferTypeOpInterface' op interface.
1743+
const auto *trait = dyn_cast<InterfaceOpTrait>(
1744+
op.getTrait("::mlir::InferTypeOpInterface::Trait"));
1745+
auto interface = trait->getOpInterface();
1746+
OpMethod *method = [&]() -> OpMethod * {
1747+
for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
1748+
if (interfaceMethod.getName() == "inferReturnTypes") {
1749+
return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
1750+
}
1751+
}
1752+
assert(0 && "unable to find inferReturnTypes interface method");
1753+
return nullptr;
1754+
}();
17421755
auto &body = method->body();
17431756
body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";
17441757

0 commit comments

Comments
 (0)