Skip to content

Commit e9e1c41

Browse files
authored
[mlir][LLVM] Add nsw and nuw flags (#74508)
The implementation of these are modeled after the existing fastmath flags for floating point arithmetic.
1 parent 22df088 commit e9e1c41

File tree

9 files changed

+170
-5
lines changed

9 files changed

+170
-5
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr<
428428
let printBitEnumPrimaryGroups = 1;
429429
}
430430

431+
//===----------------------------------------------------------------------===//
432+
// IntegerOverflowFlags
433+
//===----------------------------------------------------------------------===//
434+
435+
def IOFnone : I32BitEnumAttrCaseNone<"none">;
436+
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
437+
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
438+
439+
def IntegerOverflowFlags : I32BitEnumAttr<
440+
"IntegerOverflowFlags",
441+
"LLVM integer overflow flags",
442+
[IOFnone, IOFnsw, IOFnuw]> {
443+
let separator = ", ";
444+
let cppNamespace = "::mlir::LLVM";
445+
let genSpecializedAttr = 0;
446+
let printBitEnumPrimaryGroups = 1;
447+
}
448+
449+
def LLVM_IntegerOverflowFlagsAttr :
450+
EnumAttr<LLVM_Dialect, IntegerOverflowFlags, "overflow"> {
451+
let assemblyFormat = "`<` $value `>`";
452+
}
453+
431454
//===----------------------------------------------------------------------===//
432455
// FastmathFlags
433456
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
4848
];
4949
}
5050

51+
def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
52+
let description = [{
53+
Access to op integer overflow flags.
54+
}];
55+
56+
let cppNamespace = "::mlir::LLVM";
57+
58+
let methods = [
59+
InterfaceMethod<
60+
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
61+
/*returnType=*/ "IntegerOverflowFlagsAttr",
62+
/*methodName=*/ "getOverflowAttr",
63+
/*args=*/ (ins),
64+
/*methodBody=*/ [{}],
65+
/*defaultImpl=*/ [{
66+
auto op = cast<ConcreteOp>(this->getOperation());
67+
return op.getOverflowFlagsAttr();
68+
}]
69+
>,
70+
InterfaceMethod<
71+
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
72+
/*returnType=*/ "bool",
73+
/*methodName=*/ "hasNoUnsignedWrap",
74+
/*args=*/ (ins),
75+
/*methodBody=*/ [{}],
76+
/*defaultImpl=*/ [{
77+
auto op = cast<ConcreteOp>(this->getOperation());
78+
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
79+
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
80+
}]
81+
>,
82+
InterfaceMethod<
83+
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
84+
/*returnType=*/ "bool",
85+
/*methodName=*/ "hasNoSignedWrap",
86+
/*args=*/ (ins),
87+
/*methodBody=*/ [{}],
88+
/*defaultImpl=*/ [{
89+
auto op = cast<ConcreteOp>(this->getOperation());
90+
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
91+
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
92+
}]
93+
>,
94+
StaticInterfaceMethod<
95+
/*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute
96+
for the operation}],
97+
/*returnType=*/ "StringRef",
98+
/*methodName=*/ "getIntegerOverflowAttrName",
99+
/*args=*/ (ins),
100+
/*methodBody=*/ [{}],
101+
/*defaultImpl=*/ [{
102+
return "overflowFlags";
103+
}]
104+
>
105+
];
106+
}
107+
51108
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
52109
let description = [{
53110
An interface for operations that can carry branch weights metadata. It

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
5555
$res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
5656
}];
5757
}
58+
class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
59+
list<Trait> traits = []> :
60+
LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
61+
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
62+
dag iofArg = (
63+
ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
64+
let arguments = !con(commonArgs, iofArg);
65+
string mlirBuilder = [{
66+
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
67+
moduleImport.setIntegerOverflowFlagsAttr(inst, op);
68+
$res = op;
69+
}];
70+
let assemblyFormat = [{
71+
$lhs `,` $rhs (`overflow` `` $overflowFlags^)?
72+
custom<LLVMOpAttrs>(attr-dict) `:` type($res)
73+
}];
74+
string llvmBuilder =
75+
"$res = builder.Create" # instName #
76+
"($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
77+
}
5878
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
5979
list<Trait> traits = []> :
6080
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -90,9 +110,11 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
90110
}
91111

