@@ -1083,15 +1083,15 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
1083
1083
}
1084
1084
1085
1085
// ===----------------------------------------------------------------------===//
1086
- // Type Constraints
1086
+ // Constraints
1087
1087
// ===----------------------------------------------------------------------===//
1088
1088
1089
1089
// / Find all type constraints for which a C++ function should be generated.
1090
- static std::vector<Constraint>
1091
- getAllTypeConstraints ( const RecordKeeper &records ) {
1090
+ static std::vector<Constraint> getAllCppConstraints ( const RecordKeeper &records,
1091
+ StringRef constraintKind ) {
1092
1092
std::vector<Constraint> result;
1093
1093
for (const Record *def :
1094
- records.getAllDerivedDefinitionsIfDefined (" TypeConstraint " )) {
1094
+ records.getAllDerivedDefinitionsIfDefined (constraintKind )) {
1095
1095
// Ignore constraints defined outside of the top-level file.
1096
1096
if (llvm::SrcMgr.FindBufferContainingLoc (def->getLoc ()[0 ]) !=
1097
1097
llvm::SrcMgr.getMainFileID ())
@@ -1105,32 +1105,74 @@ getAllTypeConstraints(const RecordKeeper &records) {
1105
1105
return result;
1106
1106
}
1107
1107
1108
+ static std::vector<Constraint>
1109
+ getAllCppTypeConstraints (const RecordKeeper &records) {
1110
+ return getAllCppConstraints (records, " TypeConstraint" );
1111
+ }
1112
+
1113
+ static std::vector<Constraint>
1114
+ getAllCppAttrConstraints (const RecordKeeper &records) {
1115
+ return getAllCppConstraints (records, " AttrConstraint" );
1116
+ }
1117
+
1118
+ // / Emit the declarations for the given constraints, of the form:
1119
+ // / `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>);`
1120
+ static void emitConstraintDecls (const std::vector<Constraint> &constraints,
1121
+ raw_ostream &os, StringRef parameterTypeName,
1122
+ StringRef parameterName) {
1123
+ static const char *const constraintDecl = " bool {0}({1} {2});\n " ;
1124
+ for (Constraint constr : constraints)
1125
+ os << strfmt (constraintDecl, *constr.getCppFunctionName (),
1126
+ parameterTypeName, parameterName);
1127
+ }
1128
+
1108
1129
static void emitTypeConstraintDecls (const RecordKeeper &records,
1109
1130
raw_ostream &os) {
1110
- static const char * const typeConstraintDecl = R"(
1111
- bool {0}(::mlir::Type type);
1112
- )" ;
1131
+ emitConstraintDecls ( getAllCppTypeConstraints (records), os, " ::mlir::Type " ,
1132
+ " type" );
1133
+ }
1113
1134
1114
- for (Constraint constr : getAllTypeConstraints (records))
1115
- os << strfmt (typeConstraintDecl, *constr.getCppFunctionName ());
1135
+ static void emitAttrConstraintDecls (const RecordKeeper &records,
1136
+ raw_ostream &os) {
1137
+ emitConstraintDecls (getAllCppAttrConstraints (records), os,
1138
+ " ::mlir::Attribute" , " attr" );
1116
1139
}
1117
1140
1118
- static void emitTypeConstraintDefs (const RecordKeeper &records,
1119
- raw_ostream &os) {
1120
- static const char *const typeConstraintDef = R"(
1121
- bool {0}(::mlir::Type type) {
1122
- return ({1});
1141
+ // / Emit the definitions for the given constraints, of the form:
1142
+ // / `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>) {
1143
+ // / return (<condition>); }`
1144
+ // / where `<condition>` is the condition template with the `self` variable
1145
+ // / replaced with the `selfName` parameter.
1146
+ static void emitConstraintDefs (const std::vector<Constraint> &constraints,
1147
+ raw_ostream &os, StringRef parameterTypeName,
1148
+ StringRef selfName) {
1149
+ static const char *const constraintDef = R"(
1150
+ bool {0}({1} {2}) {
1151
+ return ({3});
1123
1152
}
1124
1153
)" ;
1125
1154
1126
- for (Constraint constr : getAllTypeConstraints (records) ) {
1155
+ for (Constraint constr : constraints ) {
1127
1156
FmtContext ctx;
1128
- ctx.withSelf (" type " );
1157
+ ctx.withSelf (selfName );
1129
1158
std::string condition = tgfmt (constr.getConditionTemplate (), &ctx);
1130
- os << strfmt (typeConstraintDef, *constr.getCppFunctionName (), condition);
1159
+ os << strfmt (constraintDef, *constr.getCppFunctionName (), parameterTypeName,
1160
+ selfName, condition);
1131
1161
}
1132
1162
}
1133
1163
1164
+ static void emitTypeConstraintDefs (const RecordKeeper &records,
1165
+ raw_ostream &os) {
1166
+ emitConstraintDefs (getAllCppTypeConstraints (records), os, " ::mlir::Type" ,
1167
+ " type" );
1168
+ }
1169
+
1170
+ static void emitAttrConstraintDefs (const RecordKeeper &records,
1171
+ raw_ostream &os) {
1172
+ emitConstraintDefs (getAllCppAttrConstraints (records), os, " ::mlir::Attribute" ,
1173
+ " attr" );
1174
+ }
1175
+
1134
1176
// ===----------------------------------------------------------------------===//
1135
1177
// GEN: Registration hooks
1136
1178
// ===----------------------------------------------------------------------===//
@@ -1158,6 +1200,21 @@ static mlir::GenRegistration
1158
1200
return generator.emitDecls (attrDialect);
1159
1201
});
1160
1202
1203
+ static mlir::GenRegistration
1204
+ genAttrConstrDefs (" gen-attr-constraint-defs" ,
1205
+ " Generate attribute constraint definitions" ,
1206
+ [](const RecordKeeper &records, raw_ostream &os) {
1207
+ emitAttrConstraintDefs (records, os);
1208
+ return false ;
1209
+ });
1210
+ static mlir::GenRegistration
1211
+ genAttrConstrDecls (" gen-attr-constraint-decls" ,
1212
+ " Generate attribute constraint declarations" ,
1213
+ [](const RecordKeeper &records, raw_ostream &os) {
1214
+ emitAttrConstraintDecls (records, os);
1215
+ return false ;
1216
+ });
1217
+
1161
1218
// ===----------------------------------------------------------------------===//
1162
1219
// TypeDef
1163
1220
// ===----------------------------------------------------------------------===//
0 commit comments