Skip to content

[MLIR][ODS] Optionally generate public C++ functions for attribute constraints #144275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions mlir/docs/DefiningDialects/Constraints.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ code is generated for type/attribute constraints. Type constraints can not only
be used when defining operation arguments, but also when defining type
parameters.

Optionally, C++ functions can be generated, so that type constraints can be
checked from C++. The name of the C++ function must be specified in the
Optionally, C++ functions can be generated, so that type/attribute constraints
can be checked from C++. The name of the C++ function must be specified in the
`cppFunctionName` field. If no function name is specified, no C++ function is
emitted.

Expand All @@ -43,17 +43,20 @@ bool isValidVectorTypeElementType(::mlir::Type type) {
}
```

An extra TableGen rule is needed to emit C++ code for type constraints. This
will generate only the declarations/definitions of the type constaraints that
are defined in the specified `.td` file, but not those that are in included
`.td` files.
An extra TableGen rule is needed to emit C++ code for type/attribute
constraints. This will generate only the declarations/definitions of the
type/attribute constaraints that are defined in the specified `.td` file, but
not those that are in included `.td` files.

```cmake
mlir_tablegen(<Your Dialect>TypeConstraints.h.inc -gen-type-constraint-decls)
mlir_tablegen(<Your Dialect>TypeConstraints.cpp.inc -gen-type-constraint-defs)
mlir_tablegen(<Your Dialect>AttrConstraints.h.inc -gen-attr-constraint-decls)
mlir_tablegen(<Your Dialect>AttrConstraints.cpp.inc -gen-attr-constraint-defs)
```

The generated `<Your Dialect>TypeConstraints.h.inc` will need to be included
whereever you are referencing the type constraint in C++. Note that no C++
namespace will be emitted by the code generator. The `#include` statements of
the `.h.inc`/`.cpp.inc` files should be wrapped in C++ namespaces by the user.
The generated `<Your Dialect>TypeConstraints.h.inc` respectivelly
`<Your Dialect>AttrConstraints.h.inc` will need to be included whereever you are
referencing the type/attributes constraint in C++. Note that no C++ namespace
will be emitted by the code generator. The `#include` statements of the
`.h.inc`/`.cpp.inc` files should be wrapped in C++ namespaces by the user.
19 changes: 13 additions & 6 deletions mlir/include/mlir/IR/Constraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ class Constraint<Pred pred, string desc = ""> {
string summary = desc;
}

// Base class for constraints on types and attributes.
class AttrTypeConstraint<Pred pred, string summary = "",
string cppFunctionNameParam = ""> :
Constraint<pred, summary> {
// The name of the C++ function that is generated for this constraint.
// If empty, no C++ function is generated.
string cppFunctionName = cppFunctionNameParam;
}

