Skip to content

Commit 9fcd14d

Browse files
authored
[MLIR][ODS] Optionally generate public C++ functions for attribute constraints (#144275)
Add `gen-attr-constraint-decls` and `gen-attr-constraint-defs`, which generate public C++ functions for attribute constraints. The name of the C++ function is specified in the `cppFunctionName` field. This generalize `cppFunctionName` from `TypeConstraint` introduced in #104577 to be usable also in `AttrConstraint`.
1 parent 222ab28 commit 9fcd14d

File tree

4 files changed

+114
-33
lines changed

4 files changed

+114
-33
lines changed

mlir/docs/DefiningDialects/Constraints.md

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ code is generated for type/attribute constraints. Type constraints can not only
2424
be used when defining operation arguments, but also when defining type
2525
parameters.
2626

27-
Optionally, C++ functions can be generated, so that type constraints can be
28-
checked from C++. The name of the C++ function must be specified in the
27+
Optionally, C++ functions can be generated, so that type/attribute constraints
28+
can be checked from C++. The name of the C++ function must be specified in the
2929
`cppFunctionName` field. If no function name is specified, no C++ function is
3030
emitted.
3131

@@ -43,17 +43,20 @@ bool isValidVectorTypeElementType(::mlir::Type type) {
4343
}
4444
```
4545
46-
An extra TableGen rule is needed to emit C++ code for type constraints. This
47-
will generate only the declarations/definitions of the type constaraints that
48-
are defined in the specified `.td` file, but not those that are in included
49-
`.td` files.
46+
An extra TableGen rule is needed to emit C++ code for type/attribute
47+
constraints. This will generate only the declarations/definitions of the
48+
type/attribute constaraints that are defined in the specified `.td` file, but
49+
not those that are in included `.td` files.
5050
5151
```cmake
5252
mlir_tablegen(<Your Dialect>TypeConstraints.h.inc -gen-type-constraint-decls)
5353
mlir_tablegen(<Your Dialect>TypeConstraints.cpp.inc -gen-type-constraint-defs)
54+
mlir_tablegen(<Your Dialect>AttrConstraints.h.inc -gen-attr-constraint-decls)
55+
mlir_tablegen(<Your Dialect>AttrConstraints.cpp.inc -gen-attr-constraint-defs)
5456
```
5557

56-
The generated `<Your Dialect>TypeConstraints.h.inc` will need to be included
57-
whereever you are referencing the type constraint in C++. Note that no C++
58-
namespace will be emitted by the code generator. The `#include` statements of
59-
the `.h.inc`/`.cpp.inc` files should be wrapped in C++ namespaces by the user.
58+
The generated `<Your Dialect>TypeConstraints.h.inc` respectivelly
59+
`<Your Dialect>AttrConstraints.h.inc` will need to be included whereever you are
60+
referencing the type/attributes constraint in C++. Note that no C++ namespace
61+
will be emitted by the code generator. The `#include` statements of the
62+
`.h.inc`/`.cpp.inc` files should be wrapped in C++ namespaces by the user.

mlir/include/mlir/IR/Constraints.td

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,15 @@ class Constraint<Pred pred, string desc = ""> {
148148
string summary = desc;
149149
}
150150

151+
// Base class for constraints on types and attributes.
152+
class AttrTypeConstraint<Pred pred, string summary = "",
153+
string cppFunctionNameParam = ""> :
154+
Constraint<pred, summary> {
155+
// The name of the C++ function that is generated for this constraint.
156+
// If empty, no C++ function is generated.
157+
string cppFunctionName = cppFunctionNameParam;
158+
}
159+
151160
// Subclasses used to differentiate different constraint kinds. These are used
152161
// as markers for the TableGen backend to handle different constraint kinds
153162
// differently if needed. Constraints not deriving from the following subclasses
@@ -157,17 +166,15 @@ class Constraint<Pred pred, string desc = ""> {
157166
class TypeConstraint<Pred predicate, string summary = "",
158167
string cppTypeParam = "::mlir::Type",
159168
string cppFunctionNameParam = ""> :
160-
Constraint<predicate, summary> {
169+
AttrTypeConstraint<predicate, summary, cppFunctionNameParam> {
161170
// The name of the C++ Type class if known, or Type if not.
162171
string cppType = cppTypeParam;
163-
// The name of the C++ function that is generated for this type constraint.
164-
// If empty, no C++ function is generated.
165-
string cppFunctionName = cppFunctionNameParam;
166172
}
167173

168174
// Subclass for constraints on an attribute.
169-
class AttrConstraint<Pred predicate, string summary = ""> :
170-
Constraint<predicate, summary>;
175+
class AttrConstraint<Pred predicate, string summary = "",
176+
string cppFunctionNameParam = ""> :
177+
AttrTypeConstraint<predicate, summary, cppFunctionNameParam>;
171178

172179
// Subclass for constraints on a property.
173180
class PropConstraint<Pred predicate, string summary = "", string interfaceTypeParam = ""> :
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-tblgen -gen-attr-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
2+
// RUN: mlir-tblgen -gen-attr-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF
3+
4+
include "mlir/IR/CommonAttrConstraints.td"
5+
6+
def DummyConstraint : AnyAttrOf<[APIntAttr, ArrayAttr, UnitAttr]> {
7+
let cppFunctionName = "isValidDummy";
8+
}
9+
10+
// DECL: bool isValidDummy(::mlir::Attribute attr);
11+
12+
// DEF: bool isValidDummy(::mlir::Attribute attr) {
13+
// DEF: return (((::llvm::isa<::mlir::IntegerAttr>(attr))) || ((::llvm::isa<::mlir::ArrayAttr>(attr))) || ((::llvm::isa<::mlir::UnitAttr>(attr))));
14+
// DEF: }

mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,15 +1083,15 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
10831083
}
10841084

10851085
//===----------------------------------------------------------------------===//
1086-
// Type Constraints
1086+
// Constraints
10871087
//===----------------------------------------------------------------------===//
10881088

10891089
/// Find all type constraints for which a C++ function should be generated.
1090-
static std::vector<Constraint>
1091-
getAllTypeConstraints(const RecordKeeper &records) {
1090+
static std::vector<Constraint> getAllCppConstraints(const RecordKeeper &records,
1091+
StringRef constraintKind) {
10921092
std::vector<Constraint> result;
10931093
for (const Record *def :
1094-
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
1094+
records.getAllDerivedDefinitionsIfDefined(constraintKind)) {
10951095
// Ignore constraints defined outside of the top-level file.
10961096
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
10971097
llvm::SrcMgr.getMainFileID())
@@ -1105,32 +1105,74 @@ getAllTypeConstraints(const RecordKeeper &records) {
11051105
return result;
11061106
}
11071107

1108+
static std::vector<Constraint>
1109+
getAllCppTypeConstraints(const RecordKeeper &records) {
1110+
return getAllCppConstraints(records, "TypeConstraint");
1111+
}
1112+
1113+
static std::vector<Constraint>
1114+
getAllCppAttrConstraints(const RecordKeeper &records) {
1115+
return getAllCppConstraints(records, "AttrConstraint");
1116+
}
1117+
1118+
/// Emit the declarations for the given constraints, of the form:
1119+
/// `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>);`
1120+
static void emitConstraintDecls(const std::vector<Constraint> &constraints,
1121+
raw_ostream &os, StringRef parameterTypeName,
1122+
StringRef parameterName) {
1123+
static const char *const constraintDecl = "bool {0}({1} {2});\n";
1124+
for (Constraint constr : constraints)
1125+
os << strfmt(constraintDecl, *constr.getCppFunctionName(),
1126+
parameterTypeName, parameterName);
1127+
}
1128+
11081129
static void emitTypeConstraintDecls(const RecordKeeper &records,
11091130
raw_ostream &os) {
1110-
static const char *const typeConstraintDecl = R"(
1111-
bool {0}(::mlir::Type type);
1112-
)";
1131+
emitConstraintDecls(getAllCppTypeConstraints(records), os, "::mlir::Type",
1132+
"type");
1133+
}
11131134

1114-
for (Constraint constr : getAllTypeConstraints(records))
1115-
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
1135+
static void emitAttrConstraintDecls(const RecordKeeper &records,
1136+
raw_ostream &os) {
1137+
emitConstraintDecls(getAllCppAttrConstraints(records), os,
1138+
"::mlir::Attribute", "attr");
11161139
}
11171140

1118-
static void emitTypeConstraintDefs(const RecordKeeper &records,
1119-
raw_ostream &os) {
1120-
static const char *const typeConstraintDef = R"(
1121-
bool {0}(::mlir::Type type) {
1122-
return ({1});
1141+
/// Emit the definitions for the given constraints, of the form:
1142+
/// `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>) {
1143+
/// return (<condition>); }`
1144+
/// where `<condition>` is the condition template with the `self` variable
1145+
/// replaced with the `selfName` parameter.
1146+
static void emitConstraintDefs(const std::vector<Constraint> &constraints,
1147+
raw_ostream &os, StringRef parameterTypeName,
1148+
StringRef selfName) {
1149+
static const char *const constraintDef = R"(
1150+
bool {0}({1} {2}) {
1151+
return ({3});
11231152
}
11241153
)";
11251154

1126-
for (Constraint constr : getAllTypeConstraints(records)) {
1155+
for (Constraint constr : constraints) {
11271156
FmtContext ctx;
1128-
ctx.withSelf("type");
1157+
ctx.withSelf(selfName);
11291158
std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
1130-
os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
1159+
os << strfmt(constraintDef, *constr.getCppFunctionName(), parameterTypeName,
1160+
selfName, condition);
11311161
}
11321162
}
11331163

1164+
static void emitTypeConstraintDefs(const RecordKeeper &records,
1165+
raw_ostream &os) {
1166+
emitConstraintDefs(getAllCppTypeConstraints(records), os, "::mlir::Type",
1167+
"type");
1168+
}
1169+
1170+
static void emitAttrConstraintDefs(const RecordKeeper &records,
1171+
raw_ostream &os) {
1172+
emitConstraintDefs(getAllCppAttrConstraints(records), os, "::mlir::Attribute",
1173+
"attr");
1174+
}
1175+
11341176
//===----------------------------------------------------------------------===//
11351177
// GEN: Registration hooks
11361178
//===----------------------------------------------------------------------===//
@@ -1158,6 +1200,21 @@ static mlir::GenRegistration
11581200
return generator.emitDecls(attrDialect);
11591201
});
11601202

1203+
static mlir::GenRegistration
1204+
genAttrConstrDefs("gen-attr-constraint-defs",
1205+
"Generate attribute constraint definitions",
1206+
[](const RecordKeeper &records, raw_ostream &os) {
1207+
emitAttrConstraintDefs(records, os);
1208+
return false;
1209+
});
1210+
static mlir::GenRegistration
1211+
genAttrConstrDecls("gen-attr-constraint-decls",
1212+
"Generate attribute constraint declarations",
1213+
[](const RecordKeeper &records, raw_ostream &os) {
1214+
emitAttrConstraintDecls(records, os);
1215+
return false;
1216+
});
1217+
11611218
//===----------------------------------------------------------------------===//
11621219
// TypeDef
11631220
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)