Skip to content

Commit 135bd31

Browse files
authored
[mlir] [tblgen-to-irdl] Refactor tblgen-to-irdl script and support more types (#105505)
Refactors the tblgen-to-irdl script slightly and adds support for - Various integer types - Various Float types - Confined types - Complex types (with fixed element type) Also doesn't add the operand and result ops if they are empty. I could potentially split this into smaller PRs if that'd be helpful (refactor + integer/float/complex, confined type, optional operand/result). @math-fehr
1 parent 6043321 commit 135bd31

File tree

4 files changed

+227
-17
lines changed

4 files changed

+227
-17
lines changed

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,10 @@ class AllOfType<list<Type> allowedTypeList, string summary = "",
198198
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
199199
string cppType = type.cppType> : Type<
200200
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
201-
summary, cppType>;
201+
summary, cppType> {
202+
Type baseType = type;
203+
list<Pred> predicateList = predicates;
204+
}
202205

203206
// Integer types.
204207

mlir/test/tblgen-to-irdl/CMathDialect.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
2525

2626
// CHECK: irdl.operation @identity {
2727
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
28-
// CHECK-NEXT: irdl.operands()
2928
// CHECK-NEXT: irdl.results(%0)
3029
// CHECK-NEXT: }
3130
def CMath_IdentityOp : CMath_Op<"identity"> {

mlir/test/tblgen-to-irdl/TestDialect.td

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ def Test_AndOp : Test_Op<"and"> {
2828
// CHECK-LABEL: irdl.operation @and {
2929
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
3030
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
31-
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
31+
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
3232
// CHECK-NEXT: irdl.operands(%[[v2]])
33-
// CHECK-NEXT: irdl.results()
3433
// CHECK-NEXT: }
3534

3635

@@ -41,9 +40,39 @@ def Test_AnyOp : Test_Op<"any"> {
4140
// CHECK-LABEL: irdl.operation @any {
4241
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
4342
// CHECK-NEXT: irdl.operands(%[[v0]])
44-
// CHECK-NEXT: irdl.results()
4543
// CHECK-NEXT: }
4644

45+
// Check confined types are converted correctly.
46+
def Test_ConfinedOp : Test_Op<"confined"> {
47+
let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
48+
ConfinedType<AnyType, [And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">
49+
, CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>]>:$vector);
50+
}
51+
// CHECK-LABEL: irdl.operation @confined {
52+
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
53+
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.c_pred "(::llvm::isa<::mlir::TensorType>($_self))"
54+
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
55+
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any
56+
// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.c_pred "(::llvm::isa<::mlir::VectorType>($_self))"
57+
// CHECK-NEXT: %[[v5:[^ ]*]] = irdl.c_pred "(::llvm::cast<::mlir::VectorType>($_self).getRank() > 0)"
58+
// CHECK-NEXT: %[[v6:[^ ]*]] = irdl.all_of(%[[v4]], %[[v5]])
59+
// CHECK-NEXT: %[[v7:[^ ]*]] = irdl.all_of(%[[v3]], %[[v6]])
60+
// CHECK-NEXT: irdl.operands(%[[v2]], %[[v7]])
61+
// CHECK-NEXT: }
62+
63+
// Check generic integer types are converted correctly.
64+
def Test_Integers : Test_Op<"integers"> {
65+
let arguments = (ins AnyI8:$any_int,
66+
AnyInteger:$any_integer);
67+
}
68+
// CHECK-LABEL: irdl.operation @integers {
69+
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i8
70+
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.is si8
71+
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.is ui8
72+
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
73+
// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.base "!builtin.integer"
74+
// CHECK-NEXT: irdl.operands(%[[v3]], %[[v4]])
75+
// CHECK-NEXT: }
4776

4877
// Check that AnyTypeOf is converted correctly.
4978
def Test_OrOp : Test_Op<"or"> {
@@ -53,11 +82,30 @@ def Test_OrOp : Test_Op<"or"> {
5382
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
5483
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
5584
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
56-
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
85+
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
5786
// CHECK-NEXT: irdl.operands(%[[v3]])
58-
// CHECK-NEXT: irdl.results()
5987
// CHECK-NEXT: }
6088

89+
// Check that various types are converted correctly.
90+
def Test_TypesOp : Test_Op<"types"> {
91+
let arguments = (ins I32:$a,
92+
SI64:$b,
93+
UI8:$c,
94+
Index:$d,
95+
F32:$e,
96+
NoneType:$f,
97+
Complex<F8E4M3FN>);
98+
}
99+
// CHECK-LABEL: irdl.operation @types {
100+
// CHECK-NEXT: %{{.*}} = irdl.is i32
101+
// CHECK-NEXT: %{{.*}} = irdl.is si64
102+
// CHECK-NEXT: %{{.*}} = irdl.is ui8
103+
// CHECK-NEXT: %{{.*}} = irdl.is index
104+
// CHECK-NEXT: %{{.*}} = irdl.is f32
105+
// CHECK-NEXT: %{{.*}} = irdl.is none
106+
// CHECK-NEXT: %{{.*}} = irdl.is complex<f8E4M3FN>
107+
// CHECK-NEXT: irdl.operands({{.*}})
108+
// CHECK-NEXT: }
61109

62110
// Check that variadics and optionals are converted correctly.
63111
def Test_VariadicityOp : Test_Op<"variadicity"> {
@@ -70,5 +118,4 @@ def Test_VariadicityOp : Test_Op<"variadicity"> {
70118
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
71119
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
72120
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
73-
// CHECK-NEXT: irdl.results()
74121
// CHECK-NEXT: }

mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp

Lines changed: 170 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,131 @@ llvm::cl::opt<std::string>
3939
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
4040
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
4141

42+
Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
43+
MLIRContext *ctx = builder.getContext();
44+
45+
if (pred.isCombined()) {
46+
auto combiner = pred.getDef().getValueAsDef("kind")->getName();
47+
if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
48+
std::vector<Value> constraints;
49+
for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
50+
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
51+
}
52+
if (combiner == "PredCombinerAnd") {
53+
auto op =
54+
builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
55+
return op.getOutput();
56+
}
57+
auto op =
58+
builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
59+
return op.getOutput();
60+
}
61+
}
62+
63+
std::string condition = pred.getCondition();
64+
// Build a CPredOp to match the C constraint built.
65+
irdl::CPredOp op = builder.create<irdl::CPredOp>(
66+
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
67+
return op;
68+
}
69+
70+
Value typeToConstraint(OpBuilder &builder, Type type) {
71+
MLIRContext *ctx = builder.getContext();
72+
auto op =
73+
builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
74+
return op.getOutput();
75+
}
76+
77+
std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
78+
79+
if (predRec.isSubClassOf("I")) {
80+
auto width = predRec.getValueAsInt("bitwidth");
81+
return IntegerType::get(ctx, width, IntegerType::Signless);
82+
}
83+
84+
if (predRec.isSubClassOf("SI")) {
85+
auto width = predRec.getValueAsInt("bitwidth");
86+
return IntegerType::get(ctx, width, IntegerType::Signed);
87+
}
88+
89+
if (predRec.isSubClassOf("UI")) {
90+
auto width = predRec.getValueAsInt("bitwidth");
91+
return IntegerType::get(ctx, width, IntegerType::Unsigned);
92+
}
93+
94+
// Index type
95+
if (predRec.getName() == "Index") {
96+
return IndexType::get(ctx);
97+
}
98+
99+
// Float types
100+
if (predRec.isSubClassOf("F")) {
101+
auto width = predRec.getValueAsInt("bitwidth");
102+
switch (width) {
103+
case 16:
104+
return FloatType::getF16(ctx);
105+
case 32:
106+
return FloatType::getF32(ctx);
107+
case 64:
108+
return FloatType::getF64(ctx);
109+
case 80:
110+
return FloatType::getF80(ctx);
111+
case 128:
112+
return FloatType::getF128(ctx);
113+
}
114+
}
115+
116+
if (predRec.getName() == "NoneType") {
117+
return NoneType::get(ctx);
118+
}
119+
120+
if (predRec.getName() == "BF16") {
121+
return FloatType::getBF16(ctx);
122+
}
123+
124+
if (predRec.getName() == "TF32") {
125+
return FloatType::getTF32(ctx);
126+
}
127+
128+
if (predRec.getName() == "F8E4M3FN") {
129+
return FloatType::getFloat8E4M3FN(ctx);
130+
}
131+
132+
if (predRec.getName() == "F8E5M2") {
133+
return FloatType::getFloat8E5M2(ctx);
134+
}
135+
136+
if (predRec.getName() == "F8E4M3") {
137+
return FloatType::getFloat8E4M3(ctx);
138+
}
139+
140+
if (predRec.getName() == "F8E4M3FNUZ") {
141+
return FloatType::getFloat8E4M3FNUZ(ctx);
142+
}
143+
144+
if (predRec.getName() == "F8E4M3B11FNUZ") {
145+
return FloatType::getFloat8E4M3B11FNUZ(ctx);
146+
}
147+
148+
if (predRec.getName() == "F8E5M2FNUZ") {
149+
return FloatType::getFloat8E5M2FNUZ(ctx);
150+
}
151+
152+
if (predRec.getName() == "F8E3M4") {
153+
return FloatType::getFloat8E3M4(ctx);
154+
}
155+
156+
if (predRec.isSubClassOf("Complex")) {
157+
const Record *elementRec = predRec.getValueAsDef("elementType");
158+
auto elementType = recordToType(ctx, *elementRec);
159+
if (elementType.has_value()) {
160+
return ComplexType::get(elementType.value());
161+
}
162+
}
163+
164+
return std::nullopt;
165+
}
166+
42167
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
43168
MLIRContext *ctx = builder.getContext();
44169
const Record &predRec = constraint.getDef();
@@ -78,11 +203,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
78203
return op.getOutput();
79204
}
80205

81-
std::string condition = constraint.getPredicate().getCondition();
82-
// Build a CPredOp to match the C constraint built.
83-
irdl::CPredOp op = builder.create<irdl::CPredOp>(
84-
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
85-
return op;
206+
// Integer types
207+
if (predRec.getName() == "AnyInteger") {
208+
auto op = builder.create<irdl::BaseOp>(
209+
UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
210+
return op.getOutput();
211+
}
212+
213+
if (predRec.isSubClassOf("AnyI")) {
214+
auto width = predRec.getValueAsInt("bitwidth");
215+
std::vector<Value> types = {
216+
typeToConstraint(builder,
217+
IntegerType::get(ctx, width, IntegerType::Signless)),
218+
typeToConstraint(builder,
219+
IntegerType::get(ctx, width, IntegerType::Signed)),
220+
typeToConstraint(builder,
221+
IntegerType::get(ctx, width, IntegerType::Unsigned))};
222+
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
223+
return op.getOutput();
224+
}
225+
226+
auto type = recordToType(ctx, predRec);
227+
228+
if (type.has_value()) {
229+
return typeToConstraint(builder, type.value());
230+
}
231+
232+
// Confined type
233+
if (predRec.isSubClassOf("ConfinedType")) {
234+
std::vector<Value> constraints;
235+
constraints.push_back(createConstraint(
236+
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
237+
for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
238+
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
239+
}
240+
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
241+
return op.getOutput();
242+
}
243+
244+
return createPredicate(builder, constraint.getPredicate());
86245
}
87246

88247
/// Returns the name of the operation without the dialect prefix.
@@ -131,10 +290,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
131290
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
132291

133292
// Create the operands and results operations.
134-
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
135-
operandVariadicity);
136-
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
137-
resultVariadicity);
293+
if (!operands.empty())
294+
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
295+
operandVariadicity);
296+
if (!results.empty())
297+
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
298+
resultVariadicity);
138299

139300
return op;
140301
}

0 commit comments

Comments
 (0)