Skip to content

Commit 92a836d

Browse files
committed
[mlir] Attach InferTypeOpInterface on SameOperandsAndResultType operations when possible
This allows for inferring the result types of operations in certain situations by using the type of an operand. This commit allowed for automatically supporting type inference for many more operations with no additional effort, e.g. nearly all Arithmetic operations now support result type inferrence with no additional changes. Differential Revision: https://reviews.llvm.org/D124581
1 parent 1bd1eda commit 92a836d

File tree

13 files changed

+47
-30
lines changed

13 files changed

+47
-30
lines changed

mlir/examples/standalone/include/Standalone/StandaloneOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/BuiltinTypes.h"
1313
#include "mlir/IR/Dialect.h"
1414
#include "mlir/IR/OpDefinition.h"
15+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1516
#include "mlir/Interfaces/SideEffectInterfaces.h"
1617

1718
#define GET_OP_CLASSES

mlir/examples/standalone/lib/Standalone/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ add_mlir_dialect_library(MLIRStandalone
1010

1111
LINK_LIBS PUBLIC
1212
MLIRIR
13+
MLIRInferTypeOpInterface
1314
)

mlir/include/mlir/Dialect/Math/IR/Math.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/IR/Dialect.h"
1414
#include "mlir/IR/OpDefinition.h"
1515
#include "mlir/IR/OpImplementation.h"
16+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1617
#include "mlir/Interfaces/SideEffectInterfaces.h"
1718
#include "mlir/Interfaces/VectorInterfaces.h"
1819

mlir/include/mlir/Dialect/Quant/QuantOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/Dialect.h"
1616
#include "mlir/IR/OpDefinition.h"
1717
#include "mlir/IR/Types.h"
18+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1819
#include "mlir/Interfaces/SideEffectInterfaces.h"
1920
#include "llvm/Support/MathExtras.h"
2021

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/OpDefinition.h"
1515
#include "mlir/IR/OpImplementation.h"
1616
#include "mlir/IR/TensorEncoding.h"
17+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1718
#include "mlir/Interfaces/SideEffectInterfaces.h"
1819

1920
#define GET_ATTRDEF_CLASSES

mlir/lib/Dialect/Quant/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ add_mlir_dialect_library(MLIRQuant
1212

1313
LINK_LIBS PUBLIC
1414
MLIRIR
15+
MLIRInferTypeOpInterface
1516
MLIRSideEffectInterfaces
1617
MLIRSupport
1718
)

mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ add_mlir_dialect_library(MLIRSparseTensor
1111
LINK_LIBS PUBLIC
1212
MLIRDialect
1313
MLIRIR
14+
MLIRInferTypeOpInterface
1415
MLIRSupport
1516
)

mlir/lib/TableGen/Operator.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,25 @@ void Operator::populateTypeInferenceInfo(
333333

334334
// Skip cases currently being custom generated.
335335
// TODO: Remove special cases.
336-
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType"))
336+
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
337+
// Check for a non-variable length operand to use as the type anchor.
338+
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
339+
NamedTypeConstraint *operand = arg.dyn_cast<NamedTypeConstraint *>();
340+
return operand && !operand->isVariableLength();
341+
});
342+
if (operandI == arguments.end())
343+
return;
344+
345+
// Map each of the result types to the anchor operation.
346+
int operandIdx = operandI - arguments.begin();
347+
resultTypeMapping.resize(getNumResults());
348+
for (int i = 0; i < getNumResults(); ++i)
349+
resultTypeMapping[i].emplace_back(operandIdx);
350+
351+
allResultsHaveKnownTypes = true;
352+
traits.push_back(Trait::create(inferTrait->getDefInit()));
337353
return;
354+
}
338355

339356
// We create equivalence classes of argument/result types where arguments
340357
// and results are mapped into the same index space and indices corresponding

mlir/test/Analysis/test-shape-fn-report.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ module attributes {shape.lib = [@shape_lib]} {
55
// expected-remark@+1 {{associated shape function: same_result_shape}}
66
func.func @tanh(%arg: tensor<10x20xf32>) -> tensor<10x20xf32>
77
attributes {shape.function = @shape_lib::@same_result_shape} {
8-
// expected-remark@+1 {{no associated way}}
8+
// expected-remark@+1 {{implements InferType op interface}}
99
%0 = math.tanh %arg : tensor<10x20xf32>
10-
// expected-remark@+1 {{associated shape function: same_result_shape}}
10+
// expected-remark@+1 {{implements InferType op interface}}
1111
%1 = "test.same_operand_result_type"(%0) : (tensor<10x20xf32>) -> tensor<10x20xf32>
1212
return %1 : tensor<10x20xf32>
1313
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2608,15 +2608,9 @@ class TableGenBuildInferReturnTypeBaseOp<string mnemonic,
26082608
}];
26092609
}
26102610

2611-
// Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface.
2612-
// Tests suppression of ambiguous build methods for operations with
2613-
// SameOperandsAndResultType and InferTypeOpInterface.
2614-
def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
2615-
"tblgen_build_5", [SameOperandsAndResultType]>;
2616-
26172611
// Op with InferTypeOpInterface and regions.
2618-
def TableGenBuildOp6 : TableGenBuildInferReturnTypeBaseOp<
2619-
"tblgen_build_6", [InferTypeOpInterface]> {
2612+
def TableGenBuildOp5 : TableGenBuildInferReturnTypeBaseOp<
2613+
"tblgen_build_5", [InferTypeOpInterface]> {
26202614
let regions = (region AnyRegion:$body);
26212615
}
26222616

mlir/test/mlir-tblgen/op-decl-and-defs.td

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> {
199199
let results = (outs AnyType:$b);
200200
}
201201

202-
// CHECK_LABEL: class NS_HCollectiveParamsOp :
202+
// CHECK_LABEL: class HCollectiveParamsOp :
203203
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a);
204204
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a);
205205
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {})
@@ -212,7 +212,7 @@ def NS_HCollectiveParamsSuppress0Op : NS_Op<"op_collective_suppress0", []> {
212212
let results = (outs Variadic<I32>:$b);
213213
}
214214

