Skip to content

Commit 9ef3146

Browse files
author
Jeff Niu
committed
[mlir][index] Add shl, shrs, and shru ops
This patch adds the left shift, signed right shift, and unsigned right shift operations to the index dialects with folders and LLVM lowerings. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D137349
1 parent 374e646 commit 9ef3146

File tree

6 files changed

+239
-5
lines changed

6 files changed

+239
-5
lines changed

mlir/include/mlir/Dialect/Index/IR/IndexOps.td

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,69 @@ def Index_MaxUOp : IndexBinaryOp<"maxu"> {
280280
}];
281281
}
282282

283+
//===----------------------------------------------------------------------===//
284+
// ShlOp
285+
//===----------------------------------------------------------------------===//
286+
287+
def Index_ShlOp : IndexBinaryOp<"shl"> {
288+
let summary = "index shift left";
289+
let description = [{
290+
The `index.shl` operation shifts an index value to the left by a variable
291+
amount. The low order bits are filled with zeroes. The RHS operand is always
292+
treated as unsigned. If the RHS operand is equal to or greater than the
293+
index bitwidth, the operation is undefined.
294+
295+
Example:
296+
297+
```mlir
298+
// c = a << b
299+
%c = index.shl %a, %b
300+
```
301+
}];
302+
}
303+
304+
//===----------------------------------------------------------------------===//
305+
// ShrSOp
306+
//===----------------------------------------------------------------------===//
307+
308+
def Index_ShrSOp : IndexBinaryOp<"shrs"> {
309+
let summary = "signed index shift right";
310+
let description = [{
311+
The `index.shrs` operation shifts an index value to the right by a variable
312+
amount. The LHS operand is treated as signed. The high order bits are filled
313+
with copies of the most significant bit. If the RHS operand is equal to or
314+
greater than the index bitwidth, the operation is undefined.
315+
316+
Example:
317+
318+
```mlir
319+
// c = a >> b
320+
%c = index.shrs %a, %b
321+
```
322+
}];
323+
}
324+
325+
//===----------------------------------------------------------------------===//
326+
// ShrUOp
327+
//===----------------------------------------------------------------------===//
328+
329+
def Index_ShrUOp : IndexBinaryOp<"shru"> {
330+
let summary = "unsigned index shift right";
331+
let description = [{
332+
The `index.shru` operation shifts an index value to the right by a variable
333+
amount. The LHS operand is treated as unsigned. The high order bits are
334+
filled with zeroes. If the RHS operand is equal to or greater than the index
335+
bitwidth, the operation is undefined.
336+
337+
Example:
338+
339+
```mlir
340+
// c = a >> b
341+
%c = index.shru %a, %b
342+
```
343+
}];
344+
}
345+
283346
//===----------------------------------------------------------------------===//
284347
// CastSOp
285348
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,11 @@ using ConvertIndexMaxS =
268268
mlir::OneToOneConvertToLLVMPattern<MaxSOp, LLVM::SMaxOp>;
269269
using ConvertIndexMaxU =
270270
mlir::OneToOneConvertToLLVMPattern<MaxUOp, LLVM::UMaxOp>;
271+
using ConvertIndexShl = mlir::OneToOneConvertToLLVMPattern<ShlOp, LLVM::ShlOp>;
272+
using ConvertIndexShrS =
273+
mlir::OneToOneConvertToLLVMPattern<ShrSOp, LLVM::AShrOp>;
274+
using ConvertIndexShrU =
275+
mlir::OneToOneConvertToLLVMPattern<ShrUOp, LLVM::LShrOp>;
271276
using ConvertIndexBoolConstant =
272277
mlir::OneToOneConvertToLLVMPattern<BoolConstantOp, LLVM::ConstantOp>;
273278

