Skip to content

Commit 83c3eeb

Browse files
author
Jeff Niu
committed
[mlir][index] Add folders for index ops
This patch adds folders for `index` dialect ops. Ths folders are careful to ensure that fold results are valid on both 32-bit and 64-bit targets. Depends on D135689 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D135694
1 parent ddf87d6 commit 83c3eeb

File tree

5 files changed

+651
-0
lines changed

5 files changed

+651
-0
lines changed

mlir/include/mlir/Dialect/Index/IR/IndexDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def IndexDialect : Dialect {
8181
void registerOperations();
8282
}];
8383

84+
let hasConstantMaterializer = 1;
8485
let useDefaultAttributePrinterParser = 1;
8586
}
8687

mlir/include/mlir/Dialect/Index/IR/IndexOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class IndexBinaryOp<string mnemonic, list<Trait> traits = []>
3434
let arguments = (ins Index:$lhs, Index:$rhs);
3535
let results = (outs Index:$result);
3636
let assemblyFormat = "$lhs `,` $rhs attr-dict";
37+
let hasFolder = 1;
3738
}
3839

3940
//===----------------------------------------------------------------------===//
@@ -378,6 +379,7 @@ def Index_CmpOp : IndexOp<"cmp"> {
378379
let arguments = (ins IndexCmpPredicateAttr:$pred, Index:$lhs, Index:$rhs);
379380
let results = (outs I1:$result);
380381
let assemblyFormat = "`` $pred `(` $lhs `,` $rhs `)` attr-dict";
382+
let hasFolder = 1;
381383
}
382384

383385
//===----------------------------------------------------------------------===//
@@ -422,6 +424,7 @@ def Index_ConstantOp : IndexOp<"constant", [ConstantLike]> {
422424
let arguments = (ins IndexAttr:$value);
423425
let results = (outs Index:$result);
424426
let assemblyFormat = "attr-dict $value";
427+
let hasFolder = 1;
425428

426429
let builders = [OpBuilder<(ins "int64_t":$value)>];
427430
}
@@ -449,6 +452,7 @@ def Index_BoolConstantOp : IndexOp<"bool.constant", [ConstantLike]> {
449452
let arguments = (ins BoolAttr:$value);
450453
let results = (outs I1:$result);
451454
let assemblyFormat = "attr-dict $value";
455+
let hasFolder = 1;
452456
}
453457

454458
#endif // INDEX_OPS

