@@ -115,36 +115,40 @@ static OpFoldResult foldBinaryOpChecked(
115
115
// AddOp
116
116
// ===----------------------------------------------------------------------===//
117
117
118
- OpFoldResult AddOp::fold (ArrayRef<Attribute> operands ) {
118
+ OpFoldResult AddOp::fold (FoldAdaptor adaptor ) {
119
119
return foldBinaryOpUnchecked (
120
- operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
120
+ adaptor.getOperands (),
121
+ [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
121
122
}
122
123
123
124
// ===----------------------------------------------------------------------===//
124
125
// SubOp
125
126
// ===----------------------------------------------------------------------===//
126
127
127
- OpFoldResult SubOp::fold (ArrayRef<Attribute> operands ) {
128
+ OpFoldResult SubOp::fold (FoldAdaptor adaptor ) {
128
129
return foldBinaryOpUnchecked (
129
- operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
130
+ adaptor.getOperands (),
131
+ [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
130
132
}
131
133
132
134
// ===----------------------------------------------------------------------===//
133
135
// MulOp
134
136
// ===----------------------------------------------------------------------===//
135
137
136
- OpFoldResult MulOp::fold (ArrayRef<Attribute> operands ) {
138
+ OpFoldResult MulOp::fold (FoldAdaptor adaptor ) {
137
139
return foldBinaryOpUnchecked (
138
- operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
140
+ adaptor.getOperands (),
141
+ [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
139
142
}
140
143
141
144
// ===----------------------------------------------------------------------===//
142
145
// DivSOp
143
146
// ===----------------------------------------------------------------------===//
144
147
145
- OpFoldResult DivSOp::fold (ArrayRef<Attribute> operands ) {
148
+ OpFoldResult DivSOp::fold (FoldAdaptor adaptor ) {
146
149
return foldBinaryOpChecked (
147
- operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
150
+ adaptor.getOperands (),
151
+ [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
148
152
// Don't fold division by zero.
149
153
if (rhs.isZero ())
150
154
return std::nullopt;
@@ -156,9 +160,10 @@ OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
156
160
// DivUOp
157
161
// ===----------------------------------------------------------------------===//
158
162
159
- OpFoldResult DivUOp::fold (ArrayRef<Attribute> operands ) {
163
+ OpFoldResult DivUOp::fold (FoldAdaptor adaptor ) {
160
164
return foldBinaryOpChecked (
161
- operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
165
+ adaptor.getOperands (),
166
+ [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
162
167
// Don't fold division by zero.
163
168
if (rhs.isZero ())
164
169
return std::nullopt;
@@ -193,18 +198,19 @@ static Optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
193
198
return (n + x).sdiv (m) + 1 ;
194
199
}
195
200
196
- OpFoldResult CeilDivSOp::fold (ArrayRef<Attribute> operands ) {
197
- return foldBinaryOpChecked (operands , calculateCeilDivS);
201
+ OpFoldResult CeilDivSOp::fold (FoldAdaptor adaptor ) {
202
+ return foldBinaryOpChecked (adaptor. getOperands () , calculateCeilDivS);
198
203
}
199
204
200
205
// ===----------------------------------------------------------------------===//
201
206
// CeilDivUOp
202
207
// ===----------------------------------------------------------------------===//
203
208
204
- OpFoldResult CeilDivUOp::fold (ArrayRef<Attribute> operands ) {
209
+ OpFoldResult CeilDivUOp::fold (FoldAdaptor adaptor ) {
205
210
// Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
206
211
return foldBinaryOpChecked (
207
- operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
212
+ adaptor.getOperands (),
213
+ [](const APInt &n, const APInt &m) -> Optional<APInt> {
208
214
// Don't fold division by zero.
209
215
if (m.isZero ())
210
216
return std::nullopt;
@@ -242,56 +248,58 @@ static Optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
242
248
return -1 - (x - n).sdiv (m);
243
249
}
244
250
245
- OpFoldResult FloorDivSOp::fold (ArrayRef<Attribute> operands ) {
246
- return foldBinaryOpChecked (operands , calculateFloorDivS);
251
+ OpFoldResult FloorDivSOp::fold (FoldAdaptor adaptor ) {
252
+ return foldBinaryOpChecked (adaptor. getOperands () , calculateFloorDivS);
247
253
}
248
254
249
255
// ===----------------------------------------------------------------------===//
250
256
// RemSOp
251
257
// ===----------------------------------------------------------------------===//
252
258
253
- OpFoldResult RemSOp::fold (ArrayRef<Attribute> operands ) {
254
- return foldBinaryOpChecked (operands, []( const APInt &lhs, const APInt &rhs) {
255
- return lhs. srem (rhs);
256
- });
259
+ OpFoldResult RemSOp::fold (FoldAdaptor adaptor ) {
260
+ return foldBinaryOpChecked (
261
+ adaptor. getOperands (),
262
+ []( const APInt &lhs, const APInt &rhs) { return lhs. srem (rhs); });
257
263
}
258
264
259
265
// ===----------------------------------------------------------------------===//
260
266
// RemUOp
261
267
// ===----------------------------------------------------------------------===//
262
268
263
- OpFoldResult RemUOp::fold (ArrayRef<Attribute> operands ) {
264
- return foldBinaryOpChecked (operands, []( const APInt &lhs, const APInt &rhs) {
265
- return lhs. urem (rhs);
266
- });
269
+ OpFoldResult RemUOp::fold (FoldAdaptor adaptor ) {
270
+ return foldBinaryOpChecked (
271
+ adaptor. getOperands (),
272
+ []( const APInt &lhs, const APInt &rhs) { return lhs. urem (rhs); });
267
273
}
268
274
269
275
// ===----------------------------------------------------------------------===//
270
276
// MaxSOp
271
277
// ===----------------------------------------------------------------------===//
272
278
273
- OpFoldResult MaxSOp::fold (ArrayRef<Attribute> operands) {
274
- return foldBinaryOpChecked (operands, [](const APInt &lhs, const APInt &rhs) {
275
- return lhs.sgt (rhs) ? lhs : rhs;
276
- });
279
+ OpFoldResult MaxSOp::fold (FoldAdaptor adaptor) {
280
+ return foldBinaryOpChecked (adaptor.getOperands (),
281
+ [](const APInt &lhs, const APInt &rhs) {
282
+ return lhs.sgt (rhs) ? lhs : rhs;
283
+ });
277
284
}
278
285
279
286
// ===----------------------------------------------------------------------===//
280
287
// MaxUOp
281
288
// ===----------------------------------------------------------------------===//
282
289
283
- OpFoldResult MaxUOp::fold (ArrayRef<Attribute> operands) {
284
- return foldBinaryOpChecked (operands, [](const APInt &lhs, const APInt &rhs) {
285
- return lhs.ugt (rhs) ? lhs : rhs;
286
- });
290
+ OpFoldResult MaxUOp::fold (FoldAdaptor adaptor) {
291
+ return foldBinaryOpChecked (adaptor.getOperands (),
292
+ [](const APInt &lhs, const APInt &rhs) {
293
+ return lhs.ugt (rhs) ? lhs : rhs;
294
+ });
287
295
}
288
296
289
297
// ===----------------------------------------------------------------------===//
290
298
// MinSOp
291
299
// ===----------------------------------------------------------------------===//
292
300
293
- OpFoldResult MinSOp::fold (ArrayRef<Attribute> operands ) {
294
- return foldBinaryOpChecked (operands , [](const APInt &lhs, const APInt &rhs) {
301
+ OpFoldResult MinSOp::fold (FoldAdaptor adaptor ) {
302
+ return foldBinaryOpChecked (adaptor. getOperands () , [](const APInt &lhs, const APInt &rhs) {
295
303
return lhs.slt (rhs) ? lhs : rhs;
296
304
});
297
305
}
@@ -300,8 +308,8 @@ OpFoldResult MinSOp::fold(ArrayRef<Attribute> operands) {
300
308
// MinUOp
301
309
// ===----------------------------------------------------------------------===//
302
310
303
- OpFoldResult MinUOp::fold (ArrayRef<Attribute> operands ) {
304
- return foldBinaryOpChecked (operands , [](const APInt &lhs, const APInt &rhs) {
311
+ OpFoldResult MinUOp::fold (FoldAdaptor adaptor ) {
312
+ return foldBinaryOpChecked (adaptor. getOperands () , [](const APInt &lhs, const APInt &rhs) {
305
313
return lhs.ult (rhs) ? lhs : rhs;
306
314
});
307
315
}
@@ -310,9 +318,10 @@ OpFoldResult MinUOp::fold(ArrayRef<Attribute> operands) {
310
318
// ShlOp
311
319
// ===----------------------------------------------------------------------===//
312
320
313
- OpFoldResult ShlOp::fold (ArrayRef<Attribute> operands ) {
321
+ OpFoldResult ShlOp::fold (FoldAdaptor adaptor ) {
314
322
return foldBinaryOpUnchecked (
315
- operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
323
+ adaptor.getOperands (),
324
+ [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
316
325
// We cannot fold if the RHS is greater than or equal to 32 because
317
326
// this would be UB in 32-bit systems but not on 64-bit systems. RHS is
318
327
// already treated as unsigned.
@@ -326,9 +335,10 @@ OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
326
335
// ShrSOp
327
336
// ===----------------------------------------------------------------------===//
328
337
329
- OpFoldResult ShrSOp::fold (ArrayRef<Attribute> operands ) {
338
+ OpFoldResult ShrSOp::fold (FoldAdaptor adaptor ) {
330
339
return foldBinaryOpChecked (
331
- operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
340
+ adaptor.getOperands (),
341
+ [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
332
342
// Don't fold if RHS is greater than or equal to 32.
333
343
if (rhs.uge (32 ))
334
344
return {};
@@ -340,9 +350,10 @@ OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
340
350
// ShrUOp
341
351
// ===----------------------------------------------------------------------===//
342
352
343
- OpFoldResult ShrUOp::fold (ArrayRef<Attribute> operands ) {
353
+ OpFoldResult ShrUOp::fold (FoldAdaptor adaptor ) {
344
354
return foldBinaryOpChecked (
345
- operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
355
+ adaptor.getOperands (),
356
+ [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
346
357
// Don't fold if RHS is greater than or equal to 32.
347
358
if (rhs.uge (32 ))
348
359
return {};
@@ -354,27 +365,30 @@ OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
354
365
// AndOp
355
366
// ===----------------------------------------------------------------------===//
356
367
357
- OpFoldResult AndOp::fold (ArrayRef<Attribute> operands ) {
368
+ OpFoldResult AndOp::fold (FoldAdaptor adaptor ) {
358
369
return foldBinaryOpUnchecked (
359
- operands, [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
370
+ adaptor.getOperands (),
371
+ [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
360
372
}
361
373
362
374
// ===----------------------------------------------------------------------===//
363
375
// OrOp
364
376
// ===----------------------------------------------------------------------===//
365
377
366
- OpFoldResult OrOp::fold (ArrayRef<Attribute> operands ) {
378
+ OpFoldResult OrOp::fold (FoldAdaptor adaptor ) {
367
379
return foldBinaryOpUnchecked (
368
- operands, [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
380
+ adaptor.getOperands (),
381
+ [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
369
382
}
370
383
371
384
// ===----------------------------------------------------------------------===//
372
385
// XOrOp
373
386
// ===----------------------------------------------------------------------===//
374
387
375
- OpFoldResult XOrOp::fold (ArrayRef<Attribute> operands ) {
388
+ OpFoldResult XOrOp::fold (FoldAdaptor adaptor ) {
376
389
return foldBinaryOpUnchecked (
377
- operands, [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
390
+ adaptor.getOperands (),
391
+ [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
378
392
}
379
393
380
394
// ===----------------------------------------------------------------------===//
@@ -425,10 +439,9 @@ bool compareIndices(const APInt &lhs, const APInt &rhs,
425
439
llvm_unreachable (" unhandled IndexCmpPredicate predicate" );
426
440
}
427
441
428
- OpFoldResult CmpOp::fold (ArrayRef<Attribute> operands) {
429
- assert (operands.size () == 2 && " compare expected 2 operands" );
430
- auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0 ]);
431
- auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1 ]);
442
+ OpFoldResult CmpOp::fold (FoldAdaptor adaptor) {
443
+ auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs ());
444
+ auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs ());
432
445
if (!lhs || !rhs)
433
446
return {};
434
447
@@ -453,9 +466,7 @@ void ConstantOp::getAsmResultNames(
453
466
setNameFn (getResult (), specialName.str ());
454
467
}
455
468
456
- OpFoldResult ConstantOp::fold (ArrayRef<Attribute> operands) {
457
- return getValueAttr ();
458
- }
469
+ OpFoldResult ConstantOp::fold (FoldAdaptor adaptor) { return getValueAttr (); }
459
470
460
471
void ConstantOp::build (OpBuilder &b, OperationState &state, int64_t value) {
461
472
build (b, state, b.getIndexType (), b.getIndexAttr (value));
@@ -465,7 +476,7 @@ void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
465
476
// BoolConstantOp
466
477
// ===----------------------------------------------------------------------===//
467
478
468
- OpFoldResult BoolConstantOp::fold (ArrayRef<Attribute> operands ) {
479
+ OpFoldResult BoolConstantOp::fold (FoldAdaptor adaptor ) {
469
480
return getValueAttr ();
470
481
}
471
482
0 commit comments