Skip to content

Commit db9df43

Browse files
committed
[mlir-tblgen] Avoid ODS verifier duplication
Different constraints may share the same predicate, in this case, we will generate duplicate ODS verification function. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D104369
1 parent a15adbc commit db9df43

File tree

3 files changed

+63
-10
lines changed

3 files changed

+63
-10
lines changed

mlir/include/mlir/TableGen/Predicate.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_TABLEGEN_PREDICATE_H_
1515

1616
#include "mlir/Support/LLVM.h"
17+
#include "llvm/ADT/Hashing.h"
1718

1819
#include <string>
1920
#include <vector>
@@ -59,6 +60,8 @@ class Pred {
5960
ArrayRef<llvm::SMLoc> getLoc() const;
6061

6162
protected:
63+
friend llvm::DenseMapInfo<Pred>;
64+
6265
// The TableGen definition of this predicate.
6366
const llvm::Record *def;
6467
};
@@ -116,4 +119,18 @@ class ConcatPred : public CombinedPred {
116119
} // end namespace tblgen
117120
} // end namespace mlir
118121

122+
namespace llvm {
123+
template <>
124+
struct DenseMapInfo<mlir::tblgen::Pred> {
125+
static mlir::tblgen::Pred getEmptyKey() { return mlir::tblgen::Pred(); }
126+
static mlir::tblgen::Pred getTombstoneKey() { return mlir::tblgen::Pred(); }
127+
static unsigned getHashValue(mlir::tblgen::Pred pred) {
128+
return llvm::hash_value(pred.def);
129+
}
130+
static bool isEqual(mlir::tblgen::Pred lhs, mlir::tblgen::Pred rhs) {
131+
return lhs == rhs;
132+
}
133+
};
134+
} // end namespace llvm
135+
119136
#endif // MLIR_TABLEGEN_PREDICATE_H_

mlir/test/mlir-tblgen/predicate.td

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,24 @@ def I32OrF32 : Type<CPred<"$_self.isInteger(32) || $_self.isF32()">,
1313

1414
def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> {
1515
let arguments = (ins I32OrF32:$x);
16+
let results = (outs Variadic<I32OrF32>:$y);
1617
}
1718

1819
// CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
19-
// CHECK: if (!((type.isInteger(32) || type.isF32()))) {
20-
// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
20+
// CHECK-NEXT: if (!((type.isInteger(32) || type.isF32()))) {
21+
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
22+
23+
// Check there is no verifier with same predicate generated.
24+
// CHECK-NOT: if (!((type.isInteger(32) || type.isF32()))) {
25+
// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type;
2126

2227
// CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
23-
// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
24-
// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
28+
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) {
29+
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type;
2530

2631
// CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]](
27-
// CHECK: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
28-
// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
32+
// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) {
33+
// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type;
2934

3035
// CHECK-LABEL: OpA::verify
3136
// CHECK: auto valueGroup0 = getODSOperands(0);

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,19 +216,50 @@ void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
216216
typeConstraints.insert(result.constraint.getAsOpaquePointer());
217217
}
218218

219+
// Record the mapping from predicate to constraint. If two constraints has the
220+
// same predicate and constraint summary, they can share the same verification
221+
// function.
222+
llvm::DenseMap<Pred, const void *> predToConstraint;
219223
FmtContext fctx;
220224
for (auto it : llvm::enumerate(typeConstraints)) {
225+
std::string name;
226+
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
227+
Pred pred = constraint.getPredicate();
228+
auto iter = predToConstraint.find(pred);
229+
if (iter != predToConstraint.end()) {
230+
do {
231+
Constraint built = Constraint::getFromOpaquePointer(iter->second);
232+
// We may have the different constraints but have the same predicate,
233+
// for example, ConstraintA and Variadic<ConstraintA>, note that
234+
// Variadic<> doesn't introduce new predicate. In this case, we can
235+
// share the same predicate function if they also have consistent
236+
// summary, otherwise we may report the wrong message while verification
237+
// fails.
238+
if (constraint.getSummary() == built.getSummary()) {
239+
name = getTypeConstraintFn(built).str();
240+
break;
241+
}
242+
++iter;
243+
} while (iter != predToConstraint.end() && iter->first == pred);
244+
}
245+
246+
if (!name.empty()) {
247+
localTypeConstraints.try_emplace(it.value(), name);
248+
continue;
249+
}
250+
221251
// Generate an obscure and unique name for this type constraint.
222-
std::string name = (Twine("__mlir_ods_local_type_constraint_") +
223-
uniqueOutputLabel + Twine(it.index()))
224-
.str();
252+
name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
253+
Twine(it.index()))
254+
.str();
255+
predToConstraint.insert(
256+
std::make_pair(constraint.getPredicate(), it.value()));
225257
localTypeConstraints.try_emplace(it.value(), name);
226258

227259
// Only generate the methods if we are generating definitions.
228260
if (emitDecl)
229261
continue;
230262

231-
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
232263
os << "static ::mlir::LogicalResult " << name
233264
<< "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
234265
"valueKind, unsigned valueGroupStartIndex) {\n";

0 commit comments

Comments
 (0)