Skip to content

[mlir][LLVM] Add nsw and nuw flags #74508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 7, 2023
Merged

[mlir][LLVM] Add nsw and nuw flags #74508

merged 11 commits into from
Dec 7, 2023

Conversation

tblah
Copy link
Contributor

@tblah tblah commented Dec 5, 2023

The implementation of these are modeled after the existing fastmath flags for floating point arithmetic.

The implementation of these are modelled after the existing fastmath flags for
floating point arithmetic.
@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2023

@llvm/pr-subscribers-mlir

Author: Tom Eccles (tblah)

Changes

The implementation of these are modeled after the existing fastmath flags for floating point arithmetic.


Full diff: https://github.com/llvm/llvm-project/pull/74508.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td (+23)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+57)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+19-4)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+5)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+10-1)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+13)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+10)
  • (added) mlir/test/Target/LLVMIR/Import/nsw_nuw.ll (+14)
  • (added) mlir/test/Target/LLVMIR/nsw_nuw.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f05230526c21f..5cde4980ae17d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr<
   let printBitEnumPrimaryGroups = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// IntegerArithFlags
+//===----------------------------------------------------------------------===//
+
+def IAFnone : I32BitEnumAttrCaseNone<"none">;
+def IAFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IAFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerArithFlags : I32BitEnumAttr<
+    "IntegerArithFlags",
+    "LLVM integer arithmetic flags",
+    [IAFnone, IAFnsw, IAFnuw]> {
+  let separator = ", ";
+  let cppNamespace = "::mlir::LLVM";
+  let genSpecializedAttr = 0;
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def LLVM_IntegerArithFlagsAttr :
+    EnumAttr<LLVM_Dialect, IntegerArithFlags, "arith"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // FastmathFlags
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index c5d65f792254e..3d3388ac50aff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
   ];
 }
 
+def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
+  let description = [{
+    Access to op integer overflow flags.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a IntegerArithFlagsAttr attribute for the operation",
+      /*returnType=*/  "IntegerArithFlagsAttr",
+      /*methodName=*/  "getArithAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getArithFlagsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNuw",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerArithFlags::nuw);
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNsw",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerArithFlags::nsw);
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the IntegerArithFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getIntegerArithAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "arithFlags";
+      }]
+      >
+  ];
+}
+
 def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
   let description = [{
     An interface for operations that can carry branch weights metadata. It
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8f166f0cc7cf5..4a2ef07f505b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -55,6 +55,21 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
     $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
   }];
 }
+class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
+                                   list<Trait> traits = []> :
+    LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
+    !listconcat([DeclareOpInterfaceMethods<IntegerArithFlagsInterface>], traits)> {
+  dag iafArg = (
+    ins DefaultValuedAttr<LLVM_IntegerArithFlagsAttr, "{}">:$arithFlags);
+  let arguments = !con(commonArgs, iafArg);
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+    moduleImport.setIntegerFlagsAttr(inst, op);
+    $res = op;
+  }];
+  let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $arithFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
+  string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNuw(), op.hasNsw());";
+}
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -90,9 +105,9 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
 }
 
 // Integer binary operations.
-def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>;
-def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">;
-def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>;
+def LLVM_AddOp : LLVM_IntArithmeticOpWithFlag<"add", "Add", [Commutative]>;
+def LLVM_SubOp : LLVM_IntArithmeticOpWithFlag<"sub", "Sub", []>;
+def LLVM_MulOp : LLVM_IntArithmeticOpWithFlag<"mul", "Mul", [Commutative]>;
 def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
 def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
 def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
@@ -102,7 +117,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
   let hasFolder = 1;
 }
 def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+def LLVM_ShlOp : LLVM_IntArithmeticOpWithFlag<"shl", "Shl", []> {
   let hasFolder = 1;
 }
 def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b8e449dc11df1..de52476636aed 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -172,6 +172,11 @@ class ModuleImport {
   /// attributes of LLVMFuncOp `funcOp`.
   void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);
 
