Skip to content

[mlir] [arith] add shl overflow flag in Arith and lower to SPIR-V and LLVMIR #79828

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 30, 2024

Conversation

yiwu0b11
Copy link
Contributor

There is no SHL used in canonicalization in arith

@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir-arith

Author: Yi Wu (yi-wu-arm)

Changes

There is no SHL used in canonicalization in arith


Full diff: https://github.com/llvm/llvm-project/pull/79828.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+7-1)
  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+3-1)
  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+1-1)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+2)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+4)
  • (modified) mlir/test/Dialect/Arith/ops.mlir (+2)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index cd0102f91ef1523..4b815a2208f6c3d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -782,7 +782,7 @@ def Arith_XOrIOp : Arith_TotalIntBinaryOp<"xori", [Commutative]> {
 // ShLIOp
 //===----------------------------------------------------------------------===//
 
-def Arith_ShLIOp : Arith_TotalIntBinaryOp<"shli"> {
+def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
   let summary = "integer left-shift";
   let description = [{
     The `shli` operation shifts the integer value of the first operand to the left 
@@ -791,11 +791,17 @@ def Arith_ShLIOp : Arith_TotalIntBinaryOp<"shli"> {
     operand is greater than the bitwidth of the first operand, then the 
     operation returns poison.
 
+    This op supports `nuw`/`nsw` overflow flags which stands stand for
+    "No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
+    `nsw` flags are present, and an unsigned/signed overflow occurs
+    (respectively), the result is poison.
+
     Example:
 
     ```mlir
     %1 = arith.constant 5 : i8                 // %1 is 0b00000101
     %2 = arith.constant 3 : i8
+    %a = arith.shli %1, %2 overflow<nsw, nuw> : (i8, i8) -> i8  
     %3 = arith.shli %1, %2 : (i8, i8) -> i8    // %3 is 0b00101000
     ```
   }];
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index cf46e0d3ac46ac3..1f01f4a75c5b3ef 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -96,7 +96,9 @@ using RemUIOpLowering =
     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
 using SelectOpLowering =
     VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
-using ShLIOpLowering = VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp>;
+using ShLIOpLowering =
+    VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
+                               arith::AttrConvertOverflowToLLVM>;
 using ShRSIOpLowering =
     VectorConvertToLLVMPattern<arith::ShRSIOp, LLVM::AShrOp>;
 using ShRUIOpLowering =
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 1abad1e9fa4d85c..edf81bd7a8f3963 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -1217,7 +1217,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
-    spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
+    ElementwiseArithOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
     spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 8937b24e0d174d1..29268eef47e8534 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -586,5 +586,7 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
   %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
   // CHECK: %{{.*}} = llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
   %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  // CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 8bf90ed0aec8ee1..ae47ae36ca51cd6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1422,6 +1422,8 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
   %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
   // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
   %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  // CHECK: %{{.*}} = spirv.ShiftLeftLogical %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap} : i64
+  %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
 
@@ -1443,6 +1445,8 @@ func.func @ops_flags(%arg0: i64, %arg1: i64) {
   %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
   // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
   %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  // CHECK: %{{.*}} = spirv.IMul %{{.*}}, %{{.*}} : i64
+  %3 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
 
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index 8ae3273f32c6b02..e499573e324b5fc 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -1147,5 +1147,7 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
   %1 = arith.subi %arg0, %arg1 overflow<nuw> : i64
   // CHECK: %{{.*}} = arith.muli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
   %2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
+  // CHECK: %{{.*}} = arith.shli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
+  %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo nit!

@yiwu0b11 yiwu0b11 merged commit f7ef73e into llvm:main Jan 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants