Skip to content

Commit 930b9a5

Browse files
committed
Refactor tblgen-to-irdl script and support more types
1 parent 768598b commit 930b9a5

File tree

4 files changed

+224
-17
lines changed

4 files changed

+224
-17
lines changed

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,10 @@ class AllOfType<list<Type> allowedTypeList, string summary = "",
198198
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
199199
string cppType = type.cppType> : Type<
200200
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
201-
summary, cppType>;
201+
summary, cppType> {
202+
Type baseType = type;
203+
list<Pred> predicateList = predicates;
204+
}
202205

203206
// Integer types.
204207

mlir/test/tblgen-to-irdl/CMathDialect.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
2525

2626
// CHECK: irdl.operation @identity {
2727
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
28-
// CHECK-NEXT: irdl.operands()
2928
// CHECK-NEXT: irdl.results(%0)
3029
// CHECK-NEXT: }
3130
def CMath_IdentityOp : CMath_Op<"identity"> {

mlir/test/tblgen-to-irdl/TestDialect.td

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ def Test_AndOp : Test_Op<"and"> {
2828
// CHECK-LABEL: irdl.operation @and {
2929
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
3030
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
31-
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
31+
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
3232
// CHECK-NEXT: irdl.operands(%[[v2]])
33-
// CHECK-NEXT: irdl.results()
3433
// CHECK-NEXT: }
3534

3635

@@ -41,9 +40,37 @@ def Test_AnyOp : Test_Op<"any"> {
4140
// CHECK-LABEL: irdl.operation @any {
4241
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
4342
// CHECK-NEXT: irdl.operands(%[[v0]])
44-
// CHECK-NEXT: irdl.results()
4543
// CHECK-NEXT: }
4644

45+
// Check confined types are converted correctly.
46+
def Test_ConfinedOp : Test_Op<"confined"> {
47+
let arguments = (ins ConfinedType<I32, [IntNonNegative.predicate]>:$confined,
48+
ConfinedType<I8, [And<[IntMinValue<1>.predicate, IntMaxValue<2>.predicate]>]>:$bounded);
49+
}
50+
// CHECK-LABEL: irdl.operation @confined {
51+
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i32
52+
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.c_pred "{{.*}}"
53+
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
54+
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.is i8
55+
// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.c_pred "{{.*}}"
56+
// CHECK-NEXT: %[[v5:[^ ]*]] = irdl.c_pred "{{.*}}"
57+
// CHECK-NEXT: %[[v6:[^ ]*]] = irdl.all_of(%[[v4]], %[[v5]])
58+
// CHECK-NEXT: %[[v7:[^ ]*]] = irdl.all_of(%[[v3]], %[[v6]])
59+
// CHECK-NEXT: irdl.operands(%[[v2]], %[[v7]])
60+
// CHECK-NEXT: }
61+
62+
def Test_Integers : Test_Op<"integers"> {
63+
let arguments = (ins AnyI8:$any_int,
64+
AnyInteger:$any_integer);
65+
}
66+
// CHECK-LABEL: irdl.operation @integers {
67+
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i8
68+
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.is si8
69+
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.is ui8
70+
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
71+
// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.base "!builtin.integer"
72+
// CHECK-NEXT: irdl.operands(%[[v3]], %[[v4]])
73+
// CHECK-NEXT: }
4774

4875
// Check that AnyTypeOf is converted correctly.
4976
def Test_OrOp : Test_Op<"or"> {
@@ -53,11 +80,30 @@ def Test_OrOp : Test_Op<"or"> {
5380
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
5481
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
5582
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
56-
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
83+
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
5784
// CHECK-NEXT: irdl.operands(%[[v3]])
58-
// CHECK-NEXT: irdl.results()
5985
// CHECK-NEXT: }
6086

87+
// Check that various types are converted correctly.
88+
def Test_TypesOp : Test_Op<"types"> {
89+
let arguments = (ins I32:$a,
90+
SI64:$b,
91+
UI8:$c,
92+
Index:$d,
93+
F32:$e,
94+
NoneType:$f,
95+
Complex<F8E4M3FN>);
96+
}
97+
// CHECK-LABEL: irdl.operation @types {
98+
// CHECK-NEXT: %{{.*}} = irdl.is i32
99+
// CHECK-NEXT: %{{.*}} = irdl.is si64
100+
// CHECK-NEXT: %{{.*}} = irdl.is ui8
101+
// CHECK-NEXT: %{{.*}} = irdl.is index
102+
// CHECK-NEXT: %{{.*}} = irdl.is f32
103+
// CHECK-NEXT: %{{.*}} = irdl.is none
104+
// CHECK-NEXT: %{{.*}} = irdl.is complex<f8E4M3FN>
105+
// CHECK-NEXT: irdl.operands({{.*}})
106+
// CHECK-NEXT: }
61107

62108
// Check that variadics and optionals are converted correctly.
63109
def Test_VariadicityOp : Test_Op<"variadicity"> {
@@ -70,5 +116,4 @@ def Test_VariadicityOp : Test_Op<"variadicity"> {
70116
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
71117
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
72118
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
73-
// CHECK-NEXT: irdl.results()
74119
// CHECK-NEXT: }

mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp

Lines changed: 169 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,130 @@ llvm::cl::opt<std::string>
3939
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
4040
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
4141

42+
Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
43+
MLIRContext *ctx = builder.getContext();
44+
45+
if (pred.isCombined()) {
46+
auto combiner = pred.getDef().getValueAsDef("kind")->getName();
47+
if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
48+
std::vector<Value> constraints;
49+
for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
50+
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
51+
}
52+
if (combiner == "PredCombinerAnd") {
53+
auto op =
54+
builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
55+
return op.getOutput();
56+
}
57+
auto op =
58+
builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
59+
return op.getOutput();
60+
}
61+
}
62+
63+
std::string condition = pred.getCondition();
64+
// Build a CPredOp to match the C constraint built.
65+
irdl::CPredOp op = builder.create<irdl::CPredOp>(
66+
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
67+
return op;
68+
}
69+
70+
Value typeToConstraint(OpBuilder &builder, MLIRContext *ctx, Type type) {
71+
auto op =
72+
builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
73+
return op.getOutput();
74+
}
75+
76+
std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {
77+
78+
if (predRec.isSubClassOf("I")) {
79+
auto width = predRec.getValueAsInt("bitwidth");
80+
return IntegerType::get(ctx, width, IntegerType::Signless);
81+
}
82+
83+
if (predRec.isSubClassOf("SI")) {
84+
auto width = predRec.getValueAsInt("bitwidth");
85+
return IntegerType::get(ctx, width, IntegerType::Signed);
86+
}
87+
88+
if (predRec.isSubClassOf("UI")) {
89+
auto width = predRec.getValueAsInt("bitwidth");
90+
return IntegerType::get(ctx, width, IntegerType::Unsigned);
91+
}
92+
93+
// Index type
94+
if (predRec.getName() == "Index") {
95+
return IndexType::get(ctx);
96+
}
97+
98+
// Float types
99+
if (predRec.isSubClassOf("F")) {
100+
auto width = predRec.getValueAsInt("bitwidth");
101+
switch (width) {
102+
case 16:
103+
return FloatType::getF16(ctx);
104+
case 32:
105+
return FloatType::getF32(ctx);
106+
case 64:
107+
return FloatType::getF64(ctx);
108+
case 80:
109+
return FloatType::getF80(ctx);
110+
case 128:
111+
return FloatType::getF128(ctx);
112+
}
113+
}
114+
115+
if (predRec.getName() == "NoneType") {
116+
return NoneType::get(ctx);
117+
}
118+
119+
if (predRec.getName() == "BF16") {
120+
return FloatType::getBF16(ctx);
121+
}
122+
123+
if (predRec.getName() == "TF32") {
124+
return FloatType::getTF32(ctx);
125+
}
126+
127+
if (predRec.getName() == "F8E4M3FN") {
128+
return FloatType::getFloat8E4M3FN(ctx);
129+
}
130+
131+
if (predRec.getName() == "F8E5M2") {
132+
return FloatType::getFloat8E5M2(ctx);
133+
}
134+
135+
if (predRec.getName() == "F8E4M3") {
136+
return FloatType::getFloat8E4M3(ctx);
137+
}
138+
139+
if (predRec.getName() == "F8E4M3FNUZ") {
140+
return FloatType::getFloat8E4M3FNUZ(ctx);
141+
}
142+
143+
if (predRec.getName() == "F8E4M3B11FNUZ") {
144+
return FloatType::getFloat8E4M3B11FNUZ(ctx);
145+
}
146+
147+
if (predRec.getName() == "F8E5M2FNUZ") {
148+
return FloatType::getFloat8E5M2FNUZ(ctx);
149+
}
150+
151+
if (predRec.getName() == "F8E3M4") {
152+
return FloatType::getFloat8E3M4(ctx);
153+
}
154+
155+
if (predRec.isSubClassOf("Complex")) {
156+
const Record *elementRec = predRec.getValueAsDef("elementType");
157+
auto elementType = recordToType(ctx, *elementRec);
158+
if (elementType.has_value()) {
159+
return ComplexType::get(elementType.value());
160+
}
161+
}
162+
163+
return std::nullopt;
164+
}
165+
42166
Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
43167
MLIRContext *ctx = builder.getContext();
44168
const Record &predRec = constraint.getDef();
@@ -78,11 +202,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
78202
return op.getOutput();
79203
}
80204

81-
std::string condition = constraint.getPredicate().getCondition();
82-
// Build a CPredOp to match the C constraint built.
83-
irdl::CPredOp op = builder.create<irdl::CPredOp>(
84-
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
85-
return op;
205+
// Integer types
206+
if (predRec.getName() == "AnyInteger") {
207+
auto op = builder.create<irdl::BaseOp>(
208+
UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
209+
return op.getOutput();
210+
}
211+
212+
if (predRec.isSubClassOf("AnyI")) {
213+
auto width = predRec.getValueAsInt("bitwidth");
214+
std::vector<Value> types = {
215+
typeToConstraint(builder, ctx,
216+
IntegerType::get(ctx, width, IntegerType::Signless)),
217+
typeToConstraint(builder, ctx,
218+
IntegerType::get(ctx, width, IntegerType::Signed)),
219+
typeToConstraint(builder, ctx,
220+
IntegerType::get(ctx, width, IntegerType::Unsigned))};
221+
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
222+
return op.getOutput();
223+
}
224+
225+
auto type = recordToType(ctx, predRec);
226+
227+
if (type.has_value()) {
228+
return typeToConstraint(builder, ctx, type.value());
229+
}
230+
231+
// Confined type
232+
if (predRec.isSubClassOf("ConfinedType")) {
233+
std::vector<Value> constraints;
234+
constraints.push_back(createConstraint(
235+
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
236+
for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
237+
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
238+
}
239+
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
240+
return op.getOutput();
241+
}
242+
243+
return createPredicate(builder, constraint.getPredicate());
86244
}
87245

88246
/// Returns the name of the operation without the dialect prefix.
@@ -131,10 +289,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
131289
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
132290

133291
// Create the operands and results operations.
134-
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
135-
operandVariadicity);
136-
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
137-
resultVariadicity);
292+
if (!operands.empty())
293+
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
294+
operandVariadicity);
295+
if (!results.empty())
296+
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
297+
resultVariadicity);
138298

139299
return op;
140300
}

0 commit comments

Comments
 (0)