+  /// Sets the integer arithmetic flags (nsw/nuw) attribute for the imported
+  /// operation `op` given the original instruction `inst`. Asserts if the
+  /// operation does not implement the integer arithmetic flag interface.
+  void setIntegerFlagsAttr(llvm::Instruction *inst, Operation *op) const;
+
   /// Sets the fastmath flags attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
   /// implement the fastmath interface.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 28445945f07d6..3d78970cf6c14 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -62,6 +62,14 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
   return filteredAttrs;
 }
 
+static auto processIntArithAttr(ArrayRef<NamedAttribute> attrs) {
+  SmallVector<NamedAttribute, 8> filteredAttrs(
+      llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
+        return attr.getName() != "arithFlags";
+      }));
+  return filteredAttrs;
+}
+
 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
                                     NamedAttrList &result) {
   return parser.parseOptionalAttrDict(result);
@@ -69,7 +77,8 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
 
 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
                              DictionaryAttr attrs) {
-  printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
+  printer.printOptionalAttrDict(
+      processFMFAttr(processIntArithAttr(attrs.getValue())));
 }
 
 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2d1aaa9229cd2..edd0120dcbb71 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
   }
 }
 
+void ModuleImport::setIntegerFlagsAttr(llvm::Instruction *inst,
+                                       Operation *op) const {
+  IntegerArithFlagsInterface iface = cast<IntegerArithFlagsInterface>(op);
+
+  IntegerArithFlags value = {};
+  value = bitEnumSet(value, IntegerArithFlags::nsw, inst->hasNoSignedWrap());
+  value = bitEnumSet(value, IntegerArithFlags::nuw, inst->hasNoUnsignedWrap());
+
+  IntegerArithFlagsAttr attr =
+      IntegerArithFlagsAttr::get(op->getContext(), value);
+  iface->setAttr(iface.getIntegerArithAttrName(), attr);
+}
+
 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
                                         Operation *op) const {
   auto iface = cast<FastmathFlagsInterface>(op);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index ee724a482cfb5..dc0f9f453057d 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr>
   %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
 
+// Integer arithmetic flags
+// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] flags <nsw> : i32
+// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] flags <nuw> : i32
+// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+  %add_flag = llvm.add %arg0, %arg0 flags <nsw> : i32
+  %sub_flag = llvm.sub %arg0, %arg0 flags <nuw> : i32
+  %mul_flag = llvm.mul %arg0, %arg0 flags <nsw, nuw> : i32
+  %shl_flag = llvm.shl %arg0, %arg0 flags <nuw, nsw> : i32
+
 // Floating point binary operations.
 //
 // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
new file mode 100644
index 0000000000000..2ea0425ec0ff7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -0,0 +1,14 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @intflag_inst
+define void @intflag_inst(i64 %arg1, i64 %arg2) {
+  ; CHECK: llvm.add %{{.*}}, %{{.*}} flags <nsw> : i64
+  %1 = add nsw i64 %arg1, %arg2
+  ; CHECK: llvm.sub %{{.*}}, %{{.*}} flags <nuw> : i64
+  %2 = sub nuw i64 %arg1, %arg2
+  ; CHECK: llvm.mul %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+  %3 = mul nsw nuw i64 %arg1, %arg2
+  ; CHECK: llvm.shl %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+  %4 = shl nuw nsw i64 %arg1, %arg2
+  ret void
+}
diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
new file mode 100644
index 0000000000000..4a7a39bb570c3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @intflags_func
+llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
+  %0 = llvm.add %arg0, %arg1 flags <nsw> : i64
+  // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
+  %1 = llvm.sub %arg0, %arg1 flags <nuw> : i64
+  // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
+  %2 = llvm.mul %arg0, %arg1 flags <nsw, nuw> : i64
+  // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
+  %3 = llvm.shl %arg0, %arg1 flags <nsw, nuw> : i64
+  llvm.return
+}

@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2023

@llvm/pr-subscribers-mlir-llvm

Author: Tom Eccles (tblah)

Changes

The implementation of these are modeled after the existing fastmath flags for floating point arithmetic.


