Skip to content

[mlir][ODS] Optionally generate public C++ functions for type constraints #104577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions mlir/docs/DefiningDialects/Constraints.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Constraints

[TOC]

## Attribute / Type Constraints

When defining the arguments of an operation in TableGen, users can specify
either plain attributes/types or use attribute/type constraints to levy
additional requirements on the attribute value or operand type.

```tablegen
def My_Type1 : MyDialect_Type<"Type1", "type1"> { ... }
def My_Type2 : MyDialect_Type<"Type2", "type2"> { ... }

// Plain type
let arguments = (ins MyType1:$val);
// Type constraint
let arguments = (ins AnyTypeOf<[MyType1, MyType2]>:$val);
```

`AnyTypeOf` is an example for a type constraints. Many useful type constraints
can be found in `mlir/IR/CommonTypeConstraints.td`. Additional verification
code is generated for type/attribute constraints. Type constraints can not only
be used when defining operation arguments, but also when defining type
parameters.

Optionally, C++ functions can be generated, so that type constraints can be
checked from C++. The name of the C++ function must be specified in the
`cppFunctionName` field. If no function name is specified, no C++ function is
emitted.

```tablegen
// Example: Element type constraint for VectorType
def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
let cppFunctionName = "isValidVectorTypeElementType";
}
```

The above example tranlates into the following C++ code:
```c++
bool isValidVectorTypeElementType(::mlir::Type type) {
return (((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type))));
}
```

An extra TableGen rule is needed to emit C++ code for type constraints. This
will generate only the declarations/definitions of the type constaraints that
are defined in the specified `.td` file, but not those that are in included
`.td` files.

```cmake
mlir_tablegen(<Your Dialect>TypeConstraints.h.inc -gen-type-constraint-decls)
mlir_tablegen(<Your Dialect>TypeConstraints.cpp.inc -gen-type-constraint-defs)
```

The generated `<Your Dialect>TypeConstraints.h.inc` will need to be included
whereever you are referencing the type constraint in C++. Note that no C++
namespace will be emitted by the code generator. The `#include` statements of
the `.h.inc`/`.cpp.inc` files should be wrapped in C++ namespaces by the user.
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
#include "mlir/IR/BuiltinTypes.h.inc"

namespace mlir {
#include "mlir/IR/BuiltinTypeConstraints.h.inc"

//===----------------------------------------------------------------------===//
// MemRefType
Expand Down
14 changes: 7 additions & 7 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,10 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
// VectorType
//===----------------------------------------------------------------------===//

def Builtin_VectorTypeElementType : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this should go into a separate .td file. Otherwise there could be multiple definitions of the same C++ function if BuiltinTypes.td is included in another C++ file for which type constraints are also generated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was running into something like this too, I think at least for all of these a dialect is specified. So one could either filter on dialect in generation, or we could require that every dialect has one "main" file where it sets the current dialect and then the generator would always just consider that one for given main file unless overridden (I feel like there is too much filter command line flags while the common case matches this)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do the same thing that we do for interfaces: Ignore everything not in the direct file that is being generated. I am a bit concerned about how we handle constraint generation, because there are so many of them spread across nearly every .td file. I think we need to be careful about how/which ones get generated (probably even more so than interfaces).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore everything not in the direct file that is being generated.

This sounds like a good strategy to me. I changed the implementation accordingly.

I think we need to be careful about how/which ones get generated

A C++ function is generated only if the cppFunctionName field is set to a non-empty string. By default it is empty, so this entire process is on an opt-in basis.

Could we have the one called from the other? Then the difference is just whether its exposed in header and perhaps whether instantiated in anonymous/impl namespace.

If a constraint calls a constraint via TableGen, it is fully inlined. E.g.: AnyTypeOf<[I32, AnyTypeOf<[I16, I64]>]> will generate a single C++ function that checks for i32, i16, i64. So this will work out of the box.

If a constraint calls a constraint via CPred<"...">, arbitrary C++ code can be specified and I'd say it's the user's responsibility to ensure that all needed C++ functions are pre-declared. By making sure that the corresponding header file (.h not .inc.h) is included in the .cpp file that contains the constraint definition. (Same as with ops that implement op interfaces: the user must include the .h file that declares the op interface in the .h file that contains the operation class.)

let cppFunctionName = "isValidVectorTypeElementType";
}

def Builtin_Vector : Builtin_Type<"Vector", "vector",
[ShapedTypeInterface, ValueSemantics], "Type"> {
let summary = "Multi-dimensional SIMD vector type";
Expand Down Expand Up @@ -1147,7 +1151,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
AnyTypeOf<[AnyInteger, Index, AnyFloat]>:$elementType,
Builtin_VectorTypeElementType:$elementType,
ArrayRefParameter<"bool">:$scalableDims
);
let builders = [
Expand All @@ -1171,12 +1175,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector",
class Builder;

/// Returns true if the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
static bool isValidElementType(Type t) {
// TODO: Auto-generate this function from $elementType.
return ::llvm::isa<IntegerType, IndexType, FloatType>(t);
}
/// type. See "Builtin_VectorTypeElementType" for allowed types.
static bool isValidElementType(Type t);

/// Returns true if the vector contains scalable dimensions.
bool isScalable() const {
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td)
mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls)
mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs)
add_public_tablegen_target(MLIRBuiltinTypesIncGen)
mlir_tablegen(BuiltinTypeConstraints.h.inc -gen-type-constraint-decls)
mlir_tablegen(BuiltinTypeConstraints.cpp.inc -gen-type-constraint-defs)
add_public_tablegen_target(MLIRBuiltinTypeConstraintsIncGen)

set(LLVM_TARGET_DEFINITIONS BuiltinTypeInterfaces.td)
mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/IR/Constraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,14 @@ class Constraint<Pred pred, string desc = ""> {

// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string summary = "",
string cppTypeParam = "::mlir::Type"> :
string cppTypeParam = "::mlir::Type",
string cppFunctionNameParam = ""> :
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
string cppType = cppTypeParam;
// The name of the C++ function that is generated for this type constraint.
// If empty, no C++ function is generated.
string cppFunctionName = cppFunctionNameParam;
}

// Subclass for constraints on an attribute.
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/TableGen/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ class Constraint {
/// context on the def).
std::string getUniqueDefName() const;

/// Returns the name of the C++ function that should be generated for this
/// constraint, or std::nullopt if no C++ function should be generated.
std::optional<StringRef> getCppFunctionName() const;

Kind getKind() const { return kind; }

/// Return the underlying def.
Expand Down
16 changes: 14 additions & 2 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ using namespace mlir::detail;
#define GET_TYPEDEF_CLASSES
#include "mlir/IR/BuiltinTypes.cpp.inc"

namespace mlir {
#include "mlir/IR/BuiltinTypeConstraints.cpp.inc"
} // namespace mlir

//===----------------------------------------------------------------------===//
// BuiltinDialect
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -230,6 +234,10 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
// VectorType
//===----------------------------------------------------------------------===//

bool VectorType::isValidElementType(Type t) {
return isValidVectorTypeElementType(t);
}

LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
ArrayRef<bool> scalableDims) {
Expand Down Expand Up @@ -278,7 +286,9 @@ Type TensorType::getElementType() const {
[](auto type) { return type.getElementType(); });
}

bool TensorType::hasRank() const { return !llvm::isa<UnrankedTensorType>(*this); }
bool TensorType::hasRank() const {
return !llvm::isa<UnrankedTensorType>(*this);
}

ArrayRef<int64_t> TensorType::getShape() const {
return llvm::cast<RankedTensorType>(*this).getShape();
Expand Down Expand Up @@ -365,7 +375,9 @@ Type BaseMemRefType::getElementType() const {
[](auto type) { return type.getElementType(); });
}

bool BaseMemRefType::hasRank() const { return !llvm::isa<UnrankedMemRefType>(*this); }
bool BaseMemRefType::hasRank() const {
return !llvm::isa<UnrankedMemRefType>(*this);
}