215-
// CHECK_LABEL: class NS_HCollectiveParamsSuppress0Op :
215+
// CHECK_LABEL: class HCollectiveParamsSuppress0Op :
216216
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
217217
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
218218

@@ -224,7 +224,7 @@ def NS_HCollectiveParamsSuppress1Op : NS_Op<"op_collective_suppress1", []> {
224224
let results = (outs I32:$b);
225225
}
226226

227-
// CHECK_LABEL: class NS_HCollectiveParamsSuppress1Op :
227+
// CHECK_LABEL: class HCollectiveParamsSuppress1Op :
228228
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
229229
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
230230

@@ -237,7 +237,7 @@ def NS_HCollectiveParamsSuppress2Op : NS_Op<"op_collective_suppress2", [SameVari
237237
let arguments = (ins Variadic<I32>:$a);
238238
let results = (outs Variadic<I32>:$b, Variadic<F32>:$c);
239239
}
240-
// CHECK_LABEL: class NS_HCollectiveParamsSuppress2Op :
240+
// CHECK_LABEL: class HCollectiveParamsSuppress2Op :
241241
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::TypeRange c, ::mlir::ValueRange a);
242242
// CHECK-NOT: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange b, ::mlir::ValueRange a);
243243
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
@@ -247,19 +247,19 @@ def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperands
247247
let arguments = (ins AnyType:$a, AnyType:$b);
248248
let results = (outs AnyType:$r);
249249
}
250-
// CHECK_LABEL: class NS_IOp :
250+
// CHECK_LABEL: class IOp :
251251
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
252+
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
252253
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
253254
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
254-
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
255255
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
256256

257257
// Check default value of `attributes` for the `genInferredTypeCollectiveParamBuilder` builder
258258
def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
259259
let arguments = (ins AnyType:$a, AnyType:$b);
260260
let results = (outs AnyType:$r);
261261
}
262-
// CHECK_LABEL: class NS_JOp :
262+
// CHECK_LABEL: class JOp :
263263
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b);
264264
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b);
265265
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b);
@@ -292,14 +292,14 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam
292292
let arguments = (ins AnyType:$a, AnyType:$b, I32Attr:$attr1);
293293
let results = (outs AnyType:$r);
294294
}
295-
// CHECK_LABEL: class NS_LOp :
295+
// CHECK_LABEL: class LOp :
296296
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
297+
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
297298
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
298299
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
300+
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
299301
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
300302
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
301-
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, ::mlir::IntegerAttr attr1);
302-
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value a, ::mlir::Value b, uint32_t attr1);
303303
// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
304304

305305

mlir/test/mlir-tblgen/op-result.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ def OpB : NS_Op<"same_input_output_type_op", [SameOperandsAndResultType]> {
2727
// CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type y, ::mlir::Value x)
2828
// CHECK: odsState.addTypes(y);
2929
// CHECK: void OpB::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value x)
30-
// CHECK: odsState.addTypes({x.getType()});
30+
// CHECK: ::llvm::SmallVector<::mlir::Type, 2> inferredReturnTypes;
31+
// CHECK: if (::mlir::succeeded(OpB::inferReturnTypes(odsBuilder.getContext(),
32+
// CHECK: odsState.location, odsState.operands,
33+
// CHECK: odsState.attributes.getDictionary(odsState.getContext()),
34+
// CHECK: /*regions=*/{}, inferredReturnTypes)))
35+
// CHECK: odsState.addTypes(inferredReturnTypes);
3136

3237
def OpC : NS_Op<"three_normal_result_op", []> {
3338
let results = (outs I32:$x, /*unnamed*/I32, I32:$z);

mlir/unittests/TableGen/OpBuildGen.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ TEST_F(OpBuildGenTest,
204204
verifyOp(op, {i32Ty, f32Ty}, {*cstI32}, attrs);
205205
}
206206

207-
// The next 2 tests test supression of ambiguous build methods for ops that
207+
// The next test checks supression of ambiguous build methods for ops that
208208
// have a single variadic input, and single non-variadic result, and which
209209
// support the SameOperandsAndResultType trait and and optionally the
210210
// InferOpTypeInterface interface. For such ops, the ODS framework generates
@@ -213,14 +213,8 @@ TEST_F(OpBuildGenTest, BuildMethodsSameOperandsAndResultTypeSuppression) {
213213
testSingleVariadicInputInferredType<test::TableGenBuildOp4>();
214214
}
215215

216-
TEST_F(
217-
OpBuildGenTest,
218-
BuildMethodsSameOperandsAndResultTypeAndInferOpTypeInterfaceSuppression) {
219-
testSingleVariadicInputInferredType<test::TableGenBuildOp5>();
220-
}
221-
222216
TEST_F(OpBuildGenTest, BuildMethodsRegionsAndInferredType) {
223-
auto op = builder.create<test::TableGenBuildOp6>(
217+
auto op = builder.create<test::TableGenBuildOp5>(
224218
loc, ValueRange{*cstI32, *cstF32}, /*attributes=*/noAttrs);
225219
ASSERT_EQ(op->getNumRegions(), 1u);
226220
verifyOp(op, {i32Ty}, {*cstI32, *cstF32}, noAttrs);

0 commit comments

Comments
 (0)