Full diff: https://github.com/llvm/llvm-project/pull/74508.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td (+23)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+57)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+19-4)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+5)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+10-1)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+13)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+10)
  • (added) mlir/test/Target/LLVMIR/Import/nsw_nuw.ll (+14)
  • (added) mlir/test/Target/LLVMIR/nsw_nuw.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f05230526c21f..5cde4980ae17d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr<
   let printBitEnumPrimaryGroups = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// IntegerArithFlags
+//===----------------------------------------------------------------------===//
+
+def IAFnone : I32BitEnumAttrCaseNone<"none">;
+def IAFnsw  : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IAFnuw  : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerArithFlags : I32BitEnumAttr<
+    "IntegerArithFlags",
+    "LLVM integer arithmetic flags",
+    [IAFnone, IAFnsw, IAFnuw]> {
+  let separator = ", ";
+  let cppNamespace = "::mlir::LLVM";
+  let genSpecializedAttr = 0;
+  let printBitEnumPrimaryGroups = 1;
+}
+
+def LLVM_IntegerArithFlagsAttr :
+    EnumAttr<LLVM_Dialect, IntegerArithFlags, "arith"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 //===----------------------------------------------------------------------===//
 // FastmathFlags
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index c5d65f792254e..3d3388ac50aff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
   ];
 }
 
+def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
+  let description = [{
+    Access to op integer overflow flags.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/        "Returns a IntegerArithFlagsAttr attribute for the operation",
+      /*returnType=*/  "IntegerArithFlagsAttr",
+      /*methodName=*/  "getArithAttr",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        return op.getArithFlagsAttr();
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNuw",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerArithFlags::nuw);
+      }]
+      >,
+    InterfaceMethod<
+      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
+      /*returnType=*/  "bool",
+      /*methodName=*/  "hasNsw",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        auto op = cast<ConcreteOp>(this->getOperation());
+        IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+        return bitEnumContainsAll(flags, IntegerArithFlags::nsw);
+      }]
+      >,
+    StaticInterfaceMethod<
+      /*desc=*/        [{Returns the name of the IntegerArithFlagsAttr attribute
+                         for the operation}],
+      /*returnType=*/  "StringRef",
+      /*methodName=*/  "getIntegerArithAttrName",
+      /*args=*/        (ins),
+      /*methodBody=*/  [{}],
+      /*defaultImpl=*/ [{
+        return "arithFlags";
+      }]
+      >
+  ];
+}
+
 def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
   let description = [{
     An interface for operations that can carry branch weights metadata. It
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8f166f0cc7cf5..4a2ef07f505b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -55,6 +55,21 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
     $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
   }];
 }
+class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
+                                   list<Trait> traits = []> :
+    LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
+    !listconcat([DeclareOpInterfaceMethods<IntegerArithFlagsInterface>], traits)> {
+  dag iafArg = (
+    ins DefaultValuedAttr<LLVM_IntegerArithFlagsAttr, "{}">:$arithFlags);
+  let arguments = !con(commonArgs, iafArg);
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+    moduleImport.setIntegerFlagsAttr(inst, op);
+    $res = op;
+  }];
+  let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $arithFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
+  string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNuw(), op.hasNsw());";
+}
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -90,9 +105,9 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
 }
 
 // Integer binary operations.
-def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>;
-def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">;
-def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>;
+def LLVM_AddOp : LLVM_IntArithmeticOpWithFlag<"add", "Add", [Commutative]>;
+def LLVM_SubOp : LLVM_IntArithmeticOpWithFlag<"sub", "Sub", []>;
+def LLVM_MulOp : LLVM_IntArithmeticOpWithFlag<"mul", "Mul", [Commutative]>;
 def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
 def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
 def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
@@ -102,7 +117,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
   let hasFolder = 1;
 }
 def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+def LLVM_ShlOp : LLVM_IntArithmeticOpWithFlag<"shl", "Shl", []> {
   let hasFolder = 1;
 }
 def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b8e449dc11df1..de52476636aed 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -172,6 +172,11 @@ class ModuleImport {
   /// attributes of LLVMFuncOp `funcOp`.
   void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);
 
