@@ -39,6 +39,131 @@ llvm::cl::opt<std::string>
39
39
selectedDialect (" dialect" , llvm::cl::desc(" The dialect to gen for" ),
40
40
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
41
41
42
+ Value createPredicate (OpBuilder &builder, tblgen::Pred pred) {
43
+ MLIRContext *ctx = builder.getContext ();
44
+
45
+ if (pred.isCombined ()) {
46
+ auto combiner = pred.getDef ().getValueAsDef (" kind" )->getName ();
47
+ if (combiner == " PredCombinerAnd" || combiner == " PredCombinerOr" ) {
48
+ std::vector<Value> constraints;
49
+ for (auto *child : pred.getDef ().getValueAsListOfDefs (" children" )) {
50
+ constraints.push_back (createPredicate (builder, tblgen::Pred (child)));
51
+ }
52
+ if (combiner == " PredCombinerAnd" ) {
53
+ auto op =
54
+ builder.create <irdl::AllOfOp>(UnknownLoc::get (ctx), constraints);
55
+ return op.getOutput ();
56
+ }
57
+ auto op =
58
+ builder.create <irdl::AnyOfOp>(UnknownLoc::get (ctx), constraints);
59
+ return op.getOutput ();
60
+ }
61
+ }
62
+
63
+ std::string condition = pred.getCondition ();
64
+ // Build a CPredOp to match the C constraint built.
65
+ irdl::CPredOp op = builder.create <irdl::CPredOp>(
66
+ UnknownLoc::get (ctx), StringAttr::get (ctx, condition));
67
+ return op;
68
+ }
69
+
70
+ Value typeToConstraint (OpBuilder &builder, Type type) {
71
+ MLIRContext *ctx = builder.getContext ();
72
+ auto op =
73
+ builder.create <irdl::IsOp>(UnknownLoc::get (ctx), TypeAttr::get (type));
74
+ return op.getOutput ();
75
+ }
76
+
77
+ std::optional<Type> recordToType (MLIRContext *ctx, const Record &predRec) {
78
+
79
+ if (predRec.isSubClassOf (" I" )) {
80
+ auto width = predRec.getValueAsInt (" bitwidth" );
81
+ return IntegerType::get (ctx, width, IntegerType::Signless);
82
+ }
83
+
84
+ if (predRec.isSubClassOf (" SI" )) {
85
+ auto width = predRec.getValueAsInt (" bitwidth" );
86
+ return IntegerType::get (ctx, width, IntegerType::Signed);
87
+ }
88
+
89
+ if (predRec.isSubClassOf (" UI" )) {
90
+ auto width = predRec.getValueAsInt (" bitwidth" );
91
+ return IntegerType::get (ctx, width, IntegerType::Unsigned);
92
+ }
93
+
94
+ // Index type
95
+ if (predRec.getName () == " Index" ) {
96
+ return IndexType::get (ctx);
97
+ }
98
+
99
+ // Float types
100
+ if (predRec.isSubClassOf (" F" )) {
101
+ auto width = predRec.getValueAsInt (" bitwidth" );
102
+ switch (width) {
103
+ case 16 :
104
+ return FloatType::getF16 (ctx);
105
+ case 32 :
106
+ return FloatType::getF32 (ctx);
107
+ case 64 :
108
+ return FloatType::getF64 (ctx);
109
+ case 80 :
110
+ return FloatType::getF80 (ctx);
111
+ case 128 :
112
+ return FloatType::getF128 (ctx);
113
+ }
114
+ }
115
+
116
+ if (predRec.getName () == " NoneType" ) {
117
+ return NoneType::get (ctx);
118
+ }
119
+
120
+ if (predRec.getName () == " BF16" ) {
121
+ return FloatType::getBF16 (ctx);
122
+ }
123
+
124
+ if (predRec.getName () == " TF32" ) {
125
+ return FloatType::getTF32 (ctx);
126
+ }
127
+
128
+ if (predRec.getName () == " F8E4M3FN" ) {
129
+ return FloatType::getFloat8E4M3FN (ctx);
130
+ }
131
+
132
+ if (predRec.getName () == " F8E5M2" ) {
133
+ return FloatType::getFloat8E5M2 (ctx);
134
+ }
135
+
136
+ if (predRec.getName () == " F8E4M3" ) {
137
+ return FloatType::getFloat8E4M3 (ctx);
138
+ }
139
+
140
+ if (predRec.getName () == " F8E4M3FNUZ" ) {
141
+ return FloatType::getFloat8E4M3FNUZ (ctx);
142
+ }
143
+
144
+ if (predRec.getName () == " F8E4M3B11FNUZ" ) {
145
+ return FloatType::getFloat8E4M3B11FNUZ (ctx);
146
+ }
147
+
148
+ if (predRec.getName () == " F8E5M2FNUZ" ) {
149
+ return FloatType::getFloat8E5M2FNUZ (ctx);
150
+ }
151
+
152
+ if (predRec.getName () == " F8E3M4" ) {
153
+ return FloatType::getFloat8E3M4 (ctx);
154
+ }
155
+
156
+ if (predRec.isSubClassOf (" Complex" )) {
157
+ const Record *elementRec = predRec.getValueAsDef (" elementType" );
158
+ auto elementType = recordToType (ctx, *elementRec);
159
+ if (elementType.has_value ()) {
160
+ return ComplexType::get (elementType.value ());
161
+ }
162
+ }
163
+
164
+ return std::nullopt;
165
+ }
166
+
42
167
Value createConstraint (OpBuilder &builder, tblgen::Constraint constraint) {
43
168
MLIRContext *ctx = builder.getContext ();
44
169
const Record &predRec = constraint.getDef ();
@@ -78,11 +203,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
78
203
return op.getOutput ();
79
204
}
80
205
81
- std::string condition = constraint.getPredicate ().getCondition ();
82
- // Build a CPredOp to match the C constraint built.
83
- irdl::CPredOp op = builder.create <irdl::CPredOp>(
84
- UnknownLoc::get (ctx), StringAttr::get (ctx, condition));
85
- return op;
206
+ // Integer types
207
+ if (predRec.getName () == " AnyInteger" ) {
208
+ auto op = builder.create <irdl::BaseOp>(
209
+ UnknownLoc::get (ctx), StringAttr::get (ctx, " !builtin.integer" ));
210
+ return op.getOutput ();
211
+ }
212
+
213
+ if (predRec.isSubClassOf (" AnyI" )) {
214
+ auto width = predRec.getValueAsInt (" bitwidth" );
215
+ std::vector<Value> types = {
216
+ typeToConstraint (builder,
217
+ IntegerType::get (ctx, width, IntegerType::Signless)),
218
+ typeToConstraint (builder,
219
+ IntegerType::get (ctx, width, IntegerType::Signed)),
220
+ typeToConstraint (builder,
221
+ IntegerType::get (ctx, width, IntegerType::Unsigned))};
222
+ auto op = builder.create <irdl::AnyOfOp>(UnknownLoc::get (ctx), types);
223
+ return op.getOutput ();
224
+ }
225
+
226
+ auto type = recordToType (ctx, predRec);
227
+
228
+ if (type.has_value ()) {
229
+ return typeToConstraint (builder, type.value ());
230
+ }
231
+
232
+ // Confined type
233
+ if (predRec.isSubClassOf (" ConfinedType" )) {
234
+ std::vector<Value> constraints;
235
+ constraints.push_back (createConstraint (
236
+ builder, tblgen::Constraint (predRec.getValueAsDef (" baseType" ))));
237
+ for (Record *child : predRec.getValueAsListOfDefs (" predicateList" )) {
238
+ constraints.push_back (createPredicate (builder, tblgen::Pred (child)));
239
+ }
240
+ auto op = builder.create <irdl::AllOfOp>(UnknownLoc::get (ctx), constraints);
241
+ return op.getOutput ();
242
+ }
243
+
244
+ return createPredicate (builder, constraint.getPredicate ());
86
245
}
87
246
88
247
// / Returns the name of the operation without the dialect prefix.
@@ -131,10 +290,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
131
290
auto [results, resultVariadicity] = getValues (tblgenOp.getResults ());
132
291
133
292
// Create the operands and results operations.
134
- consBuilder.create <irdl::OperandsOp>(UnknownLoc::get (ctx), operands,
135
- operandVariadicity);
136
- consBuilder.create <irdl::ResultsOp>(UnknownLoc::get (ctx), results,
137
- resultVariadicity);
293
+ if (!operands.empty ())
294
+ consBuilder.create <irdl::OperandsOp>(UnknownLoc::get (ctx), operands,
295
+ operandVariadicity);
296
+ if (!results.empty ())
297
+ consBuilder.create <irdl::ResultsOp>(UnknownLoc::get (ctx), results,
298
+ resultVariadicity);
138
299
139
300
return op;
140
301
}
0 commit comments