-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][LLVM] Add nsw and nuw flags #74508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The implementation of these are modelled after the existing fastmath flags for floating point arithmetic.
@llvm/pr-subscribers-mlir Author: Tom Eccles (tblah) ChangesThe implementation of these are modeled after the existing fastmath flags for floating point arithmetic. Full diff: https://github.com/llvm/llvm-project/pull/74508.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f05230526c21f..5cde4980ae17d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr<
let printBitEnumPrimaryGroups = 1;
}
+//===----------------------------------------------------------------------===//
+// IntegerArithFlags
+//===----------------------------------------------------------------------===//
+
+def IAFnone : I32BitEnumAttrCaseNone<"none">;
+def IAFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IAFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerArithFlags : I32BitEnumAttr<
+ "IntegerArithFlags",
+ "LLVM integer arithmetic flags",
+ [IAFnone, IAFnsw, IAFnuw]> {
+ let separator = ", ";
+ let cppNamespace = "::mlir::LLVM";
+ let genSpecializedAttr = 0;
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def LLVM_IntegerArithFlagsAttr :
+ EnumAttr<LLVM_Dialect, IntegerArithFlags, "arith"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// FastmathFlags
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index c5d65f792254e..3d3388ac50aff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
];
}
+def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
+ let description = [{
+ Access to op integer overflow flags.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a IntegerArithFlagsAttr attribute for the operation",
+ /*returnType=*/ "IntegerArithFlagsAttr",
+ /*methodName=*/ "getArithAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getArithFlagsAttr();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "hasNuw",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+ return bitEnumContainsAll(flags, IntegerArithFlags::nuw);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "hasNsw",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+ return bitEnumContainsAll(flags, IntegerArithFlags::nsw);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the IntegerArithFlagsAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getIntegerArithAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "arithFlags";
+ }]
+ >
+ ];
+}
+
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8f166f0cc7cf5..4a2ef07f505b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -55,6 +55,21 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
$res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
}];
}
+class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
+ list<Trait> traits = []> :
+ LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
+ !listconcat([DeclareOpInterfaceMethods<IntegerArithFlagsInterface>], traits)> {
+ dag iafArg = (
+ ins DefaultValuedAttr<LLVM_IntegerArithFlagsAttr, "{}">:$arithFlags);
+ let arguments = !con(commonArgs, iafArg);
+ string mlirBuilder = [{
+ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+ moduleImport.setIntegerFlagsAttr(inst, op);
+ $res = op;
+ }];
+ let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $arithFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
+ string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNuw(), op.hasNsw());";
+}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -90,9 +105,9 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
}
// Integer binary operations.
-def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>;
-def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">;
-def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>;
+def LLVM_AddOp : LLVM_IntArithmeticOpWithFlag<"add", "Add", [Commutative]>;
+def LLVM_SubOp : LLVM_IntArithmeticOpWithFlag<"sub", "Sub", []>;
+def LLVM_MulOp : LLVM_IntArithmeticOpWithFlag<"mul", "Mul", [Commutative]>;
def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
@@ -102,7 +117,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
let hasFolder = 1;
}
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+def LLVM_ShlOp : LLVM_IntArithmeticOpWithFlag<"shl", "Shl", []> {
let hasFolder = 1;
}
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b8e449dc11df1..de52476636aed 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -172,6 +172,11 @@ class ModuleImport {
/// attributes of LLVMFuncOp `funcOp`.
void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);
+ /// Sets the integer arithmetic flags (nsw/nuw) attribute for the imported
+ /// operation `op` given the original instruction `inst`. Asserts if the
+ /// operation does not implement the integer arithmetic flag interface.
+ void setIntegerFlagsAttr(llvm::Instruction *inst, Operation *op) const;
+
/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 28445945f07d6..3d78970cf6c14 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -62,6 +62,14 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
return filteredAttrs;
}
+static auto processIntArithAttr(ArrayRef<NamedAttribute> attrs) {
+ SmallVector<NamedAttribute, 8> filteredAttrs(
+ llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
+ return attr.getName() != "arithFlags";
+ }));
+ return filteredAttrs;
+}
+
static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
NamedAttrList &result) {
return parser.parseOptionalAttrDict(result);
@@ -69,7 +77,8 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
DictionaryAttr attrs) {
- printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
+ printer.printOptionalAttrDict(
+ processFMFAttr(processIntArithAttr(attrs.getValue())));
}
/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2d1aaa9229cd2..edd0120dcbb71 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
}
}
+void ModuleImport::setIntegerFlagsAttr(llvm::Instruction *inst,
+ Operation *op) const {
+ IntegerArithFlagsInterface iface = cast<IntegerArithFlagsInterface>(op);
+
+ IntegerArithFlags value = {};
+ value = bitEnumSet(value, IntegerArithFlags::nsw, inst->hasNoSignedWrap());
+ value = bitEnumSet(value, IntegerArithFlags::nuw, inst->hasNoUnsignedWrap());
+
+ IntegerArithFlagsAttr attr =
+ IntegerArithFlagsAttr::get(op->getContext(), value);
+ iface->setAttr(iface.getIntegerArithAttrName(), attr);
+}
+
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index ee724a482cfb5..dc0f9f453057d 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
%vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr>
%typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
+// Integer arithmetic flags
+// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] flags <nsw> : i32
+// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] flags <nuw> : i32
+// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+ %add_flag = llvm.add %arg0, %arg0 flags <nsw> : i32
+ %sub_flag = llvm.sub %arg0, %arg0 flags <nuw> : i32
+ %mul_flag = llvm.mul %arg0, %arg0 flags <nsw, nuw> : i32
+ %shl_flag = llvm.shl %arg0, %arg0 flags <nuw, nsw> : i32
+
// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
new file mode 100644
index 0000000000000..2ea0425ec0ff7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -0,0 +1,14 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @intflag_inst
+define void @intflag_inst(i64 %arg1, i64 %arg2) {
+ ; CHECK: llvm.add %{{.*}}, %{{.*}} flags <nsw> : i64
+ %1 = add nsw i64 %arg1, %arg2
+ ; CHECK: llvm.sub %{{.*}}, %{{.*}} flags <nuw> : i64
+ %2 = sub nuw i64 %arg1, %arg2
+ ; CHECK: llvm.mul %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+ %3 = mul nsw nuw i64 %arg1, %arg2
+ ; CHECK: llvm.shl %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+ %4 = shl nuw nsw i64 %arg1, %arg2
+ ret void
+}
diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
new file mode 100644
index 0000000000000..4a7a39bb570c3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @intflags_func
+llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
+ %0 = llvm.add %arg0, %arg1 flags <nsw> : i64
+ // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
+ %1 = llvm.sub %arg0, %arg1 flags <nuw> : i64
+ // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
+ %2 = llvm.mul %arg0, %arg1 flags <nsw, nuw> : i64
+ // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
+ %3 = llvm.shl %arg0, %arg1 flags <nsw, nuw> : i64
+ llvm.return
+}
|
@llvm/pr-subscribers-mlir-llvm Author: Tom Eccles (tblah) ChangesThe implementation of these are modeled after the existing fastmath flags for floating point arithmetic. Full diff: https://github.com/llvm/llvm-project/pull/74508.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
index f05230526c21f..5cde4980ae17d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
@@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr<
let printBitEnumPrimaryGroups = 1;
}
+//===----------------------------------------------------------------------===//
+// IntegerArithFlags
+//===----------------------------------------------------------------------===//
+
+def IAFnone : I32BitEnumAttrCaseNone<"none">;
+def IAFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
+def IAFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
+
+def IntegerArithFlags : I32BitEnumAttr<
+ "IntegerArithFlags",
+ "LLVM integer arithmetic flags",
+ [IAFnone, IAFnsw, IAFnuw]> {
+ let separator = ", ";
+ let cppNamespace = "::mlir::LLVM";
+ let genSpecializedAttr = 0;
+ let printBitEnumPrimaryGroups = 1;
+}
+
+def LLVM_IntegerArithFlagsAttr :
+ EnumAttr<LLVM_Dialect, IntegerArithFlags, "arith"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
//===----------------------------------------------------------------------===//
// FastmathFlags
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index c5d65f792254e..3d3388ac50aff 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
];
}
+def IntegerArithFlagsInterface : OpInterface<"IntegerArithFlagsInterface"> {
+ let description = [{
+ Access to op integer overflow flags.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/ "Returns a IntegerArithFlagsAttr attribute for the operation",
+ /*returnType=*/ "IntegerArithFlagsAttr",
+ /*methodName=*/ "getArithAttr",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ return op.getArithFlagsAttr();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "hasNuw",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+ return bitEnumContainsAll(flags, IntegerArithFlags::nuw);
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
+ /*returnType=*/ "bool",
+ /*methodName=*/ "hasNsw",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ auto op = cast<ConcreteOp>(this->getOperation());
+ IntegerArithFlags flags = op.getArithFlagsAttr().getValue();
+ return bitEnumContainsAll(flags, IntegerArithFlags::nsw);
+ }]
+ >,
+ StaticInterfaceMethod<
+ /*desc=*/ [{Returns the name of the IntegerArithFlagsAttr attribute
+ for the operation}],
+ /*returnType=*/ "StringRef",
+ /*methodName=*/ "getIntegerArithAttrName",
+ /*args=*/ (ins),
+ /*methodBody=*/ [{}],
+ /*defaultImpl=*/ [{
+ return "arithFlags";
+ }]
+ >
+ ];
+}
+
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 8f166f0cc7cf5..4a2ef07f505b4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -55,6 +55,21 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
$res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
}];
}
+class LLVM_IntArithmeticOpWithFlag<string mnemonic, string instName,
+ list<Trait> traits = []> :
+ LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
+ !listconcat([DeclareOpInterfaceMethods<IntegerArithFlagsInterface>], traits)> {
+ dag iafArg = (
+ ins DefaultValuedAttr<LLVM_IntegerArithFlagsAttr, "{}">:$arithFlags);
+ let arguments = !con(commonArgs, iafArg);
+ string mlirBuilder = [{
+ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+ moduleImport.setIntegerFlagsAttr(inst, op);
+ $res = op;
+ }];
+ let assemblyFormat = "$lhs `,` $rhs (`flags` ` ` $arithFlags^)? custom<LLVMOpAttrs>(attr-dict) `:` type($res)";
+ string llvmBuilder = "$res = builder.Create" # instName # "($lhs, $rhs, /*Name=*/\"\", op.hasNuw(), op.hasNsw());";
+}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -90,9 +105,9 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
}
// Integer binary operations.
-def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>;
-def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">;
-def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>;
+def LLVM_AddOp : LLVM_IntArithmeticOpWithFlag<"add", "Add", [Commutative]>;
+def LLVM_SubOp : LLVM_IntArithmeticOpWithFlag<"sub", "Sub", []>;
+def LLVM_MulOp : LLVM_IntArithmeticOpWithFlag<"mul", "Mul", [Commutative]>;
def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
@@ -102,7 +117,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
let hasFolder = 1;
}
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
-def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
+def LLVM_ShlOp : LLVM_IntArithmeticOpWithFlag<"shl", "Shl", []> {
let hasFolder = 1;
}
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index b8e449dc11df1..de52476636aed 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -172,6 +172,11 @@ class ModuleImport {
/// attributes of LLVMFuncOp `funcOp`.
void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);
+ /// Sets the integer arithmetic flags (nsw/nuw) attribute for the imported
+ /// operation `op` given the original instruction `inst`. Asserts if the
+ /// operation does not implement the integer arithmetic flag interface.
+ void setIntegerFlagsAttr(llvm::Instruction *inst, Operation *op) const;
+
/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 28445945f07d6..3d78970cf6c14 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -62,6 +62,14 @@ static auto processFMFAttr(ArrayRef<NamedAttribute> attrs) {
return filteredAttrs;
}
+static auto processIntArithAttr(ArrayRef<NamedAttribute> attrs) {
+ SmallVector<NamedAttribute, 8> filteredAttrs(
+ llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
+ return attr.getName() != "arithFlags";
+ }));
+ return filteredAttrs;
+}
+
static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
NamedAttrList &result) {
return parser.parseOptionalAttrDict(result);
@@ -69,7 +77,8 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
DictionaryAttr attrs) {
- printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
+ printer.printOptionalAttrDict(
+ processFMFAttr(processIntArithAttr(attrs.getValue())));
}
/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 2d1aaa9229cd2..edd0120dcbb71 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
}
}
+void ModuleImport::setIntegerFlagsAttr(llvm::Instruction *inst,
+ Operation *op) const {
+ IntegerArithFlagsInterface iface = cast<IntegerArithFlagsInterface>(op);
+
+ IntegerArithFlags value = {};
+ value = bitEnumSet(value, IntegerArithFlags::nsw, inst->hasNoSignedWrap());
+ value = bitEnumSet(value, IntegerArithFlags::nuw, inst->hasNoUnsignedWrap());
+
+ IntegerArithFlagsAttr attr =
+ IntegerArithFlagsAttr::get(op->getContext(), value);
+ iface->setAttr(iface.getIntegerArithAttrName(), attr);
+}
+
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index ee724a482cfb5..dc0f9f453057d 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
%vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr>
%typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
+// Integer arithmetic flags
+// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] flags <nsw> : i32
+// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] flags <nuw> : i32
+// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] flags <nsw, nuw> : i32
+ %add_flag = llvm.add %arg0, %arg0 flags <nsw> : i32
+ %sub_flag = llvm.sub %arg0, %arg0 flags <nuw> : i32
+ %mul_flag = llvm.mul %arg0, %arg0 flags <nsw, nuw> : i32
+ %shl_flag = llvm.shl %arg0, %arg0 flags <nuw, nsw> : i32
+
// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
new file mode 100644
index 0000000000000..2ea0425ec0ff7
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -0,0 +1,14 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @intflag_inst
+define void @intflag_inst(i64 %arg1, i64 %arg2) {
+ ; CHECK: llvm.add %{{.*}}, %{{.*}} flags <nsw> : i64
+ %1 = add nsw i64 %arg1, %arg2
+ ; CHECK: llvm.sub %{{.*}}, %{{.*}} flags <nuw> : i64
+ %2 = sub nuw i64 %arg1, %arg2
+ ; CHECK: llvm.mul %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+ %3 = mul nsw nuw i64 %arg1, %arg2
+ ; CHECK: llvm.shl %{{.*}}, %{{.*}} flags <nsw, nuw> : i64
+ %4 = shl nuw nsw i64 %arg1, %arg2
+ ret void
+}
diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
new file mode 100644
index 0000000000000..4a7a39bb570c3
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @intflags_func
+llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
+ %0 = llvm.add %arg0, %arg1 flags <nsw> : i64
+ // CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
+ %1 = llvm.sub %arg0, %arg1 flags <nuw> : i64
+ // CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
+ %2 = llvm.mul %arg0, %arg1 flags <nsw, nuw> : i64
+ // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
+ %3 = llvm.shl %arg0, %arg1 flags <nsw, nuw> : i64
+ llvm.return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this addition. Dropped a bunch of nit comments for now, but this is already looking very good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding the overflow flags!
✅ With the latest revision this PR passed the C/C++ code formatter. |
Thanks for the speedy review everyone |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo a few places that are not yet following the new naming scheme. Feel free to pick shorter better names if you have something better in mind.
The implementation of these are modeled after the existing fastmath flags for floating point arithmetic.