mlir/include/mlir/Support/LLVM.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ namespace mlir {
9999
using llvm::cast;
100100
using llvm::cast_or_null;
101101
using llvm::dyn_cast;
102+
using llvm::dyn_cast_if_present;
102103
using llvm::dyn_cast_or_null;
103104
using llvm::isa;
104105
using llvm::isa_and_nonnull;

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,264 @@ void IndexDialect::registerOperations() {
2626
>();
2727
}
2828

29+
Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
30+
Type type, Location loc) {
31+
// Materialize bool constants as `i1`.
32+
if (auto boolValue = dyn_cast<BoolAttr>(value)) {
33+
if (!type.isSignlessInteger(1))
34+
return nullptr;
35+
return b.create<BoolConstantOp>(loc, type, boolValue);
36+
}
37+
38+
// Materialize integer attributes as `index`.
39+
if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
40+
if (!indexValue.getType().isa<IndexType>() || !type.isa<IndexType>())
41+
return nullptr;
42+
assert(indexValue.getValue().getBitWidth() ==
43+
IndexType::kInternalStorageBitWidth);
44+
return b.create<ConstantOp>(loc, indexValue);
45+
}
46+
47+
return nullptr;
48+
}
49+
50+
//===----------------------------------------------------------------------===//
51+
// Fold Utilities
52+
//===----------------------------------------------------------------------===//
53+
54+
/// Fold an index operation irrespective of the target bitwidth. The
55+
/// operation must satisfy the property:
56+
///
57+
/// ```
58+
/// trunc(f(a, b)) = f(trunc(a), trunc(b))
59+
/// ```
60+
///
61+
/// For all values of `a` and `b`. The function accepts a lambda that computes
62+
/// the integer result, which in turn must satisfy the above property.
63+
static OpFoldResult foldBinaryOpUnchecked(
64+
ArrayRef<Attribute> operands,
65+
function_ref<APInt(const APInt &, const APInt &)> calculate) {
66+
assert(operands.size() == 2 && "binary operation expected 2 operands");
67+
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
68+
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
69+
if (!lhs || !rhs)
70+
return {};
71+
72+
APInt result = calculate(lhs.getValue(), rhs.getValue());
73+
assert(result.trunc(32) ==
74+
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
75+
return IntegerAttr::get(IndexType::get(lhs.getContext()), std::move(result));
76+
}
77+
78+
/// Fold an index operation only if the truncated 64-bit result matches the
79+
/// 32-bit result for operations that don't satisfy the above property. These
80+
/// are operations where the upper bits of the operands can affect the lower
81+
/// bits of the results.
82+
///
83+
/// The function accepts a lambda that computes the integer result in both
84+
/// 64-bit and 32-bit. If either call returns `None`, the operation is not
85+
/// folded.
86+
static OpFoldResult foldBinaryOpChecked(
87+
ArrayRef<Attribute> operands,
88+
function_ref<Optional<APInt>(const APInt &, const APInt &lhs)> calculate) {
89+
assert(operands.size() == 2 && "binary operation expected 2 operands");
90+
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
91+
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
92+
// Only fold index operands.
93+
if (!lhs || !rhs)
94+
return {};
95+
96+
// Compute the 64-bit result and the 32-bit result.
97+
Optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
98+
if (!result64)
99+
return {};
100+
Optional<APInt> result32 =
101+
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
102+
if (!result32)
103+
return {};
104+
// Compare the truncated 64-bit result to the 32-bit result.
105+
if (result64->trunc(32) != *result32)
106+
return {};
107+
// The operation can be folded for these particular operands.
108+
return IntegerAttr::get(IndexType::get(lhs.getContext()),
109+
std::move(*result64));
110+
}
111+
112+
//===----------------------------------------------------------------------===//
113+
// AddOp
114+
//===----------------------------------------------------------------------===//
115+
116+
OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
117+
return foldBinaryOpUnchecked(
118+
operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
119+
}
120+
121+
//===----------------------------------------------------------------------===//
122+
// SubOp
123+
//===----------------------------------------------------------------------===//
124+
125+
OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
126+
return foldBinaryOpUnchecked(
127+
operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
128+
}
129+
130+
//===----------------------------------------------------------------------===//
131+
// MulOp
132+
//===----------------------------------------------------------------------===//
133+
134+
OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
135+
return foldBinaryOpUnchecked(
136+
operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
137+
}
138+
139+
//===----------------------------------------------------------------------===//
140+
// DivSOp
141+
//===----------------------------------------------------------------------===//
142+
143+
OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
144+
return foldBinaryOpChecked(
145+
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
146+
// Don't fold division by zero.
147+
if (rhs.isZero())
148+
return None;
149+
return lhs.sdiv(rhs);
150+
});
151+
}
152+
153+
//===----------------------------------------------------------------------===//
154+
// DivUOp
155+
//===----------------------------------------------------------------------===//
156+
157+
OpFoldResult DivUOp::fold(ArrayRef<Attribute> operands) {
158+
return foldBinaryOpChecked(
159+
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
160+
// Don't fold division by zero.
161+
if (rhs.isZero())
162+
return None;
163+
return lhs.udiv(rhs);
164+
});
165+
}
166+
167+
//===----------------------------------------------------------------------===//
168+
// CeilDivSOp
169+
//===----------------------------------------------------------------------===//
170+
171+
/// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
172+
/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
173+
static Optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
174+
// Don't fold division by zero.
175+
if (m.isZero())
176+
return None;
177+
// Short-circuit the zero case.
178+
if (n.isZero())
179+
return n;
180+
181+
bool mGtZ = m.sgt(0);
182+
if (n.sgt(0) != mGtZ) {
183+
// If the operands have different signs, compute the negative result. Signed
184+
// division overflow is not possible, since if `m == -1`, `n` can be at most
185+
// `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
186+
return -(-n).sdiv(m);
187+
}
188+
// Otherwise, compute the positive result. Signed division overflow is not
189+
// possible since if `m == -1`, `x` will be `1`.
190+
int64_t x = mGtZ ? -1 : 1;
191+
return (n + x).sdiv(m) + 1;
192+
}
193+
194+
OpFoldResult CeilDivSOp::fold(ArrayRef<Attribute> operands) {
195+
return foldBinaryOpChecked(operands, calculateCeilDivS);
196+
}
197+
198+
//===----------------------------------------------------------------------===//
199+
// CeilDivUOp
200+
//===----------------------------------------------------------------------===//
201+
202+
OpFoldResult CeilDivUOp::fold(ArrayRef<Attribute> operands) {
203+
// Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
204+
return foldBinaryOpChecked(
205+
operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
206+
// Don't fold division by zero.
207+
if (m.isZero())
208+
return None;
209+
// Short-circuit the zero case.
210+
if (n.isZero())
211+
return n;
212+
213+
return (n - 1).udiv(m) + 1;
214+
});
215+
}
216+
217+
//===----------------------------------------------------------------------===//
218+
// FloorDivSOp
219+
//===----------------------------------------------------------------------===//
220+
221+
/// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
222+
/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
223+
static Optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
224+
// Don't fold division by zero.
225+
if (m.isZero())
226+
return None;
227+
// Short-circuit the zero case.
228+
if (n.isZero())
229+
return n;
230+
231+
bool mLtZ = m.slt(0);
232+
if (n.slt(0) == mLtZ) {
233+
// If the operands have the same sign, compute the positive result.
234+
return n.sdiv(m);
235+
}
236+
// If the operands have different signs, compute the negative result. Signed
237+
// division overflow is not possible since if `m == -1`, `x` will be 1 and
238+
// `n` can be at most `INT_MAX`.
239+
int64_t x = mLtZ ? 1 : -1;
240+
return -1 - (x - n).sdiv(m);
241+
}
242+
243+
OpFoldResult FloorDivSOp::fold(ArrayRef<Attribute> operands) {
244+
return foldBinaryOpChecked(operands, calculateFloorDivS);
245+
}
246+
247+
//===----------------------------------------------------------------------===//
248+
// RemSOp
249+
//===----------------------------------------------------------------------===//
250+
251+
OpFoldResult RemSOp::fold(ArrayRef<Attribute> operands) {
252+
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
253+
return lhs.srem(rhs);
254+
});
255+
}
256+
257+
//===----------------------------------------------------------------------===//
258+
// RemUOp
259+
//===----------------------------------------------------------------------===//
260+
261+
OpFoldResult RemUOp::fold(ArrayRef<Attribute> operands) {
262+
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
263+
return lhs.urem(rhs);
264+
});
265+
}
266+
267+
//===----------------------------------------------------------------------===//
268+
// MaxSOp
269+
//===----------------------------------------------------------------------===//
270+
271+
OpFoldResult MaxSOp::fold(ArrayRef<Attribute> operands) {
272+
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
273+
return lhs.sgt(rhs) ? lhs : rhs;
274+
});
275+
}
276+
277+
//===----------------------------------------------------------------------===//
278+
// MaxUOp
279+
//===----------------------------------------------------------------------===//
280+
281+
OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
282+
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
283+
return lhs.ugt(rhs) ? lhs : rhs;
284+
});
285+
}
286+
29287
//===----------------------------------------------------------------------===//
30288
// CastSOp
31289
//===----------------------------------------------------------------------===//
@@ -42,6 +300,74 @@ bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
42300
return lhsTypes.front().isa<IndexType>() != rhsTypes.front().isa<IndexType>();
43301
}
44302

