Skip to content

Commit a64975f

Browse files
authored
[mlir][irdl] Add support for basic structural constraints in tblgen-to-irdl (#82862)
1 parent 1c2b79a commit a64975f

File tree

4 files changed

+134
-22
lines changed

4 files changed

+134
-22
lines changed

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,24 +168,28 @@ def NoneType : Type<CPred<"::llvm::isa<::mlir::NoneType>($_self)">, "none type",
168168
BuildableType<"$_builder.getType<::mlir::NoneType>()">;
169169

170170
// Any type from the given list
171-
class AnyTypeOf<list<Type> allowedTypes, string summary = "",
171+
class AnyTypeOf<list<Type> allowedTypeList, string summary = "",
172172
string cppClassName = "::mlir::Type"> : Type<
173173
// Satisfy any of the allowed types' conditions.
174-
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
174+
Or<!foreach(allowedtype, allowedTypeList, allowedtype.predicate)>,
175175
!if(!eq(summary, ""),
176-
!interleave(!foreach(t, allowedTypes, t.summary), " or "),
176+
!interleave(!foreach(t, allowedTypeList, t.summary), " or "),
177177
summary),
178-
cppClassName>;
178+
cppClassName> {
179+
list<Type> allowedTypes = allowedTypeList;
180+
}
179181

180182
// A type that satisfies the constraints of all given types.
181-
class AllOfType<list<Type> allowedTypes, string summary = "",
183+
class AllOfType<list<Type> allowedTypeList, string summary = "",
182184
string cppClassName = "::mlir::Type"> : Type<
183-
// Satisfy all of the allowedf types' conditions.
184-
And<!foreach(allowedType, allowedTypes, allowedType.predicate)>,
185+
// Satisfy all of the allowed types' conditions.
186+
And<!foreach(allowedType, allowedTypeList, allowedType.predicate)>,
185187
!if(!eq(summary, ""),
186-
!interleave(!foreach(t, allowedTypes, t.summary), " and "),
188+
!interleave(!foreach(t, allowedTypeList, t.summary), " and "),
187189
summary),
188-
cppClassName>;
190+
cppClassName> {
191+
list<Type> allowedTypes = allowedTypeList;
192+
}
189193

190194
// A type that satisfies additional predicates.
191195
class ConfinedType<Type type, list<Pred> predicates, string summary = "",

mlir/test/tblgen-to-irdl/CMathDialect.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
2424
}
2525

2626
// CHECK: irdl.operation @identity {
27-
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
27+
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
2828
// CHECK-NEXT: irdl.operands()
2929
// CHECK-NEXT: irdl.results(%0)
3030
// CHECK-NEXT: }
@@ -33,9 +33,9 @@ def CMath_IdentityOp : CMath_Op<"identity"> {
3333
}
3434

3535
// CHECK: irdl.operation @mul {
36-
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
37-
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
38-
// CHECK-NEXT: %2 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
36+
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
37+
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
38+
// CHECK-NEXT: %2 = irdl.base "!cmath.complex"
3939
// CHECK-NEXT: irdl.operands(%0, %1)
4040
// CHECK-NEXT: irdl.results(%2)
4141
// CHECK-NEXT: }
@@ -45,8 +45,8 @@ def CMath_MulOp : CMath_Op<"mul"> {
4545
}
4646

4747
// CHECK: irdl.operation @norm {
48-
// CHECK-NEXT: %0 = irdl.c_pred "(true)"
49-
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
48+
// CHECK-NEXT: %0 = irdl.any
49+
// CHECK-NEXT: %1 = irdl.base "!cmath.complex"
5050
// CHECK-NEXT: irdl.operands(%0)
5151
// CHECK-NEXT: irdl.results(%1)
5252
// CHECK-NEXT: }
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=test | FileCheck %s
2+
3+
include "mlir/IR/OpBase.td"
4+
include "mlir/IR/AttrTypeBase.td"
5+
6+
// CHECK-LABEL: irdl.dialect @test {
7+
def Test_Dialect : Dialect {
8+
let name = "test";
9+
}
10+
11+
class Test_Type<string name, string typeMnemonic, list<Trait> traits = []>
12+
: TypeDef<Test_Dialect, name, traits> {
13+
let mnemonic = typeMnemonic;
14+
}
15+
16+
class Test_Op<string mnemonic, list<Trait> traits = []>
17+
: Op<Test_Dialect, mnemonic, traits>;
18+
19+
def Test_SingletonAType : Test_Type<"SingletonAType", "singleton_a"> {}
20+
def Test_SingletonBType : Test_Type<"SingletonBType", "singleton_b"> {}
21+
def Test_SingletonCType : Test_Type<"SingletonCType", "singleton_c"> {}
22+
23+
24+
// Check that AllOfType is converted correctly.
25+
def Test_AndOp : Test_Op<"and"> {
26+
let arguments = (ins AllOfType<[Test_SingletonAType, AnyType]>:$in);
27+
}
28+
// CHECK-LABEL: irdl.operation @and {
29+
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
30+
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
31+
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
32+
// CHECK-NEXT: irdl.operands(%[[v2]])
33+
// CHECK-NEXT: irdl.results()
34+
// CHECK-NEXT: }
35+
36+
37+
// Check that AnyType is converted correctly.
38+
def Test_AnyOp : Test_Op<"any"> {
39+
let arguments = (ins AnyType:$in);
40+
}
41+
// CHECK-LABEL: irdl.operation @any {
42+
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
43+
// CHECK-NEXT: irdl.operands(%[[v0]])
44+
// CHECK-NEXT: irdl.results()
45+
// CHECK-NEXT: }
46+
47+
48+
// Check that AnyTypeOf is converted correctly.
49+
def Test_OrOp : Test_Op<"or"> {
50+
let arguments = (ins AnyTypeOf<[Test_SingletonAType, Test_SingletonBType, Test_SingletonCType]>:$in);
51+
}
52+
// CHECK-LABEL: irdl.operation @or {
53+
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
54+
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
55+
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
56+
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
57+
// CHECK-NEXT: irdl.operands(%[[v3]])
58+
// CHECK-NEXT: irdl.results()
59+
// CHECK-NEXT: }
60+
61+
62+
// Check that variadics and optionals are converted correctly.
63+
def Test_VariadicityOp : Test_Op<"variadicity"> {
64+
let arguments = (ins Variadic<Test_SingletonAType>:$variadic,
65+
Optional<Test_SingletonBType>:$optional,
66+
Test_SingletonCType:$required);
67+
}
68+
// CHECK-LABEL: irdl.operation @variadicity {
69+
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
70+
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
71+
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
72+
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
73+
// CHECK-NEXT: irdl.results()
74+
// CHECK-NEXT: }

mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,49 @@ llvm::cl::opt<std::string>
3939
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
4040
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
4141

42-
irdl::CPredOp createConstraint(OpBuilder &builder,
43-
NamedTypeConstraint namedConstraint) {
42+
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
4443
MLIRContext *ctx = builder.getContext();
45-
// Build the constraint as a string.
46-
std::string constraint =
47-
namedConstraint.constraint.getPredicate().getCondition();
44+
const Record &predRec = constraint.getDef();
45+
46+
if (predRec.isSubClassOf("Variadic") || predRec.isSubClassOf("Optional"))
47+
return createConstraint(builder, predRec.getValueAsDef("baseType"));
48+
49+
if (predRec.getName() == "AnyType") {
50+
auto op = builder.create<irdl::AnyOp>(UnknownLoc::get(ctx));
51+
return op.getOutput();
52+
}
53+
54+
if (predRec.isSubClassOf("TypeDef")) {
55+
std::string typeName = ("!" + predRec.getValueAsString("typeName")).str();
56+
auto op = builder.create<irdl::BaseOp>(UnknownLoc::get(ctx),
57+
StringAttr::get(ctx, typeName));
58+
return op.getOutput();
59+
}
60+
61+
if (predRec.isSubClassOf("AnyTypeOf")) {
62+
std::vector<Value> constraints;
63+
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
64+
constraints.push_back(
65+
createConstraint(builder, tblgen::Constraint(child)));
66+
}
67+
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
68+
return op.getOutput();
69+
}
70+
71+
if (predRec.isSubClassOf("AllOfType")) {
72+
std::vector<Value> constraints;
73+
for (Record *child : predRec.getValueAsListOfDefs("allowedTypes")) {
74+
constraints.push_back(
75+
createConstraint(builder, tblgen::Constraint(child)));
76+
}
77+
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
78+
return op.getOutput();
79+
}
80+
81+
std::string condition = constraint.getPredicate().getCondition();
4882
// Build a CPredOp to match the C constraint built.
4983
irdl::CPredOp op = builder.create<irdl::CPredOp>(
50-
UnknownLoc::get(ctx), StringAttr::get(ctx, constraint));
84+
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
5185
return op;
5286
}
5387

@@ -74,7 +108,7 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
74108
SmallVector<Value> operands;
75109
SmallVector<irdl::VariadicityAttr> variadicity;
76110
for (const NamedTypeConstraint &namedCons : namedCons) {
77-
auto operand = createConstraint(consBuilder, namedCons);
111+
auto operand = createConstraint(consBuilder, namedCons.constraint);
78112
operands.push_back(operand);
79113

80114
irdl::VariadicityAttr var;

0 commit comments

Comments
 (0)