@@ -39,6 +39,130 @@ llvm::cl::opt<std::string>
39
39
selectedDialect (" dialect" , llvm::cl::desc(" The dialect to gen for" ),
40
40
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
41
41
42
+ Value createPredicate (OpBuilder &builder, tblgen::Pred pred) {
43
+ MLIRContext *ctx = builder.getContext ();
44
+
45
+ if (pred.isCombined ()) {
46
+ auto combiner = pred.getDef ().getValueAsDef (" kind" )->getName ();
47
+ if (combiner == " PredCombinerAnd" || combiner == " PredCombinerOr" ) {
48
+ std::vector<Value> constraints;
49
+ for (auto *child : pred.getDef ().getValueAsListOfDefs (" children" )) {
50
+ constraints.push_back (createPredicate (builder, tblgen::Pred (child)));
51
+ }
52
+ if (combiner == " PredCombinerAnd" ) {
53
+ auto op =
54
+ builder.create <irdl::AllOfOp>(UnknownLoc::get (ctx), constraints);
55
+ return op.getOutput ();
56
+ }
57
+ auto op =
58
+ builder.create <irdl::AnyOfOp>(UnknownLoc::get (ctx), constraints);
59
+ return op.getOutput ();
60
+ }
61
+ }
62
+
63
+ std::string condition = pred.getCondition ();
64
+ // Build a CPredOp to match the C constraint built.
65
+ irdl::CPredOp op = builder.create <irdl::CPredOp>(
66
+ UnknownLoc::get (ctx), StringAttr::get (ctx, condition));
67
+ return op;
68
+ }
69
+
70
+ Value typeToConstraint (OpBuilder &builder, MLIRContext *ctx, Type type) {
71
+ auto op =
72
+ builder.create <irdl::IsOp>(UnknownLoc::get (ctx), TypeAttr::get (type));
73
+ return op.getOutput ();
74
+ }
75
+
76
+ std::optional<Type> recordToType (MLIRContext *ctx, const Record &predRec) {
77
+
78
+ if (predRec.isSubClassOf (" I" )) {
79
+ auto width = predRec.getValueAsInt (" bitwidth" );
80
+ return IntegerType::get (ctx, width, IntegerType::Signless);
81
+ }
82
+
83
+ if (predRec.isSubClassOf (" SI" )) {
84
+ auto width = predRec.getValueAsInt (" bitwidth" );
85
+ return IntegerType::get (ctx, width, IntegerType::Signed);
86
+ }
87
+
88
+ if (predRec.isSubClassOf (" UI" )) {
89
+ auto width = predRec.getValueAsInt (" bitwidth" );
90
+ return IntegerType::get (ctx, width, IntegerType::Unsigned);
91
+ }
92
+
93
+ // Index type
94
+ if (predRec.getName () == " Index" ) {
95
+ return IndexType::get (ctx);
96
+ }
97
+
98
+ // Float types
99
+ if (predRec.isSubClassOf (" F" )) {
100
+ auto width = predRec.getValueAsInt (" bitwidth" );
101
+ switch (width) {
102
+ case 16 :
103
+ return FloatType::getF16 (ctx);
104
+ case 32 :
105
+ return FloatType::getF32 (ctx);
106
+ case 64 :
107
+ return FloatType::getF64 (ctx);
108
+ case 80 :
109
+ return FloatType::getF80 (ctx);
110
+ case 128 :
111
+ return FloatType::getF128 (ctx);
112
+ }
113
+ }
114
+
115
+ if (predRec.getName () == " NoneType" ) {
116
+ return NoneType::get (ctx);
117
+ }
118
+
119
+ if (predRec.getName () == " BF16" ) {
120
+ return FloatType::getBF16 (ctx);
121
+ }
122
+
123
+ if (predRec.getName () == " TF32" ) {
124
+ return FloatType::getTF32 (ctx);
125
+ }
126
+
127
+ if (predRec.getName () == " F8E4M3FN" ) {
128
+ return FloatType::getFloat8E4M3FN (ctx);
129
+ }
130
+
131
+ if (predRec.getName () == " F8E5M2" ) {
132
+ return FloatType::getFloat8E5M2 (ctx);
133
+ }
134
+
135
+ if (predRec.getName () == " F8E4M3" ) {
136
+ return FloatType::getFloat8E4M3 (ctx);
137
+ }
138
+
139
+ if (predRec.getName () == " F8E4M3FNUZ" ) {
140
+ return FloatType::getFloat8E4M3FNUZ (ctx);
141
+ }
142
+
143
+ if (predRec.getName () == " F8E4M3B11FNUZ" ) {
144
+ return FloatType::getFloat8E4M3B11FNUZ (ctx);
145
+ }
146
+
147
+ if (predRec.getName () == " F8E5M2FNUZ" ) {
148
+ return FloatType::getFloat8E5M2FNUZ (ctx);
149
+ }
150
+
151
+ if (predRec.getName () == " F8E3M4" ) {
152
+ return FloatType::getFloat8E3M4 (ctx);
153
+ }
154
+
155
+ if (predRec.isSubClassOf (" Complex" )) {
156
+ const Record *elementRec = predRec.getValueAsDef (" elementType" );
157
+ auto elementType = recordToType (ctx, *elementRec);
158
+ if (elementType.has_value ()) {
159
+ return ComplexType::get (elementType.value ());
160
+ }
161
+ }
162
+
163
+ return std::nullopt;
164
+ }
165
+
42
166
Value createConstraint (OpBuilder &builder, tblgen::Constraint constraint) {
43
167
MLIRContext *ctx = builder.getContext ();
44
168
const Record &predRec = constraint.getDef ();
@@ -78,11 +202,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
78
202
return op.getOutput ();
79
203
}
80
204
81
- std::string condition = constraint.getPredicate ().getCondition ();
82
- // Build a CPredOp to match the C constraint built.
83
- irdl::CPredOp op = builder.create <irdl::CPredOp>(
84
- UnknownLoc::get (ctx), StringAttr::get (ctx, condition));
85
- return op;
205
+ // Integer types
206
+ if (predRec.getName () == " AnyInteger" ) {
207
+ auto op = builder.create <irdl::BaseOp>(
208
+ UnknownLoc::get (ctx), StringAttr::get (ctx, " !builtin.integer" ));
209
+ return op.getOutput ();
210
+ }
211
+
212
+ if (predRec.isSubClassOf (" AnyI" )) {
213
+ auto width = predRec.getValueAsInt (" bitwidth" );
214
+ std::vector<Value> types = {
215
+ typeToConstraint (builder, ctx,
216
+ IntegerType::get (ctx, width, IntegerType::Signless)),
217
+ typeToConstraint (builder, ctx,
218
+ IntegerType::get (ctx, width, IntegerType::Signed)),
219
+ typeToConstraint (builder, ctx,
220
+ IntegerType::get (ctx, width, IntegerType::Unsigned))};
221
+ auto op = builder.create <irdl::AnyOfOp>(UnknownLoc::get (ctx), types);
222
+ return op.getOutput ();
223
+ }
224
+
225
+ auto type = recordToType (ctx, predRec);
226
+
227
+ if (type.has_value ()) {
228
+ return typeToConstraint (builder, ctx, type.value ());
229
+ }
230
+
231
+ // Confined type
232
+ if (predRec.isSubClassOf (" ConfinedType" )) {
233
+ std::vector<Value> constraints;
234
+ constraints.push_back (createConstraint (
235
+ builder, tblgen::Constraint (predRec.getValueAsDef (" baseType" ))));
236
+ for (Record *child : predRec.getValueAsListOfDefs (" predicateList" )) {
237
+ constraints.push_back (createPredicate (builder, tblgen::Pred (child)));
238
+ }
239
+ auto op = builder.create <irdl::AllOfOp>(UnknownLoc::get (ctx), constraints);
240
+ return op.getOutput ();
241
+ }
242
+
243
+ return createPredicate (builder, constraint.getPredicate ());
86
244
}
87
245
88
246
// / Returns the name of the operation without the dialect prefix.
@@ -131,10 +289,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
131
289
auto [results, resultVariadicity] = getValues (tblgenOp.getResults ());
132
290
133
291
// Create the operands and results operations.
134
- consBuilder.create <irdl::OperandsOp>(UnknownLoc::get (ctx), operands,
135
- operandVariadicity);
136
- consBuilder.create <irdl::ResultsOp>(UnknownLoc::get (ctx), results,
137
- resultVariadicity);
292
+ if (!operands.empty ())
293
+ consBuilder.create <irdl::OperandsOp>(UnknownLoc::get (ctx), operands,
294
+ operandVariadicity);
295
+ if (!results.empty ())
296
+ consBuilder.create <irdl::ResultsOp>(UnknownLoc::get (ctx), results,
297
+ resultVariadicity);
138
298
139
299
return op;
140
300
}
0 commit comments