@@ -290,6 +295,9 @@ void index::populateIndexToLLVMConversionPatterns(
290295
ConvertIndexRemU,
291296
ConvertIndexMaxS,
292297
ConvertIndexMaxU,
298+
ConvertIndexShl,
299+
ConvertIndexShrS,
300+
ConvertIndexShrU,
293301
ConvertIndexCeilDivS,
294302
ConvertIndexCeilDivU,
295303
ConvertIndexFloorDivS,

mlir/lib/Dialect/Index/IR/IndexOps.cpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,19 @@ Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
6262
/// the integer result, which in turn must satisfy the above property.
6363
static OpFoldResult foldBinaryOpUnchecked(
6464
ArrayRef<Attribute> operands,
65-
function_ref<APInt(const APInt &, const APInt &)> calculate) {
65+
function_ref<Optional<APInt>(const APInt &, const APInt &)> calculate) {
6666
assert(operands.size() == 2 && "binary operation expected 2 operands");
6767
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
6868
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
6969
if (!lhs || !rhs)
7070
return {};
7171

72-
APInt result = calculate(lhs.getValue(), rhs.getValue());
73-
assert(result.trunc(32) ==
72+
Optional<APInt> result = calculate(lhs.getValue(), rhs.getValue());
73+
if (!result)
74+
return {};
75+
assert(result->trunc(32) ==
7476
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
75-
return IntegerAttr::get(IndexType::get(lhs.getContext()), std::move(result));
77+
return IntegerAttr::get(IndexType::get(lhs.getContext()), std::move(*result));
7678
}
7779

7880
/// Fold an index operation only if the truncated 64-bit result matches the
@@ -284,6 +286,50 @@ OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
284286
});
285287
}
286288

289+
//===----------------------------------------------------------------------===//
290+
// ShlOp
291+
//===----------------------------------------------------------------------===//
292+
293+
OpFoldResult ShlOp::fold(ArrayRef<Attribute> operands) {
294+
return foldBinaryOpUnchecked(
295+
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
296+
// We cannot fold if the RHS is greater than or equal to 32 because
297+
// this would be UB in 32-bit systems but not on 64-bit systems. RHS is
298+
// already treated as unsigned.
299+
if (rhs.uge(32))
300+
return {};
301+
return lhs << rhs;
302+
});
303+
}
304+
305+
//===----------------------------------------------------------------------===//
306+
// ShrSOp
307+
//===----------------------------------------------------------------------===//
308+
309+
OpFoldResult ShrSOp::fold(ArrayRef<Attribute> operands) {
310+
return foldBinaryOpChecked(
311+
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
312+
// Don't fold if RHS is greater than or equal to 32.
313+
if (rhs.uge(32))
314+
return {};
315+
return lhs.ashr(rhs);
316+
});
317+
}
318+
319+
//===----------------------------------------------------------------------===//
320+
// ShrUOp
321+
//===----------------------------------------------------------------------===//
322+
323+
OpFoldResult ShrUOp::fold(ArrayRef<Attribute> operands) {
324+
return foldBinaryOpChecked(
325+
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
326+
// Don't fold if RHS is greater than or equal to 32.
327+
if (rhs.uge(32))
328+
return {};
329+
return lhs.lshr(rhs);
330+
});
331+
}
332+
287333
//===----------------------------------------------------------------------===//
288334
// CastSOp
289335
//===----------------------------------------------------------------------===//

mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,14 @@ func.func @trivial_ops(%a: index, %b: index) {
2222
%7 = index.maxs %a, %b
2323
// CHECK: llvm.intr.umax
2424
%8 = index.maxu %a, %b
25+
// CHECK: llvm.shl
26+
%9 = index.shl %a, %b
27+
// CHECK: llvm.ashr
28+
%10 = index.shrs %a, %b
29+
// CHECK: llvm.lshr
30+
%11 = index.shru %a, %b
2531
// CHECK: llvm.mlir.constant(true
26-
%9 = index.bool.constant true
32+
%12 = index.bool.constant true
2733
return
2834
}
2935

mlir/test/Dialect/Index/index-canonicalize.mlir

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,111 @@ func.func @maxu() -> index {
279279
return %0 : index
280280
}
281281

