2
2
#include < cstdint>
3
3
#include < llvm/ADT/APFloat.h>
4
4
#include < llvm/ADT/APInt.h>
5
+ #include < llvm/ADT/APSInt.h>
5
6
#include < llvm/ADT/ArrayRef.h>
7
+ #include < llvm/ADT/TypeSwitch.h>
6
8
#include < llvm/Support/Casting.h>
7
9
#include < llvm/Support/ErrorHandling.h>
8
10
#include < mlir/Dialect/PDL/IR/Builtins.h>
@@ -56,52 +58,200 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
56
58
return rewriter.getArrayAttr (values);
57
59
}
58
60
59
- LogicalResult add (mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
60
- llvm::ArrayRef<mlir::PDLValue> args) {
61
- assert (args.size () == 2 && " Expected 2 arguments" );
61
+ template <UnaryOpKind T>
62
+ LogicalResult unaryOp (PatternRewriter &rewriter, PDLResultList &results,
63
+ ArrayRef<PDLValue> args) {
64
+ assert (args.size () == 1 && " Expected one operand for unary operation" );
65
+ auto operandAttr = args[0 ].cast <Attribute>();
66
+
67
+ if (auto operandIntAttr = dyn_cast_or_null<IntegerAttr>(operandAttr)) {
68
+ auto integerType = cast<IntegerType>(operandIntAttr.getType ());
69
+ auto bitWidth = integerType.getIntOrFloatBitWidth ();
70
+
71
+ if constexpr (T == UnaryOpKind::exp2) {
72
+ uint64_t resultVal =
73
+ integerType.isUnsigned () || integerType.isSignless ()
74
+ ? std::pow (2 , operandIntAttr.getValue ().getZExtValue ())
75
+ : std::pow (2 , operandIntAttr.getValue ().getSExtValue ());
76
+
77
+ APInt resultInt (bitWidth, resultVal, integerType.isSigned ());
78
+
79
+ bool isOverflow = integerType.isSigned ()
80
+ ? resultInt.slt (operandIntAttr.getValue ())
81
+ : resultInt.ult (operandIntAttr.getValue ());
82
+
83
+ if (isOverflow)
84
+ return failure ();
85
+
86
+ results.push_back (rewriter.getIntegerAttr (integerType, resultInt));
87
+ } else if constexpr (T == UnaryOpKind::log2) {
88
+ auto getIntegerAsAttr = [&](const APSInt &value) {
89
+ int32_t log2Value = value.exactLogBase2 ();
90
+ assert (log2Value >= 0 &&
91
+ " log2 of an integer is expected to return an exact integer." );
92
+ return rewriter.getIntegerAttr (
93
+ integerType,
94
+ APSInt (APInt (bitWidth, log2Value), integerType.isUnsigned ()));
95
+ };
96
+ // for log2 we treat signless integer as signed
97
+ if (integerType.isSignless ())
98
+ results.push_back (
99
+ getIntegerAsAttr (APSInt (operandIntAttr.getValue (), false )));
100
+ else
101
+ results.push_back (getIntegerAsAttr (operandIntAttr.getAPSInt ()));
102
+ } else {
103
+ llvm::llvm_unreachable_internal (
104
+ " encountered an unsupported unary operator" );
105
+ return failure ();
106
+ }
107
+ return success ();
108
+ }
109
+
110
+ if (auto operandFloatAttr = dyn_cast_or_null<FloatAttr>(operandAttr)) {
111
+ // auto floatType = operandFloatAttr.getType();
112
+
113
+ if constexpr (T == UnaryOpKind::exp2) {
114
+ // auto maxVal = APFloat::getLargest(llvm::APFloat::IEEEhalf());
115
+ // auto minVal = APFloat::getSmallest(llvm::APFloat::IEEEhalf());
116
+
117
+ auto type = operandFloatAttr.getType ();
118
+
119
+ return TypeSwitch<Type, LogicalResult>(type)
120
+ .template Case <Float64Type>([&results, &rewriter,
121
+ &operandFloatAttr](auto floatType) {
122
+ APFloat resultAPFloat (
123
+ std::exp2 (operandFloatAttr.getValue ().convertToDouble ()));
124
+
125
+ // check overflow
126
+ if (!resultAPFloat.isNormal ())
127
+ return failure ();
128
+
129
+ results.push_back (rewriter.getFloatAttr (floatType, resultAPFloat));
130
+ return success ();
131
+ })
132
+ .template Case <Float32Type, Float16Type, BFloat16Type>(
133
+ [&results, &rewriter, &operandFloatAttr](auto floatType) {
134
+ APFloat resultAPFloat (
135
+ std::exp2 (operandFloatAttr.getValue ().convertToFloat ()));
136
+
137
+ // check overflow and underflow
138
+ // If overflow happens, resultAPFloat is inf
139
+ // If underflow happens, resultAPFloat is 0
140
+ if (!resultAPFloat.isNormal ())
141
+ return failure ();
142
+
143
+ results.push_back (
144
+ rewriter.getFloatAttr (floatType, resultAPFloat));
145
+ return success ();
146
+ })
147
+ .Default ([](Type /* type*/ ) { return failure (); });
148
+ } else if constexpr (T == UnaryOpKind::log2) {
149
+ auto minF32 = APFloat::getSmallest (llvm::APFloat::IEEEsingle ());
150
+
151
+ APFloat resultFloat ((float )operandFloatAttr.getValue ().getExactLog2 ());
152
+ results.push_back (
153
+ rewriter.getFloatAttr (operandFloatAttr.getType (), resultFloat));
154
+ }
155
+ return success ();
156
+ }
157
+ return failure ();
158
+ }
159
+
160
+ template <BinaryOpKind T>
161
+ LogicalResult binaryOp (PatternRewriter &rewriter, PDLResultList &results,
162
+ llvm::ArrayRef<PDLValue> args) {
163
+ assert (args.size () == 2 && " Expected two operands for binary operation" );
62
164
auto lhsAttr = args[0 ].cast <Attribute>();
63
165
auto rhsAttr = args[1 ].cast <Attribute>();
64
166
65
- // Integer
66
167
if (auto lhsIntAttr = dyn_cast_or_null<IntegerAttr>(lhsAttr)) {
67
168
auto rhsIntAttr = dyn_cast_or_null<IntegerAttr>(rhsAttr);
68
- if (!rhsIntAttr || lhsIntAttr.getType () != rhsIntAttr.getType ())
169
+ if (!rhsIntAttr || lhsIntAttr.getType () != rhsIntAttr.getType ()) {
69
170
return failure ();
171
+ }
70
172
71
173
auto integerType = lhsIntAttr.getType ();
72
-
73
- bool isOverflow;
74
- llvm::APInt resultAPInt;
75
- if (integerType.isUnsignedInteger () || integerType.isSignlessInteger ()) {
76
- resultAPInt =
77
- lhsIntAttr.getValue ().uadd_ov (rhsIntAttr.getValue (), isOverflow);
174
+ APInt resultAPInt;
175
+ bool isOverflow = false ;
176
+ if constexpr (T == BinaryOpKind::add) {
177
+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
178
+ resultAPInt =
179
+ lhsIntAttr.getValue ().uadd_ov (rhsIntAttr.getValue (), isOverflow);
180
+ } else {
181
+ resultAPInt =
182
+ lhsIntAttr.getValue ().sadd_ov (rhsIntAttr.getValue (), isOverflow);
183
+ }
184
+ } else if constexpr (T == BinaryOpKind::sub) {
185
+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
186
+ resultAPInt =
187
+ lhsIntAttr.getValue ().usub_ov (rhsIntAttr.getValue (), isOverflow);
188
+ } else {
189
+ resultAPInt =
190
+ lhsIntAttr.getValue ().ssub_ov (rhsIntAttr.getValue (), isOverflow);
191
+ }
192
+ } else if constexpr (T == BinaryOpKind::mul) {
193
+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
194
+ resultAPInt =
195
+ lhsIntAttr.getValue ().umul_ov (rhsIntAttr.getValue (), isOverflow);
196
+ } else {
197
+ resultAPInt =
198
+ lhsIntAttr.getValue ().smul_ov (rhsIntAttr.getValue (), isOverflow);
199
+ }
200
+ } else if constexpr (T == BinaryOpKind::div) {
201
+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
202
+ resultAPInt = lhsIntAttr.getValue ().udiv (rhsIntAttr.getValue ());
203
+ } else {
204
+ resultAPInt =
205
+ lhsIntAttr.getValue ().sdiv_ov (rhsIntAttr.getValue (), isOverflow);
206
+ }
207
+ } else if constexpr (T == BinaryOpKind::mod) {
208
+ if (integerType.isSignlessInteger () || integerType.isUnsignedInteger ()) {
209
+ resultAPInt = lhsIntAttr.getValue ().urem (rhsIntAttr.getValue ());
210
+ } else {
211
+ resultAPInt = lhsIntAttr.getValue ().srem (rhsIntAttr.getValue ());
212
+ }
78
213
} else {
79
- resultAPInt =
80
- lhsIntAttr.getValue ().sadd_ov (rhsIntAttr.getValue (), isOverflow);
214
+ assert (false && " Unsupported binary operator" );
81
215
}
82
216
83
- if (isOverflow) {
217
+ if (isOverflow)
84
218
return failure ();
85
- }
86
219
87
220
results.push_back (rewriter.getIntegerAttr (integerType, resultAPInt));
88
221
return success ();
89
222
}
90
223
91
- // Float
92
224
if (auto lhsFloatAttr = dyn_cast_or_null<FloatAttr>(lhsAttr)) {
93
225
auto rhsFloatAttr = dyn_cast_or_null<FloatAttr>(rhsAttr);
94
- if (!rhsFloatAttr || lhsFloatAttr.getType () != rhsFloatAttr.getType ())
226
+ if (!rhsFloatAttr || lhsFloatAttr.getType () != rhsFloatAttr.getType ()) {
95
227
return failure ();
228
+ }
96
229
97
230
APFloat lhsVal = lhsFloatAttr.getValue ();
98
231
APFloat rhsVal = rhsFloatAttr.getValue ();
99
232
APFloat resultVal (lhsVal);
100
233
auto floatType = lhsFloatAttr.getType ();
101
234
102
- bool isOverflow =
103
- resultVal.add (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
104
- if (isOverflow) {
235
+ APFloat::opStatus operationStatus;
236
+ if constexpr (T == BinaryOpKind::add) {
237
+ operationStatus =
238
+ resultVal.add (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
239
+ } else if constexpr (T == BinaryOpKind::sub) {
240
+ operationStatus =
241
+ resultVal.subtract (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
242
+ } else if constexpr (T == BinaryOpKind::mul) {
243
+ operationStatus =
244
+ resultVal.multiply (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
245
+ } else if constexpr (T == BinaryOpKind::div) {
246
+ operationStatus =
247
+ resultVal.divide (rhsVal, llvm::APFloatBase::rmNearestTiesToEven);
248
+ } else if constexpr (T == BinaryOpKind::mod) {
249
+ operationStatus = resultVal.mod (rhsVal);
250
+ } else {
251
+ assert (false && " Unsupported binary operator" );
252
+ }
253
+
254
+ if (operationStatus != APFloat::opOK) {
105
255
return failure ();
106
256
}
107
257
@@ -110,6 +260,41 @@ LogicalResult add(mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
110
260
}
111
261
return failure ();
112
262
}
263
+
264
+ LogicalResult add (mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
265
+ llvm::ArrayRef<mlir::PDLValue> args) {
266
+ return binaryOp<BinaryOpKind::add>(rewriter, results, args);
267
+ }
268
+
269
+ LogicalResult sub (mlir::PatternRewriter &rewriter, mlir::PDLResultList &results,
270
+ llvm::ArrayRef<mlir::PDLValue> args) {
271
+ return binaryOp<BinaryOpKind::sub>(rewriter, results, args);
272
+ }
273
+
274
+ LogicalResult mul (PatternRewriter &rewriter, PDLResultList &results,
275
+ llvm::ArrayRef<PDLValue> args) {
276
+ return binaryOp<BinaryOpKind::mul>(rewriter, results, args);
277
+ }
278
+
279
+ LogicalResult div (PatternRewriter &rewriter, PDLResultList &results,
280
+ llvm::ArrayRef<PDLValue> args) {
281
+ return binaryOp<BinaryOpKind::div>(rewriter, results, args);
282
+ }
283
+
284
+ LogicalResult mod (PatternRewriter &rewriter, PDLResultList &results,
285
+ ArrayRef<PDLValue> args) {
286
+ return binaryOp<BinaryOpKind::mod>(rewriter, results, args);
287
+ }
288
+
289
+ LogicalResult exp2 (PatternRewriter &rewriter, PDLResultList &results,
290
+ llvm::ArrayRef<PDLValue> args) {
291
+ return unaryOp<UnaryOpKind::exp2>(rewriter, results, args);
292
+ }
293
+
294
+ LogicalResult log2 (PatternRewriter &rewriter, PDLResultList &results,
295
+ llvm::ArrayRef<PDLValue> args) {
296
+ return unaryOp<UnaryOpKind::log2>(rewriter, results, args);
297
+ }
113
298
} // namespace builtin
114
299
115
300
void registerBuiltins (PDLPatternModule &pdlPattern) {
@@ -128,6 +313,27 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
128
313
pdlPattern.registerConstraintFunctionWithResults (
129
314
" __builtin_addEntryToDictionaryAttr_constraint" ,
130
315
addEntryToDictionaryAttr);
131
- pdlPattern.registerConstraintFunctionWithResults (" __builtin_add" , add);
316
+ pdlPattern.registerRewriteFunction (" __builtin_mulRewrite" , mul);
317
+ pdlPattern.registerRewriteFunction (" __builtin_divRewrite" , div);
318
+ pdlPattern.registerRewriteFunction (" __builtin_modRewrite" , mod);
319
+ pdlPattern.registerRewriteFunction (" __builtin_addRewrite" , add);
320
+ pdlPattern.registerRewriteFunction (" __builtin_subRewrite" , sub);
321
+ pdlPattern.registerRewriteFunction (" __builtin_log2Rewrite" , log2);
322
+ pdlPattern.registerRewriteFunction (" __builtin_exp2Rewrite" , exp2);
323
+
324
+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_mulConstraint" ,
325
+ mul);
326
+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_divConstraint" ,
327
+ div);
328
+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_modConstraint" ,
329
+ mod);
330
+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_addConstraint" ,
331
+ add);
332
+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_subConstraint" ,
333
+ sub);
334
+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_log2Constraint" ,
335
+ log2);
336
+ pdlPattern.registerConstraintFunctionWithResults (" __builtin_exp2Constraint" ,
337
+ exp2);
132
338
}
133
- } // namespace mlir::pdl
339
+ } // namespace mlir::pdl
0 commit comments