Skip to content

[mlir][irdl] Add support for basic structural constraints in tblgen-to-irdl #82862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,28 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
BuildableType<"$_builder.getType<::mlir::NoneType>()">;

// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string summary = "",
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
// Satisfy any of the allowed types' conditions.
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypes, t.summary), " or "),
!interleave(!foreach(t, allowedTypeList, t.summary), " or "),
summary),
cppClassName>;
cppClassName> {
list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies the constraints of all given types.
class AllOfType<list<Type> allowedTypes, string summary = "",
class AllOfType<list<Type> allowedTypeList, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
// Satisfy all of the allowedf types' conditions.
And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
// Satisfy all of the allowed types' conditions.
And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypes, t.summary), " and "),
!interleave(!foreach(t, allowedTypeList, t.summary), " and "),
summary),
cppClassName>;
cppClassName> {
list<Type> allowedTypes = allowedTypeList;
}

// A type that satisfies additional predicates.
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
Expand Down
12 changes: 6 additions & 6 deletions mlir/test/tblgen-to-irdl/CMathDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
}

// CHECK: irdl.operation @identity {
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands()
// CHECK-NEXT: irdl.results(%0)
// CHECK-NEXT: }
Expand All @@ -33,9 +33,9 @@ def CMath_IdentityOp : CMath_Op<"identity"> {
}

// CHECK: irdl.operation @mul {
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %2 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
// CHECK-NEXT: %2 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands(%0, %1)
// CHECK-NEXT: irdl.results(%2)
// CHECK-NEXT: }
Expand All @@ -45,8 +45,8 @@ def CMath_MulOp : CMath_Op<"mul"> {
}

// CHECK: irdl.operation @norm {
// CHECK-NEXT: %0 = irdl.c_pred "(true)"
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
// CHECK-NEXT: %0 = irdl.any
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands(%0)
// CHECK-NEXT: irdl.results(%1)
// CHECK-NEXT: }
Expand Down
74 changes: 74 additions & 0 deletions mlir/test/tblgen-to-irdl/TestDialect.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=test | FileCheck %s

include "mlir/IR/OpBase.td"
include "mlir/IR/AttrTypeBase.td"

// CHECK-LABEL: irdl.dialect @test {
def Test_Dialect : Dialect {
let name = "test";
}

class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<Test_Dialect, name, traits> {
let mnemonic = typeMnemonic;
}

class Test_Op<string mnemonic, list<Trait> traits = []>
: Op<Test_Dialect, mnemonic, traits>;

def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}


// Check that AllOfType is converted correctly.
def Test_AndOp : Test_Op<"and"> {
let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in);
}
// CHECK-LABEL: irdl.operation @and {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
// CHECK-NEXT: irdl.operands(%[[v2]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }


// Check that AnyType is converted correctly.
def Test_AnyOp : Test_Op<"any"> {
let arguments = (ins AnyType:$in);
}
// CHECK-LABEL: irdl.operation @any {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
// CHECK-NEXT: irdl.operands(%[[v0]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }


// Check that AnyTypeOf is converted correctly.
def Test_OrOp : Test_Op<"or"> {
let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in);
}
// CHECK-LABEL: irdl.operation @or {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.operands(%[[v3]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }


// Check that variadics and optionals are converted correctly.
def Test_VariadicityOp : Test_Op<"variadicity"> {
let arguments = (ins Variadic<Test_SingletonAType>:$variadic,
Optional<Test_SingletonBType>:$optional,
Test_SingletonCType:$required);
}
// CHECK-LABEL: irdl.operation @variadicity {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
48 changes: 41 additions & 7 deletions mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,49 @@ llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);

irdl::CPredOp createConstraint(OpBuilder &builder,
NamedTypeConstraint namedConstraint) {
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
// Build the constraint as a string.
std::string constraint =
namedConstraint.constraint.getPredicate().getCondition();
const Record &predRec = constraint.getDef();

if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
return createConstraint(builder, predRec.getValueAsDef("baseType"));

if (predRec.getName() == "AnyType") {
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
return op.getOutput();
}

if (predRec.isSubClassOf("TypeDef")) {
std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
StringAttr::get(ctx, typeName));
return op.getOutput();
}

if (predRec.isSubClassOf("AnyTypeOf")) {
std::vector<Value> constraints;
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}

if (predRec.isSubClassOf("AllOfType")) {
std::vector<Value> constraints;
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
constraints.push_back(
createConstraint(builder, tblgen::Constraint(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}

std::string condition = constraint.getPredicate().getCondition();
// Build a CPredOp to match the C constraint built.
irdl::CPredOp op = builder.create<irdl::CPredOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, constraint));
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
return op;
}

Expand All @@ -74,7 +108,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
SmallVector<Value> operands;
SmallVector<irdl::VariadicityAttr> variadicity;
for (const NamedTypeConstraint &namedCons : namedCons) {
auto operand = createConstraint(consBuilder, namedCons);
auto operand = createConstraint(consBuilder, namedCons.constraint);
operands.push_back(operand);

irdl::VariadicityAttr var;
Expand Down