+  /// Sets the integer arithmetic flags (nsw/nuw) attribute for the imported
+  /// operation `op` given the original instruction `inst`. Asserts if the
+  /// operation does not implement the integer arithmetic flag interface.
+  void setIntegerFlagsAttr(llvm::Instruction *inst, Operation *op) const;
+
   /// Sets the fastmath flags attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
   /// implement the fastmath interface.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 28445945f07d6..3d78970cf6c14 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -62,6 +62,14 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
   return filteredAttrs;
 }
 
+static auto processIntArithAttr(ArrayRef<NamedAttribute> attrs) {
+  SmallVector<NamedAttribute, 8> filteredAttrs(
+      llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
+        return attr.getName() != "arithFlags";
+      }));
+  return filteredAttrs;
+}
+
 static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
                                     NamedAttrList &result) {
   return parser.parseOptionalAttrDict(result);
@@ -69,7 +77,8 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
 
 static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
                              DictionaryAttr attrs) {
-  printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
+  printer.printOptionalAttrDict(
+      processFMFAttr(processIntArithAttr(attrs.getValue())));
 }
 
 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2d1aaa9229cd2..edd0120dcbb71 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
   }
 }
 
+void ModuleImport::setIntegerFlagsAttr(llvm::Instruction *inst,
+                                       Operation *op) const {
+  IntegerArithFlagsInterface iface = cast<IntegerArithFlagsInterface>(op);
+
+  IntegerArithFlags value = {};
+  value = bitEnumSet(value, IntegerArithFlags::nsw, inst->hasNoSignedWrap());
+  value = bitEnumSet(value, IntegerArithFlags::nuw, inst->hasNoUnsignedWrap());
+
+  IntegerArithFlagsAttr attr =
+      IntegerArithFlagsAttr::get(op->getContext(), value);
+  iface->setAttr(iface.getIntegerArithAttrName(), attr);
+}
+
 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
                                         Operation *op) const {
   auto iface = cast<FastmathFlagsInterface>(op);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index ee724a482cfb5..dc0f9f453057d 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr>
   %typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
 
+// Integer arithmetic flags
+// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] flags <nsw> : i32
+// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] flags <nuw> : i32
+// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+  %add_flag = llvm.add %arg0, %arg0 flags <nsw> : i32
+  %sub_flag = llvm.sub %arg0, %arg0 flags <nuw> : i32
+  %mul_flag = llvm.mul %arg0, %arg0 flags <nsw, nuw> : i32
+  %shl_flag = llvm.shl %arg0, %arg0 flags <nuw, nsw> : i32
+
 // Floating point binary operations.
 //
 // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
new file mode 100644
index 0000000000000..2ea0425ec0ff7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -0,0 +1,14 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @intflag_inst
+define void @intflag_inst(i64 %arg1, i64 %arg2) {
+  ; CHECK: llvm.add %{{.*}}, %{{.*}} flags <nsw> : i64
+  %1 = add nsw i64 %arg1, %arg2
+  ; CHECK: llvm.sub %{{.*}}, %{{.*}} flags <nuw> : i64
+  %2 = sub nuw i64 %arg1, %arg2
+  ; CHECK: llvm.mul %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+  %3 = mul nsw nuw i64 %arg1, %arg2
+  ; CHECK: llvm.shl %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+  %4 = shl nuw nsw i64 %arg1, %arg2
+  ret void
+}
diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
new file mode 100644
index 0000000000000..4a7a39bb570c3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @intflags_func
+llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
+  %0 = llvm.add %arg0, %arg1 flags <nsw> : i64
+  // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
+  %1 = llvm.sub %arg0, %arg1 flags <nuw> : i64
+  // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
+  %2 = llvm.mul %arg0, %arg1 flags <nsw, nuw> : i64
+  // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
+  %3 = llvm.shl %arg0, %arg1 flags <nsw, nuw> : i64
+  llvm.return
+}

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this addition. Dropped a bunch of nit comments for now, but this is already looking very good.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the overflow flags!

Copy link

github-actions bot commented Dec 6, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@tblah
Copy link
Contributor Author

tblah commented Dec 6, 2023

Thanks for the speedy review everyone

Copy link
Contributor

@Hardcode84 Hardcode84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo a few places that are not yet following the new naming scheme. Feel free to pick shorter better names if you have something better in mind.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants