Skip to content

Commit 1bd1eda

Browse files
committed
[mlir:ODS] Support using attributes in AllTypesMatch to automatically add InferTypeOpInterface
This allows for using attribute types in result type inference for use with InferTypeOpInterface. This was a TODO before, but it isn't much additional work to properly support this. After this commit, arith::ConstantOp can now have its InferTypeOpInterface implementation automatically generated. Differential Revision: https://reviews.llvm.org/D124580
1 parent 53f775b commit 1bd1eda

File tree

11 files changed

+125
-80
lines changed

11 files changed

+125
-80
lines changed

mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
124124
def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
125125
[ConstantLike, NoSideEffect,
126126
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
127-
TypesMatchWith<
128-
"result and attribute have the same type",
129-
"value", "result", "$_self">]> {
127+
AllTypesMatch<["value", "result"]>]> {
130128
let summary = "integer or floating point constant";
131129
let description = [{
132130
The `constant` operation produces an SSA value equal to some integer or
@@ -154,8 +152,6 @@ def Arith_ConstantOp : Op<Arithmetic_Dialect, "constant",
154152
let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result);
155153

156154
let builders = [
157-
OpBuilder<(ins "Attribute":$value),
158-
[{ build($_builder, $_state, value.getType(), value); }]>,
159155
OpBuilder<(ins "Attribute":$value, "Type":$type),
160156
[{ build($_builder, $_state, type, value); }]>,
161157
];

mlir/include/mlir/TableGen/CodeGenHelpers.h

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -187,19 +187,9 @@ class StaticVerifierFunctionEmitter {
187187
/// ensure that the static functions have a unique name.
188188
std::string uniqueOutputLabel;
189189

190-
/// Unique constraints by their predicate and summary. Constraints that share
191-
/// the same predicate may have different descriptions; ensure that the
192-
/// correct error message is reported when verification fails.
193-
struct ConstraintUniquer {
194-
static Constraint getEmptyKey();
195-
static Constraint getTombstoneKey();
196-
static unsigned getHashValue(Constraint constraint);
197-
static bool isEqual(Constraint lhs, Constraint rhs);
198-
};
199190
/// Use a MapVector to ensure that functions are generated deterministically.
200-
using ConstraintMap =
201-
llvm::MapVector<Constraint, std::string,
202-
llvm::DenseMap<Constraint, unsigned, ConstraintUniquer>>;
191+
using ConstraintMap = llvm::MapVector<Constraint, std::string,
192+
llvm::DenseMap<Constraint, unsigned>>;
203193

204194
/// A generic function to emit constraints
205195
void emitConstraints(const ConstraintMap &constraints, StringRef selfName,

mlir/include/mlir/TableGen/Constraint.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,20 @@ struct AppliedConstraint {
9494
} // namespace tblgen
9595
} // namespace mlir
9696

97+
namespace llvm {
98+
/// Unique constraints by their predicate and summary. Constraints that share
99+
/// the same predicate may have different descriptions; ensure that the
100+
/// correct error message is reported when verification fails.
101+
template <>
102+
struct DenseMapInfo<mlir::tblgen::Constraint> {
103+
using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
104+
105+
static mlir::tblgen::Constraint getEmptyKey();
106+
static mlir::tblgen::Constraint getTombstoneKey();
107+
static unsigned getHashValue(mlir::tblgen::Constraint constraint);
108+
static bool isEqual(mlir::tblgen::Constraint lhs,
109+
mlir::tblgen::Constraint rhs);
110+
};
111+
} // namespace llvm
112+
97113
#endif // MLIR_TABLEGEN_CONSTRAINT_H_

mlir/lib/TableGen/Constraint.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,34 @@ AppliedConstraint::AppliedConstraint(Constraint &&constraint,
108108
std::vector<std::string> &&entities)
109109
: constraint(constraint), self(std::string(self)),
110110
entities(std::move(entities)) {}
111+
112+
Constraint DenseMapInfo<Constraint>::getEmptyKey() {
113+
return Constraint(RecordDenseMapInfo::getEmptyKey(),
114+
Constraint::CK_Uncategorized);
115+
}
116+
117+
Constraint DenseMapInfo<Constraint>::getTombstoneKey() {
118+
return Constraint(RecordDenseMapInfo::getTombstoneKey(),
119+
Constraint::CK_Uncategorized);
120+
}
121+
122+
unsigned DenseMapInfo<Constraint>::getHashValue(Constraint constraint) {
123+
if (constraint == getEmptyKey())
124+
return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
125+
if (constraint == getTombstoneKey()) {
126+
return RecordDenseMapInfo::getHashValue(
127+
RecordDenseMapInfo::getTombstoneKey());
128+
}
129+
return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
130+
}
131+
132+
bool DenseMapInfo<Constraint>::isEqual(Constraint lhs, Constraint rhs) {
133+
if (lhs == rhs)
134+
return true;
135+
if (lhs == getEmptyKey() || lhs == getTombstoneKey())
136+
return false;
137+
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
138+
return false;
139+
return lhs.getPredicate() == rhs.getPredicate() &&
140+
lhs.getSummary() == rhs.getSummary();
141+
}

mlir/lib/TableGen/Operator.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,10 +357,6 @@ void Operator::populateTypeInferenceInfo(
357357
continue;
358358
}
359359

360-
if (getArg(*mi).is<NamedAttribute *>()) {
361-
// TODO: Handle attributes.
362-
continue;
363-
}
364360
resultTypeMapping[i].emplace_back(*mi);
365361
found = true;
366362
}

mlir/python/mlir/dialects/_arith_ops_ext.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ def __init__(self,
4141
loc=None,
4242
ip=None):
4343
if isinstance(value, int):
44-
super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip)
44+
super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip)
4545
elif isinstance(value, float):
46-
super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip)
46+
super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip)
4747
else:
48-
super().__init__(result, value, loc=loc, ip=ip)
48+
super().__init__(value, loc=loc, ip=ip)
4949