303+
//===----------------------------------------------------------------------===//
304+
// CmpOp
305+
//===----------------------------------------------------------------------===//
306+
307+
/// Compare two integers according to the comparison predicate.
308+
bool compareIndices(const APInt &lhs, const APInt &rhs,
309+
IndexCmpPredicate pred) {
310+
switch (pred) {
311+
case IndexCmpPredicate::EQ:
312+
return lhs.eq(rhs);
313+
case IndexCmpPredicate::NE:
314+
return lhs.ne(rhs);
315+
case IndexCmpPredicate::SGE:
316+
return lhs.sge(rhs);
317+
case IndexCmpPredicate::SGT:
318+
return lhs.sgt(rhs);
319+
case IndexCmpPredicate::SLE:
320+
return lhs.sle(rhs);
321+
case IndexCmpPredicate::SLT:
322+
return lhs.slt(rhs);
323+
case IndexCmpPredicate::UGE:
324+
return lhs.uge(rhs);
325+
case IndexCmpPredicate::UGT:
326+
return lhs.ugt(rhs);
327+
case IndexCmpPredicate::ULE:
328+
return lhs.ule(rhs);
329+
case IndexCmpPredicate::ULT:
330+
return lhs.ult(rhs);
331+
}
332+
llvm_unreachable("unhandled IndexCmpPredicate predicate");
333+
}
334+
335+
OpFoldResult CmpOp::fold(ArrayRef<Attribute> operands) {
336+
assert(operands.size() == 2 && "compare expected 2 operands");
337+
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
338+
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
339+
if (!lhs || !rhs)
340+
return {};
341+
342+
// Perform the comparison in 64-bit and 32-bit.
343+
bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
344+
bool result32 = compareIndices(lhs.getValue().trunc(32),
345+
rhs.getValue().trunc(32), getPred());
346+
if (result64 != result32)
347+
return {};
348+
return BoolAttr::get(getContext(), result64);
349+
}
350+
351+
//===----------------------------------------------------------------------===//
352+
// ConstantOp
353+
//===----------------------------------------------------------------------===//
354+
355+
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
356+
return getValueAttr();
357+
}
358+
359+
void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
360+
build(b, state, b.getIndexType(), b.getIndexAttr(value));
361+
}
362+
363+
//===----------------------------------------------------------------------===//
364+
// BoolConstantOp
365+
//===----------------------------------------------------------------------===//
366+
367+
OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
368+
return getValueAttr();
369+
}
370+
45371
//===----------------------------------------------------------------------===//
46372
// ODS-Generated Definitions
47373
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)