Skip to content

Commit a703d15

Browse files
committed
[mlir][Index][NFC] Migrate index dialect to the new fold API
See https://discourse.llvm.org/t/psa-new-improved-fold-method-signature-has-landed-please-update-your-downstream-projects/67618 for context Similar to the patch for the arith dialect, the index dialects fold implementations make heavy use of generic fold functions, hence the change being comparatively mechanical and mostly changing the function signature. Differential Revision: https://reviews.llvm.org/D141502
1 parent f2f3b1a commit a703d15

File tree

2 files changed

+68
-56
lines changed

2 files changed

+68
-56
lines changed

mlir/include/mlir/Dialect/Index/IR/IndexDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def IndexDialect : Dialect {
8383

8484
let hasConstantMaterializer = 1;
8585
let useDefaultAttributePrinterParser = 1;
86+
let useFoldAPI = kEmitFoldAdaptorFolder;
8687
}
8788

8889
#endif // INDEX_DIALECT

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 67 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -115,36 +115,40 @@ static OpFoldResult foldBinaryOpChecked(
115115
// AddOp
116116
//===----------------------------------------------------------------------===//
117117

118-
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
118+
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
119119
return foldBinaryOpUnchecked(
120-
operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
120+
adaptor.getOperands(),
121+
[](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
121122
}
122123

123124
//===----------------------------------------------------------------------===//
124125
// SubOp
125126
//===----------------------------------------------------------------------===//
126127

127-
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
128+
OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
128129
return foldBinaryOpUnchecked(
129-
operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
130+
adaptor.getOperands(),
131+
[](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
130132
}
131133

132134
//===----------------------------------------------------------------------===//
133135
// MulOp
134136
//===----------------------------------------------------------------------===//
135137

136-
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
138+
OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
137139
return foldBinaryOpUnchecked(
138-
operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
140+
adaptor.getOperands(),
141+
[](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
139142
}
140143

141144
//===----------------------------------------------------------------------===//
142145
// DivSOp
143146
//===----------------------------------------------------------------------===//
144147

145-
OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
148+
OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
146149
return foldBinaryOpChecked(
147-
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
150+
adaptor.getOperands(),
151+
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
148152
// Don't fold division by zero.
149153
if (rhs.isZero())
150154
return std::nullopt;
@@ -156,9 +160,10 @@ OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
156160
// DivUOp
157161
//===----------------------------------------------------------------------===//
158162

159-
OpFoldResult DivUOp::fold(ArrayRef<Attribute> operands) {
163+
OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
160164
return foldBinaryOpChecked(
161-
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
165+
adaptor.getOperands(),
166+
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
162167
// Don't fold division by zero.
163168
if (rhs.isZero())
164169
return std::nullopt;
@@ -193,18 +198,19 @@ static Optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
193198
return (n + x).sdiv(m) + 1;
194199
}
195200

196-
OpFoldResult CeilDivSOp::fold(ArrayRef<Attribute> operands) {
197-
return foldBinaryOpChecked(operands, calculateCeilDivS);
201+
OpFoldResult CeilDivSOp::fold(FoldAdaptor adaptor) {
202+
return foldBinaryOpChecked(adaptor.getOperands(), calculateCeilDivS);
198203
}
199204

200205
//===----------------------------------------------------------------------===//
201206
// CeilDivUOp
202207
//===----------------------------------------------------------------------===//
203208

204-
OpFoldResult CeilDivUOp::fold(ArrayRef<Attribute> operands) {
209+
OpFoldResult CeilDivUOp::fold(FoldAdaptor adaptor) {
205210
// Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
206211
return foldBinaryOpChecked(
207-
operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
212+
adaptor.getOperands(),
213+
[](const APInt &n, const APInt &m) -> Optional<APInt> {
208214
// Don't fold division by zero.
209215
if (m.isZero())
210216
return std::nullopt;
@@ -242,56 +248,58 @@ static Optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
242248
return -1 - (x - n).sdiv(m);
243249
}
244250

245-
OpFoldResult FloorDivSOp::fold(ArrayRef<Attribute> operands) {
246-
return foldBinaryOpChecked(operands, calculateFloorDivS);
251+
OpFoldResult FloorDivSOp::fold(FoldAdaptor adaptor) {
252+
return foldBinaryOpChecked(adaptor.getOperands(), calculateFloorDivS);
247253
}
248254

249255
//===----------------------------------------------------------------------===//
250256
// RemSOp
251257
//===----------------------------------------------------------------------===//
252258

253-
OpFoldResult RemSOp::fold(ArrayRef<Attribute> operands) {
254-
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
255-
return lhs.srem(rhs);
256-
});
259+
OpFoldResult RemSOp::fold(FoldAdaptor adaptor) {
260+
return foldBinaryOpChecked(
261+
adaptor.getOperands(),
262+
[](const APInt &lhs, const APInt &rhs) { return lhs.srem(rhs); });
257263
}
258264

259265
//===----------------------------------------------------------------------===//
260266
// RemUOp
261267
//===----------------------------------------------------------------------===//
262268

263-
OpFoldResult RemUOp::fold(ArrayRef<Attribute> operands) {
264-
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
265-
return lhs.urem(rhs);
266-
});
269+
OpFoldResult RemUOp::fold(FoldAdaptor adaptor) {
270+
return foldBinaryOpChecked(
271+
adaptor.getOperands(),
272+
[](const APInt &lhs, const APInt &rhs) { return lhs.urem(rhs); });
267273
}
268274

269275
//===----------------------------------------------------------------------===//
270276
// MaxSOp
271277
//===----------------------------------------------------------------------===//
272278

273-
OpFoldResult MaxSOp::fold(ArrayRef<Attribute> operands) {
274-
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
275-
return lhs.sgt(rhs) ? lhs : rhs;
276-
});
279+
OpFoldResult MaxSOp::fold(FoldAdaptor adaptor) {
280+
return foldBinaryOpChecked(adaptor.getOperands(),
281+
[](const APInt &lhs, const APInt &rhs) {
282+
return lhs.sgt(rhs) ? lhs : rhs;
283+
});
277284
}
278285

279286
//===----------------------------------------------------------------------===//
280287
// MaxUOp
281288
//===----------------------------------------------------------------------===//
282289

283-
OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
284-
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
285-
return lhs.ugt(rhs) ? lhs : rhs;
286-
});
290+
OpFoldResult MaxUOp::fold(FoldAdaptor adaptor) {
291+
return foldBinaryOpChecked(adaptor.getOperands(),
292+
[](const APInt &lhs, const APInt &rhs) {
293+
return lhs.ugt(rhs) ? lhs : rhs;
294+
});
287295
}
288296

289297
//===----------------------------------------------------------------------===//
290298
// MinSOp
291299
//===----------------------------------------------------------------------===//
292300

293-
OpFoldResult MinSOp::fold(ArrayRef<Attribute> operands) {
294-
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
301+
OpFoldResult MinSOp::fold(FoldAdaptor adaptor) {
302+
return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
295303
return lhs.slt(rhs) ? lhs : rhs;
296304
});
297305
}
@@ -300,8 +308,8 @@ OpFoldResult MinSOp::fold(ArrayRef<Attribute> operands) {
300308
// MinUOp
301309
//===----------------------------------------------------------------------===//
302310

303-
OpFoldResult MinUOp::fold(ArrayRef<Attribute> operands) {
304-
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
311+
OpFoldResult MinUOp::fold(FoldAdaptor adaptor) {
312+
return foldBinaryOpChecked(adaptor.getOperands(), [](const APInt &lhs, const APInt &rhs) {
305313
return lhs.ult(rhs) ? lhs : rhs;
306314
});
307315
}
@@ -310,9 +318,10 @@ OpFoldResult MinUOp::fold(ArrayRef<Attribute> operands) {
310318
// ShlOp
311319
//===----------------------------------------------------------------------===//
312320

313-
OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
321+
OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
314322
return foldBinaryOpUnchecked(
315-
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
323+
adaptor.getOperands(),
324+
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
316325
// We cannot fold if the RHS is greater than or equal to 32 because
317326
// this would be UB in 32-bit systems but not on 64-bit systems. RHS is
318327
// already treated as unsigned.
@@ -326,9 +335,10 @@ OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
326335
// ShrSOp
327336
//===----------------------------------------------------------------------===//
328337

329-
OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
338+
OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
330339
return foldBinaryOpChecked(
331-
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
340+
adaptor.getOperands(),
341+
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
332342
// Don't fold if RHS is greater than or equal to 32.
333343
if (rhs.uge(32))
334344
return {};
@@ -340,9 +350,10 @@ OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
340350
// ShrUOp
341351
//===----------------------------------------------------------------------===//
342352

343-
OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
353+
OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
344354
return foldBinaryOpChecked(
345-
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
355+
adaptor.getOperands(),
356+
[](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
346357
// Don't fold if RHS is greater than or equal to 32.
347358
if (rhs.uge(32))
348359
return {};
@@ -354,27 +365,30 @@ OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
354365
// AndOp
355366
//===----------------------------------------------------------------------===//
356367

357-
OpFoldResult AndOp::fold(ArrayRef<Attribute> operands) {
368+
OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
358369
return foldBinaryOpUnchecked(
359-
operands, [](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
370+
adaptor.getOperands(),
371+
[](const APInt &lhs, const APInt &rhs) { return lhs & rhs; });
360372
}
361373

362374
//===----------------------------------------------------------------------===//
363375
// OrOp
364376
//===----------------------------------------------------------------------===//
365377

366-
OpFoldResult OrOp::fold(ArrayRef<Attribute> operands) {
378+
OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
367379
return foldBinaryOpUnchecked(
368-
operands, [](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
380+
adaptor.getOperands(),
381+
[](const APInt &lhs, const APInt &rhs) { return lhs | rhs; });
369382
}
370383

371384
//===----------------------------------------------------------------------===//
372385
// XOrOp
373386
//===----------------------------------------------------------------------===//
374387

375-
OpFoldResult XOrOp::fold(ArrayRef<Attribute> operands) {
388+
OpFoldResult XOrOp::fold(FoldAdaptor adaptor) {
376389
return foldBinaryOpUnchecked(
377-
operands, [](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
390+
adaptor.getOperands(),
391+
[](const APInt &lhs, const APInt &rhs) { return lhs ^ rhs; });
378392
}
379393

380394
//===----------------------------------------------------------------------===//
@@ -425,10 +439,9 @@ bool compareIndices(const APInt &lhs, const APInt &rhs,
425439
llvm_unreachable("unhandled IndexCmpPredicate predicate");
426440
}
427441

428-
OpFoldResult CmpOp::fold(ArrayRef<Attribute> operands) {
429-
assert(operands.size() == 2 && "compare expected 2 operands");
430-
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
431-
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
442+
OpFoldResult CmpOp::fold(FoldAdaptor adaptor) {
443+
auto lhs = dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
444+
auto rhs = dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
432445
if (!lhs || !rhs)
433446
return {};
434447

@@ -453,9 +466,7 @@ void ConstantOp::getAsmResultNames(
453466
setNameFn(getResult(), specialName.str());
454467
}
455468

456-
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
457-
return getValueAttr();
458-
}
469+
OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
459470

460471
void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
461472
build(b, state, b.getIndexType(), b.getIndexAttr(value));
@@ -465,7 +476,7 @@ void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
465476
// BoolConstantOp
466477
//===----------------------------------------------------------------------===//
467478

468-
OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
479+
OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
469480
return getValueAttr();
470481
}
471482

0 commit comments

Comments
 (0)