@@ -26,6 +26,264 @@ void IndexDialect::registerOperations() {
26
26
>();
27
27
}
28
28
29
+ Operation *IndexDialect::materializeConstant (OpBuilder &b, Attribute value,
30
+ Type type, Location loc) {
31
+ // Materialize bool constants as `i1`.
32
+ if (auto boolValue = dyn_cast<BoolAttr>(value)) {
33
+ if (!type.isSignlessInteger (1 ))
34
+ return nullptr ;
35
+ return b.create <BoolConstantOp>(loc, type, boolValue);
36
+ }
37
+
38
+ // Materialize integer attributes as `index`.
39
+ if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
40
+ if (!indexValue.getType ().isa <IndexType>() || !type.isa <IndexType>())
41
+ return nullptr ;
42
+ assert (indexValue.getValue ().getBitWidth () ==
43
+ IndexType::kInternalStorageBitWidth );
44
+ return b.create <ConstantOp>(loc, indexValue);
45
+ }
46
+
47
+ return nullptr ;
48
+ }
49
+
50
+ // ===----------------------------------------------------------------------===//
51
+ // Fold Utilities
52
+ // ===----------------------------------------------------------------------===//
53
+
54
+ // / Fold an index operation irrespective of the target bitwidth. The
55
+ // / operation must satisfy the property:
56
+ // /
57
+ // / ```
58
+ // / trunc(f(a, b)) = f(trunc(a), trunc(b))
59
+ // / ```
60
+ // /
61
+ // / For all values of `a` and `b`. The function accepts a lambda that computes
62
+ // / the integer result, which in turn must satisfy the above property.
63
+ static OpFoldResult foldBinaryOpUnchecked (
64
+ ArrayRef<Attribute> operands,
65
+ function_ref<APInt(const APInt &, const APInt &)> calculate) {
66
+ assert (operands.size () == 2 && " binary operation expected 2 operands" );
67
+ auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0 ]);
68
+ auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1 ]);
69
+ if (!lhs || !rhs)
70
+ return {};
71
+
72
+ APInt result = calculate (lhs.getValue (), rhs.getValue ());
73
+ assert (result.trunc (32 ) ==
74
+ calculate (lhs.getValue ().trunc (32 ), rhs.getValue ().trunc (32 )));
75
+ return IntegerAttr::get (IndexType::get (lhs.getContext ()), std::move (result));
76
+ }
77
+
78
+ // / Fold an index operation only if the truncated 64-bit result matches the
79
+ // / 32-bit result for operations that don't satisfy the above property. These
80
+ // / are operations where the upper bits of the operands can affect the lower
81
+ // / bits of the results.
82
+ // /
83
+ // / The function accepts a lambda that computes the integer result in both
84
+ // / 64-bit and 32-bit. If either call returns `None`, the operation is not
85
+ // / folded.
86
+ static OpFoldResult foldBinaryOpChecked (
87
+ ArrayRef<Attribute> operands,
88
+ function_ref<Optional<APInt>(const APInt &, const APInt &lhs)> calculate) {
89
+ assert (operands.size () == 2 && " binary operation expected 2 operands" );
90
+ auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0 ]);
91
+ auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1 ]);
92
+ // Only fold index operands.
93
+ if (!lhs || !rhs)
94
+ return {};
95
+
96
+ // Compute the 64-bit result and the 32-bit result.
97
+ Optional<APInt> result64 = calculate (lhs.getValue (), rhs.getValue ());
98
+ if (!result64)
99
+ return {};
100
+ Optional<APInt> result32 =
101
+ calculate (lhs.getValue ().trunc (32 ), rhs.getValue ().trunc (32 ));
102
+ if (!result32)
103
+ return {};
104
+ // Compare the truncated 64-bit result to the 32-bit result.
105
+ if (result64->trunc (32 ) != *result32)
106
+ return {};
107
+ // The operation can be folded for these particular operands.
108
+ return IntegerAttr::get (IndexType::get (lhs.getContext ()),
109
+ std::move (*result64));
110
+ }
111
+
112
+ // ===----------------------------------------------------------------------===//
113
+ // AddOp
114
+ // ===----------------------------------------------------------------------===//
115
+
116
+ OpFoldResult AddOp::fold (ArrayRef<Attribute> operands) {
117
+ return foldBinaryOpUnchecked (
118
+ operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
119
+ }
120
+
121
+ // ===----------------------------------------------------------------------===//
122
+ // SubOp
123
+ // ===----------------------------------------------------------------------===//
124
+
125
+ OpFoldResult SubOp::fold (ArrayRef<Attribute> operands) {
126
+ return foldBinaryOpUnchecked (
127
+ operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
128
+ }
129
+
130
+ // ===----------------------------------------------------------------------===//
131
+ // MulOp
132
+ // ===----------------------------------------------------------------------===//
133
+
134
+ OpFoldResult MulOp::fold (ArrayRef<Attribute> operands) {
135
+ return foldBinaryOpUnchecked (
136
+ operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
137
+ }
138
+
139
+ // ===----------------------------------------------------------------------===//
140
+ // DivSOp
141
+ // ===----------------------------------------------------------------------===//
142
+
143
+ OpFoldResult DivSOp::fold (ArrayRef<Attribute> operands) {
144
+ return foldBinaryOpChecked (
145
+ operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
146
+ // Don't fold division by zero.
147
+ if (rhs.isZero ())
148
+ return None;
149
+ return lhs.sdiv (rhs);
150
+ });
151
+ }
152
+
153
+ // ===----------------------------------------------------------------------===//
154
+ // DivUOp
155
+ // ===----------------------------------------------------------------------===//
156
+
157
+ OpFoldResult DivUOp::fold (ArrayRef<Attribute> operands) {
158
+ return foldBinaryOpChecked (
159
+ operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
160
+ // Don't fold division by zero.
161
+ if (rhs.isZero ())
162
+ return None;
163
+ return lhs.udiv (rhs);
164
+ });
165
+ }
166
+
167
+ // ===----------------------------------------------------------------------===//
168
+ // CeilDivSOp
169
+ // ===----------------------------------------------------------------------===//
170
+
171
+ // / Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
172
+ // / `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
173
+ static Optional<APInt> calculateCeilDivS (const APInt &n, const APInt &m) {
174
+ // Don't fold division by zero.
175
+ if (m.isZero ())
176
+ return None;
177
+ // Short-circuit the zero case.
178
+ if (n.isZero ())
179
+ return n;
180
+
181
+ bool mGtZ = m.sgt (0 );
182
+ if (n.sgt (0 ) != mGtZ ) {
183
+ // If the operands have different signs, compute the negative result. Signed
184
+ // division overflow is not possible, since if `m == -1`, `n` can be at most
185
+ // `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
186
+ return -(-n).sdiv (m);
187
+ }
188
+ // Otherwise, compute the positive result. Signed division overflow is not
189
+ // possible since if `m == -1`, `x` will be `1`.
190
+ int64_t x = mGtZ ? -1 : 1 ;
191
+ return (n + x).sdiv (m) + 1 ;
192
+ }
193
+
194
+ OpFoldResult CeilDivSOp::fold (ArrayRef<Attribute> operands) {
195
+ return foldBinaryOpChecked (operands, calculateCeilDivS);
196
+ }
197
+
198
+ // ===----------------------------------------------------------------------===//
199
+ // CeilDivUOp
200
+ // ===----------------------------------------------------------------------===//
201
+
202
+ OpFoldResult CeilDivUOp::fold (ArrayRef<Attribute> operands) {
203
+ // Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
204
+ return foldBinaryOpChecked (
205
+ operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
206
+ // Don't fold division by zero.
207
+ if (m.isZero ())
208
+ return None;
209
+ // Short-circuit the zero case.
210
+ if (n.isZero ())
211
+ return n;
212
+
213
+ return (n - 1 ).udiv (m) + 1 ;
214
+ });
215
+ }
216
+
217
+ // ===----------------------------------------------------------------------===//
218
+ // FloorDivSOp
219
+ // ===----------------------------------------------------------------------===//
220
+
221
+ // / Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
222
+ // / `n*m < 0 ? -1 - (x-n)/m : n/m`.
223
+ static Optional<APInt> calculateFloorDivS (const APInt &n, const APInt &m) {
224
+ // Don't fold division by zero.
225
+ if (m.isZero ())
226
+ return None;
227
+ // Short-circuit the zero case.
228
+ if (n.isZero ())
229
+ return n;
230
+
231
+ bool mLtZ = m.slt (0 );
232
+ if (n.slt (0 ) == mLtZ ) {
233
+ // If the operands have the same sign, compute the positive result.
234
+ return n.sdiv (m);
235
+ }
236
+ // If the operands have different signs, compute the negative result. Signed
237
+ // division overflow is not possible since if `m == -1`, `x` will be 1 and
238
+ // `n` can be at most `INT_MAX`.
239
+ int64_t x = mLtZ ? 1 : -1 ;
240
+ return -1 - (x - n).sdiv (m);
241
+ }
242
+
243
+ OpFoldResult FloorDivSOp::fold (ArrayRef<Attribute> operands) {
244
+ return foldBinaryOpChecked (operands, calculateFloorDivS);
245
+ }
246
+
247
+ // ===----------------------------------------------------------------------===//
248
+ // RemSOp
249
+ // ===----------------------------------------------------------------------===//
250
+
251
+ OpFoldResult RemSOp::fold (ArrayRef<Attribute> operands) {
252
+ return foldBinaryOpChecked (operands, [](const APInt &lhs, const APInt &rhs) {
253
+ return lhs.srem (rhs);
254
+ });
255
+ }
256
+
257
+ // ===----------------------------------------------------------------------===//
258
+ // RemUOp
259
+ // ===----------------------------------------------------------------------===//
260
+
261
+ OpFoldResult RemUOp::fold (ArrayRef<Attribute> operands) {
262
+ return foldBinaryOpChecked (operands, [](const APInt &lhs, const APInt &rhs) {
263
+ return lhs.urem (rhs);
264
+ });
265
+ }
266
+
267
+ // ===----------------------------------------------------------------------===//
268
+ // MaxSOp
269
+ // ===----------------------------------------------------------------------===//
270
+
271
+ OpFoldResult MaxSOp::fold (ArrayRef<Attribute> operands) {
272
+ return foldBinaryOpChecked (operands, [](const APInt &lhs, const APInt &rhs) {
273
+ return lhs.sgt (rhs) ? lhs : rhs;
274
+ });
275
+ }
276
+
277
+ // ===----------------------------------------------------------------------===//
278
+ // MaxUOp
279
+ // ===----------------------------------------------------------------------===//
280
+
281
+ OpFoldResult MaxUOp::fold (ArrayRef<Attribute> operands) {
282
+ return foldBinaryOpChecked (operands, [](const APInt &lhs, const APInt &rhs) {
283
+ return lhs.ugt (rhs) ? lhs : rhs;
284
+ });
285
+ }
286
+
29
287
// ===----------------------------------------------------------------------===//
30
288
// CastSOp
31
289
// ===----------------------------------------------------------------------===//
@@ -42,6 +300,74 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
42
300
return lhsTypes.front ().isa <IndexType>() != rhsTypes.front ().isa <IndexType>();
43
301
}
44
302
303
+ // ===----------------------------------------------------------------------===//
304
+ // CmpOp
305
+ // ===----------------------------------------------------------------------===//
306
+
307
+ // / Compare two integers according to the comparison predicate.
308
+ bool compareIndices (const APInt &lhs, const APInt &rhs,
309
+ IndexCmpPredicate pred) {
310
+ switch (pred) {
311
+ case IndexCmpPredicate::EQ:
312
+ return lhs.eq (rhs);
313
+ case IndexCmpPredicate::NE:
314
+ return lhs.ne (rhs);
315
+ case IndexCmpPredicate::SGE:
316
+ return lhs.sge (rhs);
317
+ case IndexCmpPredicate::SGT:
318
+ return lhs.sgt (rhs);
319
+ case IndexCmpPredicate::SLE:
320
+ return lhs.sle (rhs);
321
+ case IndexCmpPredicate::SLT:
322
+ return lhs.slt (rhs);
323
+ case IndexCmpPredicate::UGE:
324
+ return lhs.uge (rhs);
325
+ case IndexCmpPredicate::UGT:
326
+ return lhs.ugt (rhs);
327
+ case IndexCmpPredicate::ULE:
328
+ return lhs.ule (rhs);
329
+ case IndexCmpPredicate::ULT:
330
+ return lhs.ult (rhs);
331
+ }
332
+ llvm_unreachable (" unhandled IndexCmpPredicate predicate" );
333
+ }
334
+
335
+ OpFoldResult CmpOp::fold (ArrayRef<Attribute> operands) {
336
+ assert (operands.size () == 2 && " compare expected 2 operands" );
337
+ auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0 ]);
338
+ auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1 ]);
339
+ if (!lhs || !rhs)
340
+ return {};
341
+
342
+ // Perform the comparison in 64-bit and 32-bit.
343
+ bool result64 = compareIndices (lhs.getValue (), rhs.getValue (), getPred ());
344
+ bool result32 = compareIndices (lhs.getValue ().trunc (32 ),
345
+ rhs.getValue ().trunc (32 ), getPred ());
346
+ if (result64 != result32)
347
+ return {};
348
+ return BoolAttr::get (getContext (), result64);
349
+ }
350
+
351
+ // ===----------------------------------------------------------------------===//
352
+ // ConstantOp
353
+ // ===----------------------------------------------------------------------===//
354
+
355
+ OpFoldResult ConstantOp::fold (ArrayRef<Attribute> operands) {
356
+ return getValueAttr ();
357
+ }
358
+
359
+ void ConstantOp::build (OpBuilder &b, OperationState &state, int64_t value) {
360
+ build (b, state, b.getIndexType (), b.getIndexAttr (value));
361
+ }
362
+
363
+ // ===----------------------------------------------------------------------===//
364
+ // BoolConstantOp
365
+ // ===----------------------------------------------------------------------===//
366
+
367
+ OpFoldResult BoolConstantOp::fold (ArrayRef<Attribute> operands) {
368
+ return getValueAttr ();
369
+ }
370
+
45
371
// ===----------------------------------------------------------------------===//
46
372
// ODS-Generated Definitions
47
373
// ===----------------------------------------------------------------------===//
0 commit comments