Skip to content

Commit 83e7d34

Browse files
[mlir][ODS] Optionally generate public C++ functions for type constraints
Add `gen-type-constraint-decls` and `gen-type-constraint-defs`, which generate public C++ functions for type constraints. The name of the C++ function is specified in the `cppFunctionName` field. Type constraints are typically used for op/type/attribute verification. They are also sometimes called from builders and transformations. Until now, this required duplicating the check in C++. Note: This commit just adds the option for type constraints, but attribute constraints could be supported in the same way. Alternatives considered: 1. The C++ functions could also be generated as part of `gen-typedef-decls/defs`, but that can be confusing because type constraints may rely on type definitions from multiple `.td` files. 2. The C++ functions could also be generated as static member functions of dialects, but they don't really belong to a dialect. (Because they may rely on type definitions from multiple dialects.)
1 parent 846f790 commit 83e7d34

File tree

10 files changed

+118
-11
lines changed

10 files changed

+118
-11
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
198198
#include "mlir/IR/BuiltinTypes.h.inc"
199199

200200
namespace mlir {
201+
#include "mlir/IR/BuiltinTypeConstraints.h.inc"
201202

202203
//===----------------------------------------------------------------------===//
203204
// MemRefType

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,10 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
10971097
// VectorType
10981098
//===----------------------------------------------------------------------===//
10991099

1100+
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
1101+
let cppFunctionName = "isValidVectorTypeElementType";
1102+
}
1103+
11001104
def Builtin_Vector : Builtin_Type<"Vector", "vector",
11011105
[ShapedTypeInterface, ValueSemantics], "Type"> {
11021106
let summary = "Multi-dimensional SIMD vector type";
@@ -1147,7 +1151,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
11471151
}];
11481152
let parameters = (ins
11491153
ArrayRefParameter<"int64_t">:$shape,
1150-
AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
1154+
Builtin_VectorTypeElementType:$elementType,
11511155
ArrayRefParameter<"bool">:$scalableDims
11521156
);
11531157
let builders = [
@@ -1171,12 +1175,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
11711175
class Builder;
11721176

11731177
/// Returns true if the given type can be used as an element of a vector
1174-
/// type. In particular, vectors can consist of integer, index, or float
1175-
/// primitives.
1176-
static bool isValidElementType(Type t) {
1177-
// TODO: Auto-generate this function from $elementType.
1178-
return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
1179-
}
1178+
/// type. See "Builtin_VectorTypeElementType" for allowed types.
1179+
static bool isValidElementType(Type t);
11801180

11811181
/// Returns true if the vector contains scalable dimensions.
11821182
bool isScalable() const {

mlir/include/mlir/IR/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
3535
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
3636
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
3737
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
38+
mlir_tablegen(BuiltinTypeConstraints.h.inc -gen-type-constraint-decls)
39+
mlir_tablegen(BuiltinTypeConstraints.cpp.inc -gen-type-constraint-defs)
40+
add_public_tablegen_target(MLIRBuiltinTypeConstraintsIncGen)
3841

3942
set(LLVM_TARGET_DEFINITIONS BuiltinTypeInterfaces.td)
4043
mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)

mlir/include/mlir/IR/Constraints.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,14 @@ class Constraint<Pred pred, string desc = ""> {
149149

150150
// Subclass for constraints on a type.
151151
class TypeConstraint<Pred predicate, string summary = "",
152-
string cppTypeParam = "::mlir::Type"> :
152+
string cppTypeParam = "::mlir::Type",
153+
string cppFunctionNameParam = ""> :
153154
Constraint<predicate, summary> {
154155
// The name of the C++ Type class if known, or Type if not.
155156
string cppType = cppTypeParam;
157+
// The name of the C++ function that is generated for this type constraint.
158+
// If empty, no C++ function is generated.
159+
string cppFunctionName = cppFunctionNameParam;
156160
}
157161

158162
// Subclass for constraints on an attribute.

mlir/include/mlir/TableGen/Constraint.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class Constraint {
6969
/// context on the def).
7070
std::string getUniqueDefName() const;
7171

72+
/// Returns the name of the C++ function that should be generated for this
73+
/// constraint, or std::nullopt if no C++ function should be generated.
74+
std::optional<StringRef> getCppFunctionName() const;
75+
7276
Kind getKind() const { return kind; }
7377

7478
/// Return the underlying def.

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ using namespace mlir::detail;
3232
#define GET_TYPEDEF_CLASSES
3333
#include "mlir/IR/BuiltinTypes.cpp.inc"
3434

35+
namespace mlir {
36+
#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
37+
} // namespace mlir
38+
3539
//===----------------------------------------------------------------------===//
3640
// BuiltinDialect
3741
//===----------------------------------------------------------------------===//
@@ -230,6 +234,10 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
230234
// VectorType
231235
//===----------------------------------------------------------------------===//
232236

237+
bool VectorType::isValidElementType(Type t) {
238+
return succeeded(isValidVectorTypeElementType(t));
239+
}
240+
233241
LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
234242
ArrayRef<int64_t> shape, Type elementType,
235243
ArrayRef<bool> scalableDims) {
@@ -278,7 +286,9 @@ Type TensorType::getElementType() const {
278286
[](auto type) { return type.getElementType(); });
279287
}
280288

281-
bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
289+
bool TensorType::hasRank() const {
290+
return !llvm::isa<UnrankedTensorType>(*this);
291+
}
282292

283293
ArrayRef<int64_t> TensorType::getShape() const {
284294
return llvm::cast<RankedTensorType>(*this).getShape();
@@ -365,7 +375,9 @@ Type BaseMemRefType::getElementType() const {
365375
[](auto type) { return type.getElementType(); });
366376
}
367377

368-
bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
378+
bool BaseMemRefType::hasRank() const {
379+
return !llvm::isa<UnrankedMemRefType>(*this);
380+
}
369381

370382
ArrayRef<int64_t> BaseMemRefType::getShape() const {
371383
return llvm::cast<MemRefType>(*this).getShape();

mlir/lib/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ add_mlir_library(MLIRIR
5555
MLIRBuiltinLocationAttributesIncGen
5656
MLIRBuiltinOpsIncGen
5757
MLIRBuiltinTypesIncGen
58+
MLIRBuiltinTypeConstraintsIncGen
5859
MLIRBuiltinTypeInterfacesIncGen
5960
MLIRCallInterfacesIncGen
6061
MLIRCastInterfacesIncGen

mlir/lib/TableGen/Constraint.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Constraint::Constraint(const llvm::Record *record)
3030
kind = CK_Region;
3131
} else if (def->isSubClassOf("SuccessorConstraint")) {
3232
kind = CK_Successor;
33-
} else if(!def->isSubClassOf("Constraint")) {
33+
} else if (!def->isSubClassOf("Constraint")) {
3434
llvm::errs() << "Expected a constraint but got: \n" << *def << "\n";
3535
llvm::report_fatal_error("Abort");
3636
}
@@ -109,6 +109,14 @@ std::optional<StringRef> Constraint::getBaseDefName() const {
109109
}
110110
}
111111

112+
std::optional<StringRef> Constraint::getCppFunctionName() const {
113+
std::optional<StringRef> name =
114+
def->getValueAsOptionalString("cppFunctionName");
115+
if (!name || *name == "")
116+
return std::nullopt;
117+
return name;
118+
}
119+
112120
AppliedConstraint::AppliedConstraint(Constraint &&constraint,
113121
llvm::StringRef self,
114122
std::vector<std::string> &&entities)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-tblgen -gen-type-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
2+
// RUN: mlir-tblgen -gen-type-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
3+
4+
include "mlir/IR/CommonTypeConstraints.td"
5+
6+
def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
7+
let cppFunctionName = "isValidDummy";
8+
}
9+
10+
// DECL: ::llvm::LogicalResult isValidDummy(::mlir::Type type);
11+
12+
// DEF: ::llvm::LogicalResult isValidDummy(::mlir::Type type) {
13+
// DEF: return ::llvm::success((((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type)))));
14+
// DEF: }

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,51 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
10231023
return false;
10241024
}
10251025

