Skip to content

Commit 398f04a

Browse files
jpienaartensorflower-gardener
authored andcommitted
Generate builder for ops that use InferTypeOpInterface trait in ODS
For ops with infer type op interface defined, generate version that calls the inferal method on build. This is intermediate step to removing special casing of SameOperandsAndResultType & FirstAttrDereivedResultType. After that would be generating the inference code, with the initial focus on shaped container types. In between I plan to refactor these a bit to reuse generated paths. The intention would not be to add the type inference trait in multiple places, but rather to take advantage of the current modelling in ODS where possible to emit it instead. Switch the `inferReturnTypes` method to be static. Skipping ops with regions here as I don't like the Region vs unique_ptr<Region> difference at the moment, and I want the infer return type trait to be useful for verification too. So instead, just skip it for now to avoid churn. PiperOrigin-RevId: 284217913
1 parent e216a72 commit 398f04a

File tree

4 files changed

+66
-16
lines changed

4 files changed

+66
-16
lines changed

mlir/include/mlir/Analysis/InferTypeOpInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
3737
}];
3838

3939
let methods = [
40-
InterfaceMethod<
40+
StaticInterfaceMethod<
4141
/*desc=*/[{Returns the return types that an op would generate.
4242

4343
The method takes an optional location which, if set, will be used to

mlir/test/lib/TestDialect/TestPatterns.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,28 @@ struct ReturnTypeOpMatch : public RewritePattern {
7373
PatternMatchResult matchAndRewrite(Operation *op,
7474
PatternRewriter &rewriter) const final {
7575
if (auto retTypeFn = dyn_cast<InferTypeOpInterface>(op)) {
76-
SmallVector<Value *, 4> values;
77-
values.reserve(op->getNumOperands());
78-
for (auto &operand : op->getOpOperands())
79-
values.push_back(operand.get());
76+
SmallVector<Value *, 4> values(op->getOperands());
8077
auto res = retTypeFn.inferReturnTypes(op->getLoc(), values,
8178
op->getAttrs(), op->getRegions());
8279
SmallVector<Type, 1> result_types(op->getResultTypes());
8380
if (!retTypeFn.isCompatibleReturnTypes(res, result_types))
8481
return op->emitOpError(
8582
"inferred type incompatible with return type of operation"),
8683
matchFailure();
84+
85+
// TODO(jpienaar): Split this out to make the test more focused.
86+
// Create new op with unknown location to verify building with
87+
// InferTypeOpInterface is triggered.
88+
auto fop = op->getParentOfType<FuncOp>();
89+
if (values[0] == fop.getArgument(0)) {
90+
// Use the 2nd function argument if the first function argument is used
91+
// when constructing the new op so that a new return type is inferred.
92+
values[0] = fop.getArgument(1);
93+
values[1] = fop.getArgument(1);
94+
// TODO(jpienaar): Expand to regions.
95+
rewriter.create<OpWithInferTypeInterfaceOp>(
96+
UnknownLoc::get(op->getContext()), values, op->getAttrs());
97+
}
8798
}
8899
return matchFailure();
89100
}

mlir/test/mlir-tblgen/return-types.mlir

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
// RUN: mlir-opt %s -test-return-type -split-input-file -verify-diagnostics | FileCheck %s --dump-input-on-failure
22

33
// CHECK-LABEL: testReturnTypeOpInterface
4-
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
4+
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>, %arg1 : tensor<20xi32>) {
55
%good = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
6+
// CHECK: test.op_with_infer_type_if
7+
// CHECK-SAME: tensor<20xi32>
8+
// CHECK: test.op_with_infer_type_if
9+
// CHECK-SAME: tensor<10xf32>
10+
return
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: testReturnTypeOpInterface
16+
func @testReturnTypeOpInterface(%arg0 : tensor<10xf32>) {
617
// expected-error@+1 {{incompatible with return type}}
718
%bad = "test.op_with_infer_type_if"(%arg0, %arg0) : (tensor<10xf32>, tensor<10xf32>) -> tensor<*xf32>
819
return

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,11 @@ class OpEmitter {
541541
// operand's type as all results' types.
542542
void genUseOperandAsResultTypeCollectiveParamBuilder();
543543

544+
// Generates the build() method that takes aggregate operands/attributes
545+
// parameters. This build() method uses inferred types as result types.
546+
// Requires: The type needs to be inferable via InferTypeOpInterface.
547+
void genInferedTypeCollectiveParamBuilder();
548+
544549
// Generates the build() method that takes each operand/attribute as a
545550
// stand-alone parameter. The generated build() method uses first attribute's
546551
// type as all result's types.
@@ -968,11 +973,6 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
968973
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
969974
auto &body = m.body();
970975

971-
// Result types
972-
SmallVector<std::string, 2> resultTypes(numResults, "operands[0]->getType()");
973-
body << " " << builderOpState << ".addTypes({"
974-
<< llvm::join(resultTypes, ", ") << "});\n\n";
975-
976976
// Operands
977977
body << " " << builderOpState << ".addOperands(operands);\n\n";
978978

@@ -984,6 +984,27 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
984984
for (int i = 0; i < numRegions; ++i)
985985
m.body() << " (void)" << builderOpState << ".addRegion();\n";
986986
}
987+
988+
// Result types
989+
SmallVector<std::string, 2> resultTypes(numResults, "operands[0]->getType()");
990+
body << " " << builderOpState << ".addTypes({"
991+
<< llvm::join(resultTypes, ", ") << "});\n\n";
992+
}
993+
994+
void OpEmitter::genInferedTypeCollectiveParamBuilder() {
995+
// TODO(jpienaar): Expand to support regions.
996+
std::string params =
997+
(Twine("Builder *, OperationState &") + builderOpState +
998+
", ArrayRef<Value *> operands, ArrayRef<NamedAttribute> attributes")
999+
.str();
1000+
auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static);
1001+
auto &body = m.body();
1002+
1003+
body << " " << builderOpState << ".addOperands(operands);\n\n";
1004+
body << " " << builderOpState << ".addAttributes(attributes);\n";
1005+
body << " " << builderOpState << ".addTypes(" << opClass.getClassName()
1006+
<< "::inferReturnTypes(" << builderOpState
1007+
<< ".location, operands, attributes, /*regions=*/{}));\n";
9871008
}
9881009

9891010
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
@@ -1026,15 +1047,17 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() {
10261047
} else {
10271048
resultType = "attr.second.getType()";
10281049
}
1029-
SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1030-
body << " " << builderOpState << ".addTypes({"
1031-
<< llvm::join(resultTypes, ", ") << "});\n";
1032-
body << " }\n";
10331050

10341051
// Operands
10351052
body << " " << builderOpState << ".addOperands(operands);\n\n";
10361053
// Attributes
10371054
body << " " << builderOpState << ".addAttributes(attributes);\n";
1055+
1056+
// Result types
1057+
SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType);
1058+
body << " " << builderOpState << ".addTypes({"
1059+
<< llvm::join(resultTypes, ", ") << "});\n";
1060+
body << " }\n";
10381061
}
10391062

10401063
void OpEmitter::genBuilder() {
@@ -1082,7 +1105,7 @@ void OpEmitter::genBuilder() {
10821105
genCollectiveParamBuilder();
10831106
// 4. one having a stand-alone parameter for each operand and attribute,
10841107
// use the first operand or attribute's type as all result types
1085-
// to facilitate different call patterns.
1108+
// to facilitate different call patterns.
10861109
if (op.getNumVariadicResults() == 0) {
10871110
if (op.getTrait("OpTrait::SameOperandsAndResultType")) {
10881111
genUseOperandAsResultTypeSeparateParamBuilder();
@@ -1091,6 +1114,11 @@ void OpEmitter::genBuilder() {
10911114
if (op.getTrait("OpTrait::FirstAttrDerivedResultType"))
10921115
genUseAttrAsResultTypeBuilder();
10931116
}
1117+
// TODO(jpienaar): Subsume this with general checking if type can be infered
1118+
// automatically.
1119+
// TODO(jpienaar): Expand to handle regions.
1120+
if (op.getTrait("InferTypeOpInterface::Trait") && op.getNumRegions() == 0)
1121+
genInferedTypeCollectiveParamBuilder();
10941122
}
10951123

10961124
void OpEmitter::genCollectiveParamBuilder() {

0 commit comments

Comments
 (0)