5050
@classmethod
5151
def create_index(cls, value: int, *, loc=None, ip=None):

mlir/test/Dialect/Arithmetic/invalid.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func.func @non_signless_constant() {
2525
// -----
2626

2727
func.func @complex_constant_wrong_attribute_type() {
28-
// expected-error @+1 {{'arith.constant' op failed to verify that result and attribute have the same type}}
28+
// expected-error @+1 {{'arith.constant' op failed to verify that all of {value, result} have same type}}
2929
%0 = "arith.constant" () {value = 1.0 : f32} : () -> complex<f32>
3030
return
3131
}
@@ -50,23 +50,23 @@ func.func @bitcast_different_bit_widths(%arg : f16) -> f32 {
5050

5151
func.func @constant() {
5252
^bb:
53-
%x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
53+
%x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
5454
return
5555
}
5656

5757
// -----
5858

5959
func.func @constant_out_of_range() {
6060
^bb:
61-
%x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
61+
%x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
6262
return
6363
}
6464

6565
// -----
6666

6767
func.func @constant_wrong_type() {
6868
^bb:
69-
%x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}}
69+
%x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}}
7070
return
7171
}
7272

mlir/test/IR/diagnostic-handler.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
// Emit the first available call stack in the fused location.
77
func.func @constant_out_of_range() {
8-
// CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that result and attribute have the same type
8+
// CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that all of {value, result} have same type
99
// CHECK-NEXT: mysource2:1:0: note: called from
1010
// CHECK-NEXT: mysource3:2:0: note: called from
1111
%x = "arith.constant"() {value = 100} : () -> i1 loc(fused["bar", callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))])

mlir/test/mlir-tblgen/op-result.td

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
123123

124124
// CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
125125
// CHECK-NOT: }
126-
// CHECK: inferredReturnTypes[0] = operands[0].getType();
126+
// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
127+
// CHECK: inferredReturnTypes[0] = odsInferredType0;
127128

128129
def OpL2 : NS_Op<"op_with_all_types_constraint",
129130
[AllTypesMatch<["c", "b"]>, AllTypesMatch<["a", "d"]>]> {
@@ -133,5 +134,18 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
133134

134135
// CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
135136
// CHECK-NOT: }
136-
// CHECK: inferredReturnTypes[0] = operands[2].getType();
137-
// CHECK: inferredReturnTypes[1] = operands[0].getType();
137+
// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
138+
// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
139+
// CHECK: inferredReturnTypes[0] = odsInferredType0;
140+
// CHECK: inferredReturnTypes[1] = odsInferredType1;
141+
142+
def OpL3 : NS_Op<"op_with_all_types_constraint",
143+
[AllTypesMatch<["a", "b"]>]> {
144+
let arguments = (ins I32Attr:$a);
145+
let results = (outs AnyType:$b);
146+
}
147+
148+
// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes
149+
// CHECK-NOT: }
150+
// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType();
151+
// CHECK: inferredReturnTypes[0] = odsInferredType0;

mlir/tools/mlir-tblgen/CodeGenHelpers.cpp

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -234,41 +234,6 @@ void StaticVerifierFunctionEmitter::emitPatternConstraints() {
234234
//===----------------------------------------------------------------------===//
235235
// Constraint Uniquing
236236

237-
using RecordDenseMapInfo = llvm::DenseMapInfo<const llvm::Record *>;
238-
239-
Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() {
240-
return Constraint(RecordDenseMapInfo::getEmptyKey(),
241-
Constraint::CK_Uncategorized);
242-
}
243-
244-
Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() {
245-
return Constraint(RecordDenseMapInfo::getTombstoneKey(),
246-
Constraint::CK_Uncategorized);
247-
}
248-
249-
unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue(
250-
Constraint constraint) {
251-
if (constraint == getEmptyKey())
252-
return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey());
253-
if (constraint == getTombstoneKey()) {
254-
return RecordDenseMapInfo::getHashValue(
255-
RecordDenseMapInfo::getTombstoneKey());
256-
}
257-
return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary());
258-
}
259-
260-
bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs,
261-
Constraint rhs) {
262-
if (lhs == rhs)
263-
return true;
264-
if (lhs == getEmptyKey() || lhs == getTombstoneKey())
265-
return false;
266-
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
267-
return false;
268-
return lhs.getPredicate() == rhs.getPredicate() &&
269-
lhs.getSummary() == rhs.getSummary();
270-
}
271-
272237
/// An attribute constraint that references anything other than itself and the
273238
/// current op cannot be generically extracted into a function. Most
274239
/// prohibitive are operands and results, which require calls to

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,23 +2336,60 @@ void OpEmitter::genTypeInterfaceMethods() {
23362336
fctx.withBuilder("odsBuilder");
23372337
body << " ::mlir::Builder odsBuilder(context);\n";
23382338

2339-
auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & {
2340-
if (!type.isArg())
2341-
return body << tgfmt(*type.getType().getBuilderCall(), &fctx);
2342-
auto argIndex = type.getArg();
2343-
assert(!op.getArg(argIndex).is<NamedAttribute *>());
2339+
// Preprocess the result types and build all of the types used during
2340+
// inferrence. This limits the amount of duplicated work when a type is used
2341+
// to infer multiple others.
2342+
llvm::DenseMap<Constraint, int> constraintsTypes;
2343+
llvm::DenseMap<int, int> argumentsTypes;
2344+
int inferredTypeIdx = 0;
2345+
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2346+
auto type = op.getSameTypeAsResult(i).front();
2347+
2348+
// If the type isn't an argument, it refers to a buildable type.
2349+
if (!type.isArg()) {
2350+
auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx);
2351+
if (!it.second)
2352+
continue;
2353+
2354+
// If we haven't seen this constraint, generate a variable for it.
2355+
body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
2356+
<< tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n";
2357+
continue;
2358+
}
2359+
2360+
// Otherwise, this is an argument.
2361+
int argIndex = type.getArg();
2362+
auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx);
2363+
if (!it.second)
2364+
continue;
2365+
body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = ";
2366+
2367+
// If this is an operand, just index into operand list to access the type.
23442368
auto arg = op.getArgToOperandOrAttribute(argIndex);
2345-
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand)
2346-
return body << "operands[" << arg.operandOrAttributeIndex()
2347-
<< "].getType()";
2348-
return body << "attributes[" << arg.operandOrAttributeIndex()
2349-
<< "].getType()";
2350-
};
2369+
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
2370+
body << "operands[" << arg.operandOrAttributeIndex() << "].getType()";
2371+
2372+
// If this is an attribute, index into the attribute dictionary.
2373+
} else {
2374+
auto *attr =
2375+
op.getArg(arg.operandOrAttributeIndex()).get<NamedAttribute *>();
2376+
body << "attributes.get(\"" << attr->name << "\").getType()";
2377+
}
2378+
body << ";\n";
2379+
}
23512380

2381+
// Perform a second pass that handles assigning the inferred types to the
2382+
// results.
23522383
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
2353-
body << " inferredReturnTypes[" << i << "] = ";
23542384
auto types = op.getSameTypeAsResult(i);
2355-
emitType(types[0]) << ";\n";
2385+
2386+
// Append the inferred type.
2387+
auto type = types.front();
2388+
body << " inferredReturnTypes[" << i << "] = odsInferredType"
2389+
<< (type.isArg() ? argumentsTypes[type.getArg()]
2390+
: constraintsTypes[type.getType()])
2391+
<< ";\n";
2392+
23562393
if (types.size() == 1)
23572394
continue;
23582395
// TODO: We could verify equality here, but skipping that for verification.

0 commit comments

Comments
 (0)