1026+
//===----------------------------------------------------------------------===//
1027+
// Type Constraints
1028+
//===----------------------------------------------------------------------===//
1029+
1030+
static const char *const typeConstraintDecl = R"(
1031+
::llvm::LogicalResult {0}(::mlir::Type type);
1032+
)";
1033+
1034+
static const char *const typeConstraintDef = R"(
1035+
::llvm::LogicalResult {0}(::mlir::Type type) {
1036+
return ::llvm::success(({1}));
1037+
}
1038+
)";
1039+
1040+
/// Find all type constraints for which a C++ function should be generated.
1041+
static std::vector<Constraint>
1042+
getAllTypeConstraints(const llvm::RecordKeeper &records) {
1043+
std::vector<Constraint> result;
1044+
for (llvm::Record *def :
1045+
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
1046+
Constraint constr(def);
1047+
// Generate C++ function only if "cppFunctionName" is set.
1048+
if (!constr.getCppFunctionName())
1049+
continue;
1050+
result.push_back(constr);
1051+
}
1052+
return result;
1053+
}
1054+
1055+
static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
1056+
raw_ostream &os) {
1057+
for (Constraint constr : getAllTypeConstraints(records))
1058+
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
1059+
}
1060+
1061+
static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
1062+
raw_ostream &os) {
1063+
for (Constraint constr : getAllTypeConstraints(records)) {
1064+
FmtContext ctx;
1065+
ctx.withSelf("type");
1066+
std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
1067+
os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
1068+
}
1069+
}
1070+
10261071
//===----------------------------------------------------------------------===//
10271072
// GEN: Registration hooks
10281073
//===----------------------------------------------------------------------===//
@@ -1070,3 +1115,18 @@ static mlir::GenRegistration
10701115
TypeDefGenerator generator(records, os);
10711116
return generator.emitDecls(typeDialect);
10721117
});
1118+
1119+
static mlir::GenRegistration
1120+
genTypeConstrDefs("gen-type-constraint-defs",
1121+
"Generate type constraint definitions",
1122+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1123+
emitTypeConstraintDefs(records, os);
1124+
return false;
1125+
});
1126+
static mlir::GenRegistration
1127+
genTypeConstrDecls("gen-type-constraint-decls",
1128+
"Generate type constraint declarations",
1129+
[](const llvm::RecordKeeper &records, raw_ostream &os) {
1130+
emitTypeConstraintDecls(records, os);
1131+
return false;
1132+
});

0 commit comments

Comments
 (0)