282+
// CHECK-LABEL: @shl
283+
func.func @shl() -> index {
284+
%lhs = index.constant 128
285+
%rhs = index.constant 2
286+
// CHECK: %[[A:.*]] = index.constant 512
287+
%0 = index.shl %lhs, %rhs
288+
// CHECK: return %[[A]]
289+
return %0 : index
290+
}
291+
292+
// CHECK-LABEL: @shl_32
293+
func.func @shl_32() -> index {
294+
%lhs = index.constant 1
295+
%rhs = index.constant 32
296+
// CHECK: index.shl
297+
%0 = index.shl %lhs, %rhs
298+
return %0 : index
299+
}
300+
301+
// CHECK-LABEL: @shl_edge
302+
func.func @shl_edge() -> index {
303+
%lhs = index.constant 4000000000
304+
%rhs = index.constant 31
305+
// CHECK: %[[A:.*]] = index.constant 858{{[0-9]+}}
306+
%0 = index.shl %lhs, %rhs
307+
// CHECK: return %[[A]]
308+
return %0 : index
309+
}
310+
311+
// CHECK-LABEL: @shrs
312+
func.func @shrs() -> index {
313+
%lhs = index.constant 128
314+
%rhs = index.constant 2
315+
// CHECK: %[[A:.*]] = index.constant 32
316+
%0 = index.shrs %lhs, %rhs
317+
// CHECK: return %[[A]]
318+
return %0 : index
319+
}
320+
321+
// CHECK-LABEL: @shrs_32
322+
func.func @shrs_32() -> index {
323+
%lhs = index.constant 4000000000000
324+
%rhs = index.constant 32
325+
// CHECK: index.shrs
326+
%0 = index.shrs %lhs, %rhs
327+
return %0 : index
328+
}
329+
330+
// CHECK-LABEL: @shrs_nofold
331+
func.func @shrs_nofold() -> index {
332+
%lhs = index.constant 0x100000000
333+
%rhs = index.constant 1
334+
// CHECK: index.shrs
335+
%0 = index.shrs %lhs, %rhs
336+
return %0 : index
337+
}
338+
339+
// CHECK-LABEL: @shrs_edge
340+
func.func @shrs_edge() -> index {
341+
%lhs = index.constant 0x10000000000
342+
%rhs = index.constant 3
343+
// CHECK: %[[A:.*]] = index.constant 137{{[0-9]+}}
344+
%0 = index.shrs %lhs, %rhs
345+
// CHECK: return %[[A]]
346+
return %0 : index
347+
}
348+
349+
// CHECK-LABEL: @shru
350+
func.func @shru() -> index {
351+
%lhs = index.constant 128
352+
%rhs = index.constant 2
353+
// CHECK: %[[A:.*]] = index.constant 32
354+
%0 = index.shru %lhs, %rhs
355+
// CHECK: return %[[A]]
356+
return %0 : index
357+
}
358+
359+
// CHECK-LABEL: @shru_32
360+
func.func @shru_32() -> index {
361+
%lhs = index.constant 4000000000000
362+
%rhs = index.constant 32
363+
// CHECK: index.shru
364+
%0 = index.shru %lhs, %rhs
365+
return %0 : index
366+
}
367+
368+
// CHECK-LABEL: @shru_nofold
369+
func.func @shru_nofold() -> index {
370+
%lhs = index.constant 0x100000000
371+
%rhs = index.constant 1
372+
// CHECK: index.shru
373+
%0 = index.shru %lhs, %rhs
374+
return %0 : index
375+
}
376+
377+
// CHECK-LABEL: @shru_edge
378+
func.func @shru_edge() -> index {
379+
%lhs = index.constant 0x10000000000
380+
%rhs = index.constant 3
381+
// CHECK: %[[A:.*]] = index.constant 137{{[0-9]+}}
382+
%0 = index.shru %lhs, %rhs
383+
// CHECK: return %[[A]]
384+
return %0 : index
385+
}
386+
282387
// CHECK-LABEL: @cmp
283388
func.func @cmp() -> (i1, i1, i1, i1) {
284389
%a = index.constant 0

mlir/test/Dialect/Index/index-ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ func.func @binary_ops(%a: index, %b: index) {
2727
%10 = index.maxs %a, %b
2828
// CHECK-NEXT: index.maxu %[[A]], %[[B]]
2929
%11 = index.maxu %a, %b
30+
// CHECK-NEXT: index.shl %[[A]], %[[B]]
31+
%12 = index.shl %a, %b
32+
// CHECK-NEXT: index.shrs %[[A]], %[[B]]
33+
%13 = index.shrs %a, %b
34+
// CHECK-NEXT: index.shru %[[A]], %[[B]]
35+
%14 = index.shru %a, %b
3036
return
3137
}
3238

0 commit comments

Comments
 (0)