92112
// Integer binary operations.
93-
def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>;
94-
def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">;
95-
def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>;
113+
def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add",
114+
[Commutative]>;
115+
def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>;
116+
def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul",
117+
[Commutative]>;
96118
def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
97119
def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
98120
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
@@ -102,7 +124,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
102124
let hasFolder = 1;
103125
}
104126
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
105-
def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
127+
def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> {
106128
let hasFolder = 1;
107129
}
108130
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ class ModuleImport {
172172
/// attributes of LLVMFuncOp `funcOp`.
173173
void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);
174174

175+
/// Sets the integer overflow flags (nsw/nuw) attribute for the imported
176+
/// operation `op` given the original instruction `inst`. Asserts if the
177+
/// operation does not implement the integer overflow flag interface.
178+
void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
179+
Operation *op) const;
180+
175181
/// Sets the fastmath flags attribute for the imported operation `op` given
176182
/// the original instruction `inst`. Asserts if the operation does not
177183
/// implement the fastmath interface.

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,13 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,
6969

7070
static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
7171
DictionaryAttr attrs) {
72-
printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
72+
auto filteredAttrs = processFMFAttr(attrs.getValue());
73+
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
74+
printer.printOptionalAttrDict(
75+
filteredAttrs,
76+
/*elidedAttrs=*/{iface.getIntegerOverflowAttrName()});
77+
else
78+
printer.printOptionalAttrDict(filteredAttrs);
7379
}
7480

7581
/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
619619
}
620620
}
621621

622+
void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
623+
Operation *op) const {
624+
auto iface = cast<IntegerOverflowFlagsInterface>(op);
625+
626+
IntegerOverflowFlags value = {};
627+
value = bitEnumSet(value, IntegerOverflowFlags::nsw, inst->hasNoSignedWrap());
628+
value =
629+
bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());
630+
631+
auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value);
632+
iface->setAttr(iface.getIntegerOverflowAttrName(), attr);
633+
}
634+
622635
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
623636
Operation *op) const {
624637
auto iface = cast<FastmathFlagsInterface>(op);

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
3434
%vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr>
3535
%typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>
3636

37+
// Integer overflow flags
38+
// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] overflow<nsw> : i32
39+
// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] overflow<nuw> : i32
40+
// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] overflow<nsw, nuw> : i32
41+
// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] overflow<nsw, nuw> : i32
42+
%add_flag = llvm.add %arg0, %arg0 overflow<nsw> : i32
43+
%sub_flag = llvm.sub %arg0, %arg0 overflow<nuw> : i32
44+
%mul_flag = llvm.mul %arg0, %arg0 overflow<nsw, nuw> : i32
45+
%shl_flag = llvm.shl %arg0, %arg0 overflow<nuw, nsw> : i32
46+
3747
// Floating point binary operations.
3848
//
3949
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
2+
3+
; CHECK-LABEL: @intflag_inst
4+
define void @intflag_inst(i64 %arg1, i64 %arg2) {
5+
; CHECK: llvm.add %{{.*}}, %{{.*}} overflow<nsw> : i64
6+
%1 = add nsw i64 %arg1, %arg2
7+
; CHECK: llvm.sub %{{.*}}, %{{.*}} overflow<nuw> : i64
8+
%2 = sub nuw i64 %arg1, %arg2
9+
; CHECK: llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
10+
%3 = mul nsw nuw i64 %arg1, %arg2
11+
; CHECK: llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
12+
%4 = shl nuw nsw i64 %arg1, %arg2
13+
ret void
14+
}

mlir/test/Target/LLVMIR/nsw_nuw.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: define void @intflags_func
4+
llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
5+
// CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
6+
%0 = llvm.add %arg0, %arg1 overflow <nsw> : i64
7+
// CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
8+
%1 = llvm.sub %arg0, %arg1 overflow <nuw> : i64
9+
// CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
10+
%2 = llvm.mul %arg0, %arg1 overflow <nsw, nuw> : i64
11+
// CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
12+
%3 = llvm.shl %arg0, %arg1 overflow <nsw, nuw> : i64
13+
llvm.return
14+
}

0 commit comments

Comments
 (0)