Skip to content

Commit f7ef73e

Browse files
yiwu0b11kuhargysit
authored
[mlir] [arith] add shl overflow flag in Arith and lower to SPIR-V and LLVMIR (#79828)
There is no `SHL` used in canonicalization in `arith` --------- Co-authored-by: Jakub Kuderski <[email protected]> Co-authored-by: Tobias Gysi <[email protected]>
1 parent e5054fb commit f7ef73e

File tree

6 files changed

+21
-5
lines changed

6 files changed

+21
-5
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ def Arith_XOrIOp : Arith_TotalIntBinaryOp<"xori", [Commutative]> {
782782
// ShLIOp
783783
//===----------------------------------------------------------------------===//
784784

785-
def Arith_ShLIOp : Arith_TotalIntBinaryOp<"shli"> {
785+
def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
786786
let summary = "integer left-shift";
787787
let description = [{
788788
The `shli` operation shifts the integer value of the first operand to the left
@@ -791,12 +791,18 @@ def Arith_ShLIOp : Arith_TotalIntBinaryOp<"shli"> {
791791
operand is greater than the bitwidth of the first operand, then the
792792
operation returns poison.
793793

794+
This op supports `nuw`/`nsw` overflow flags which stands stand for
795+
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
796+
`nsw` flags are present, and an unsigned/signed overflow occurs
797+
(respectively), the result is poison.
798+
794799
Example:
795800

796801
```mlir
797-
%1 = arith.constant 5 : i8 // %1 is 0b00000101
802+
%1 = arith.constant 5 : i8 // %1 is 0b00000101
798803
%2 = arith.constant 3 : i8
799-
%3 = arith.shli %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000
804+
%3 = arith.shli %1, %2 : i8 // %3 is 0b00101000
805+
%4 = arith.shli %1, %2 overflow<nsw, nuw> : i8
800806
```
801807
}];
802808
let hasFolder = 1;

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ using RemUIOpLowering =
9696
VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
9797
using SelectOpLowering =
9898
VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
99-
using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
99+
using ShLIOpLowering =
100+
VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
101+
arith::AttrConvertOverflowToLLVM>;
100102
using ShRSIOpLowering =
101103
VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
102104
using ShRUIOpLowering =

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1217,7 +1217,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
12171217
BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
12181218
BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
12191219
XOrIOpLogicalPattern, XOrIOpBooleanPattern,
1220-
spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
1220+
ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
12211221
spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
12221222
spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
12231223
spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,5 +586,7 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
586586
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
587587
// CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
588588
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
589+
// CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
590+
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
589591
return
590592
}

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,8 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
14221422
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
14231423
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
14241424
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
1425+
// CHECK: %{{.*}} = spirv.ShiftLeftLogical %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
1426+
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
14251427
return
14261428
}
14271429

@@ -1443,6 +1445,8 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
14431445
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
14441446
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
14451447
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
1448+
// CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
1449+
%3 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
14461450
return
14471451
}
14481452

mlir/test/Dialect/Arith/ops.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,5 +1147,7 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
11471147
%1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
11481148
// CHECK: %{{.*}} = arith.muli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
11491149
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
1150+
// CHECK: %{{.*}} = arith.shli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
1151+
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
11501152
return
11511153
}

0 commit comments

Comments
 (0)