// Subclasses used to differentiate different constraint kinds. These are used
// as markers for the TableGen backend to handle different constraint kinds
// differently if needed. Constraints not deriving from the following subclasses
Expand All @@ -157,17 +166,15 @@ class Constraint<Pred pred, string desc = ""> {
class TypeConstraint<Pred predicate, string summary = "",
string cppTypeParam = "::mlir::Type",
string cppFunctionNameParam = ""> :
Constraint<predicate, summary> {
AttrTypeConstraint<predicate, summary, cppFunctionNameParam> {
// The name of the C++ Type class if known, or Type if not.
string cppType = cppTypeParam;
// The name of the C++ function that is generated for this type constraint.
// If empty, no C++ function is generated.
string cppFunctionName = cppFunctionNameParam;
}

// Subclass for constraints on an attribute.
class AttrConstraint<Pred predicate, string summary = ""> :
Constraint<predicate, summary>;
class AttrConstraint<Pred predicate, string summary = "",
string cppFunctionNameParam = ""> :
AttrTypeConstraint<predicate, summary, cppFunctionNameParam>;

// Subclass for constraints on a property.
class PropConstraint<Pred predicate, string summary = "", string interfaceTypeParam = ""> :
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/mlir-tblgen/attr-constraints.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-tblgen -gen-attr-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-attr-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF

include "mlir/IR/CommonAttrConstraints.td"

def DummyConstraint : AnyAttrOf<[APIntAttr, ArrayAttr, UnitAttr]> {
let cppFunctionName = "isValidDummy";
}

// DECL: bool isValidDummy(::mlir::Attribute attr);

// DEF: bool isValidDummy(::mlir::Attribute attr) {
// DEF: return (((::llvm::isa<::mlir::IntegerAttr>(attr))) || ((::llvm::isa<::mlir::ArrayAttr>(attr))) || ((::llvm::isa<::mlir::UnitAttr>(attr))));
// DEF: }
91 changes: 74 additions & 17 deletions mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1083,15 +1083,15 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
}

//===----------------------------------------------------------------------===//
// Type Constraints
// Constraints
//===----------------------------------------------------------------------===//

/// Find all type constraints for which a C++ function should be generated.
static std::vector<Constraint>
getAllTypeConstraints(const RecordKeeper &records) {
static std::vector<Constraint> getAllCppConstraints(const RecordKeeper &records,
StringRef constraintKind) {
std::vector<Constraint> result;
for (const Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
records.getAllDerivedDefinitionsIfDefined(constraintKind)) {
// Ignore constraints defined outside of the top-level file.
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
llvm::SrcMgr.getMainFileID())
Expand All @@ -1105,32 +1105,74 @@ getAllTypeConstraints(const RecordKeeper &records) {
return result;
}

static std::vector<Constraint>
getAllCppTypeConstraints(const RecordKeeper &records) {
return getAllCppConstraints(records, "TypeConstraint");
}

static std::vector<Constraint>
getAllCppAttrConstraints(const RecordKeeper &records) {
return getAllCppConstraints(records, "AttrConstraint");
}

/// Emit the declarations for the given constraints, of the form:
/// `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>);`
static void emitConstraintDecls(const std::vector<Constraint> &constraints,
raw_ostream &os, StringRef parameterTypeName,
StringRef parameterName) {
static const char *const constraintDecl = "bool {0}({1} {2});\n";
for (Constraint constr : constraints)
os << strfmt(constraintDecl, *constr.getCppFunctionName(),
parameterTypeName, parameterName);
}

static void emitTypeConstraintDecls(const RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDecl = R"(
bool {0}(::mlir::Type type);
)";
emitConstraintDecls(getAllCppTypeConstraints(records), os, "::mlir::Type",
"type");
}

for (Constraint constr : getAllTypeConstraints(records))
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
static void emitAttrConstraintDecls(const RecordKeeper &records,
raw_ostream &os) {
emitConstraintDecls(getAllCppAttrConstraints(records), os,
"::mlir::Attribute", "attr");
}

static void emitTypeConstraintDefs(const RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDef = R"(
bool {0}(::mlir::Type type) {
return ({1});
/// Emit the definitions for the given constraints, of the form:
/// `bool <constraintCppFunctionName>(<parameterTypeName> <parameterName>) {
/// return (<condition>); }`
/// where `<condition>` is the condition template with the `self` variable
/// replaced with the `selfName` parameter.
static void emitConstraintDefs(const std::vector<Constraint> &constraints,
raw_ostream &os, StringRef parameterTypeName,
StringRef selfName) {
static const char *const constraintDef = R"(
bool {0}({1} {2}) {
return ({3});
}
)";

for (Constraint constr : getAllTypeConstraints(records)) {
for (Constraint constr : constraints) {
FmtContext ctx;
ctx.withSelf("type");
ctx.withSelf(selfName);
std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
os << strfmt(constraintDef, *constr.getCppFunctionName(), parameterTypeName,
selfName, condition);
}
}

static void emitTypeConstraintDefs(const RecordKeeper &records,
raw_ostream &os) {
emitConstraintDefs(getAllCppTypeConstraints(records), os, "::mlir::Type",
"type");
}

static void emitAttrConstraintDefs(const RecordKeeper &records,
raw_ostream &os) {
emitConstraintDefs(getAllCppAttrConstraints(records), os, "::mlir::Attribute",
"attr");
}

//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1158,6 +1200,21 @@ static mlir::GenRegistration
return generator.emitDecls(attrDialect);
});

static mlir::GenRegistration
genAttrConstrDefs("gen-attr-constraint-defs",
"Generate attribute constraint definitions",
[](const RecordKeeper &records, raw_ostream &os) {
emitAttrConstraintDefs(records, os);
return false;
});
static mlir::GenRegistration
genAttrConstrDecls("gen-attr-constraint-decls",
"Generate attribute constraint declarations",
[](const RecordKeeper &records, raw_ostream &os) {
emitAttrConstraintDecls(records, os);
return false;
});

//===----------------------------------------------------------------------===//
// TypeDef
//===----------------------------------------------------------------------===//
Expand Down
Loading