ArrayRef<int64_t> BaseMemRefType::getShape() const {
return llvm::cast<MemRefType>(*this).getShape();
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ add_mlir_library(MLIRIR
MLIRBuiltinLocationAttributesIncGen
MLIRBuiltinOpsIncGen
MLIRBuiltinTypesIncGen
MLIRBuiltinTypeConstraintsIncGen
MLIRBuiltinTypeInterfacesIncGen
MLIRCallInterfacesIncGen
MLIRCastInterfacesIncGen
Expand Down
10 changes: 9 additions & 1 deletion mlir/lib/TableGen/Constraint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Constraint::Constraint(const llvm::Record *record)
kind = CK_Region;
} else if (def->isSubClassOf("SuccessorConstraint")) {
kind = CK_Successor;
} else if(!def->isSubClassOf("Constraint")) {
} else if (!def->isSubClassOf("Constraint")) {
llvm::errs() << "Expected a constraint but got: \n" << *def << "\n";
llvm::report_fatal_error("Abort");
}
Expand Down Expand Up @@ -109,6 +109,14 @@ std::optional<StringRef> Constraint::getBaseDefName() const {
}
}

std::optional<StringRef> Constraint::getCppFunctionName() const {
std::optional<StringRef> name =
def->getValueAsOptionalString("cppFunctionName");
if (!name || *name == "")
return std::nullopt;
return name;
}

AppliedConstraint::AppliedConstraint(Constraint &&constraint,
llvm::StringRef self,
std::vector<std::string> &&entities)
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/mlir-tblgen/type-constraints.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-tblgen -gen-type-constraint-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL
// RUN: mlir-tblgen -gen-type-constraint-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF

include "mlir/IR/CommonTypeConstraints.td"

def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> {
let cppFunctionName = "isValidDummy";
}

// DECL: bool isValidDummy(::mlir::Type type);

// DEF: bool isValidDummy(::mlir::Type type) {
// DEF: return (((::llvm::isa<::mlir::IntegerType>(type))) || ((::llvm::isa<::mlir::IndexType>(type))) || ((::llvm::isa<::mlir::FloatType>(type))));
// DEF: }
64 changes: 64 additions & 0 deletions mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,55 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) {
return false;
}

//===----------------------------------------------------------------------===//
// Type Constraints
//===----------------------------------------------------------------------===//

/// Find all type constraints for which a C++ function should be generated.
static std::vector<Constraint>
getAllTypeConstraints(const llvm::RecordKeeper &records) {
std::vector<Constraint> result;
for (llvm::Record *def :
records.getAllDerivedDefinitionsIfDefined("TypeConstraint")) {
// Ignore constraints defined outside of the top-level file.
if (llvm::SrcMgr.FindBufferContainingLoc(def->getLoc()[0]) !=
llvm::SrcMgr.getMainFileID())
continue;
Constraint constr(def);
// Generate C++ function only if "cppFunctionName" is set.
if (!constr.getCppFunctionName())
continue;
result.push_back(constr);
}
return result;
}

static void emitTypeConstraintDecls(const llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDecl = R"(
bool {0}(::mlir::Type type);
)";

for (Constraint constr : getAllTypeConstraints(records))
os << strfmt(typeConstraintDecl, *constr.getCppFunctionName());
}

static void emitTypeConstraintDefs(const llvm::RecordKeeper &records,
raw_ostream &os) {
static const char *const typeConstraintDef = R"(
bool {0}(::mlir::Type type) {
return ({1});
}
)";

for (Constraint constr : getAllTypeConstraints(records)) {
FmtContext ctx;
ctx.withSelf("type");
std::string condition = tgfmt(constr.getConditionTemplate(), &ctx);
os << strfmt(typeConstraintDef, *constr.getCppFunctionName(), condition);
}
}

//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1070,3 +1119,18 @@ static mlir::GenRegistration
TypeDefGenerator generator(records, os);
return generator.emitDecls(typeDialect);
});

static mlir::GenRegistration
genTypeConstrDefs("gen-type-constraint-defs",
"Generate type constraint definitions",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDefs(records, os);
return false;
});
static mlir::GenRegistration
genTypeConstrDecls("gen-type-constraint-decls",
"Generate type constraint declarations",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
emitTypeConstraintDecls(records, os);
return false;
});
Loading