Skip to content

[mlir][LLVM] Add nsw and nuw flags #74508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,29 @@ def DISubprogramFlags : I32BitEnumAttr<
let printBitEnumPrimaryGroups = 1;
}

//===----------------------------------------------------------------------===//
// IntegerOverflowFlags
//===----------------------------------------------------------------------===//

def IOFnone : I32BitEnumAttrCaseNone<"none">;
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;

def IntegerOverflowFlags : I32BitEnumAttr<
"IntegerOverflowFlags",
"LLVM integer overflow flags",
[IOFnone, IOFnsw, IOFnuw]> {
let separator = ", ";
let cppNamespace = "::mlir::LLVM";
let genSpecializedAttr = 0;
let printBitEnumPrimaryGroups = 1;
}

def LLVM_IntegerOverflowFlagsAttr :
EnumAttr<LLVM_Dialect, IntegerOverflowFlags, "overflow"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// FastmathFlags
//===----------------------------------------------------------------------===//
Expand Down
57 changes: 57 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,63 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
];
}

def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
let description = [{
Access to op integer overflow flags.
}];

let cppNamespace = "::mlir::LLVM";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
/*returnType=*/ "IntegerOverflowFlagsAttr",
/*methodName=*/ "getOverflowAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getOverflowFlagsAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoUnsignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoSignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getIntegerOverflowAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "overflowFlags";
}]
>
];
}

def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
Expand Down
30 changes: 26 additions & 4 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,26 @@ class LLVM_IntArithmeticOp<string mnemonic, string instName,
$res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
}];
}
class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnyInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
dag iofArg = (
ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
let arguments = !con(commonArgs, iofArg);
string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
moduleImport.setIntegerOverflowFlagsAttr(inst, op);
$res = op;
}];
let assemblyFormat = [{
$lhs `,` $rhs (`overflow` `` $overflowFlags^)?
custom<LLVMOpAttrs>(attr-dict) `:` type($res)
}];
string llvmBuilder =
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());";
}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
Expand Down Expand Up @@ -90,9 +110,11 @@ class LLVM_UnaryFloatArithmeticOp<Type type, string mnemonic,
}

// Integer binary operations.
def LLVM_AddOp : LLVM_IntArithmeticOp<"add", "Add", [Commutative]>;
def LLVM_SubOp : LLVM_IntArithmeticOp<"sub", "Sub">;
def LLVM_MulOp : LLVM_IntArithmeticOp<"mul", "Mul", [Commutative]>;
def LLVM_AddOp : LLVM_IntArithmeticOpWithOverflowFlag<"add", "Add",
[Commutative]>;
def LLVM_SubOp : LLVM_IntArithmeticOpWithOverflowFlag<"sub", "Sub", []>;
def LLVM_MulOp : LLVM_IntArithmeticOpWithOverflowFlag<"mul", "Mul",
[Commutative]>;
def LLVM_UDivOp : LLVM_IntArithmeticOp<"udiv", "UDiv">;
def LLVM_SDivOp : LLVM_IntArithmeticOp<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
Expand All @@ -102,7 +124,7 @@ def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
let hasFolder = 1;
}
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
def LLVM_ShlOp : LLVM_IntArithmeticOp<"shl", "Shl"> {
def LLVM_ShlOp : LLVM_IntArithmeticOpWithOverflowFlag<"shl", "Shl", []> {
let hasFolder = 1;
}
def LLVM_LShrOp : LLVM_IntArithmeticOp<"lshr", "LShr">;
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ class ModuleImport {
/// attributes of LLVMFuncOp `funcOp`.
void processFunctionAttributes(llvm::Function *func, LLVMFuncOp funcOp);

/// Sets the integer overflow flags (nsw/nuw) attribute for the imported
/// operation `op` given the original instruction `inst`. Asserts if the
/// operation does not implement the integer overflow flag interface.
void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
Operation *op) const;

/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.
Expand Down
8 changes: 7 additions & 1 deletion mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@ static ParseResult parseLLVMOpAttrs(OpAsmParser &parser,

static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
DictionaryAttr attrs) {
printer.printOptionalAttrDict(processFMFAttr(attrs.getValue()));
auto filteredAttrs = processFMFAttr(attrs.getValue());
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
printer.printOptionalAttrDict(
filteredAttrs,
/*elidedAttrs=*/{iface.getIntegerOverflowAttrName()});
else
printer.printOptionalAttrDict(filteredAttrs);
}

/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,19 @@ void ModuleImport::setNonDebugMetadataAttrs(llvm::Instruction *inst,
}
}

void ModuleImport::setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<IntegerOverflowFlagsInterface>(op);

IntegerOverflowFlags value = {};
value = bitEnumSet(value, IntegerOverflowFlags::nsw, inst->hasNoSignedWrap());
value =
bitEnumSet(value, IntegerOverflowFlags::nuw, inst->hasNoUnsignedWrap());

auto attr = IntegerOverflowFlagsAttr::get(op->getContext(), value);
iface->setAttr(iface.getIntegerOverflowAttrName(), attr);
}

void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ func.func @ops(%arg0: i32, %arg1: f32,
%vptrcmp = llvm.icmp "ne" %arg5, %arg5 : !llvm.vec<2 x ptr>
%typecheck_vptrcmp = llvm.add %vptrcmp, %vptrcmp : vector<2 x i1>

// Integer overflow flags
// CHECK: {{.*}} = llvm.add %[[I32]], %[[I32]] overflow<nsw> : i32
// CHECK: {{.*}} = llvm.sub %[[I32]], %[[I32]] overflow<nuw> : i32
// CHECK: {{.*}} = llvm.mul %[[I32]], %[[I32]] overflow<nsw, nuw> : i32
// CHECK: {{.*}} = llvm.shl %[[I32]], %[[I32]] overflow<nsw, nuw> : i32
%add_flag = llvm.add %arg0, %arg0 overflow<nsw> : i32
%sub_flag = llvm.sub %arg0, %arg0 overflow<nuw> : i32
%mul_flag = llvm.mul %arg0, %arg0 overflow<nsw, nuw> : i32
%shl_flag = llvm.shl %arg0, %arg0 overflow<nuw, nsw> : i32

// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s

; CHECK-LABEL: @intflag_inst
define void @intflag_inst(i64 %arg1, i64 %arg2) {
; CHECK: llvm.add %{{.*}}, %{{.*}} overflow<nsw> : i64
%1 = add nsw i64 %arg1, %arg2
; CHECK: llvm.sub %{{.*}}, %{{.*}} overflow<nuw> : i64
%2 = sub nuw i64 %arg1, %arg2
; CHECK: llvm.mul %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%3 = mul nsw nuw i64 %arg1, %arg2
; CHECK: llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%4 = shl nuw nsw i64 %arg1, %arg2
ret void
}
14 changes: 14 additions & 0 deletions mlir/test/Target/LLVMIR/nsw_nuw.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// CHECK-LABEL: define void @intflags_func
llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = add nsw i64 %{{.*}}, %{{.*}}
%0 = llvm.add %arg0, %arg1 overflow <nsw> : i64
// CHECK: %{{.*}} = sub nuw i64 %{{.*}}, %{{.*}}
%1 = llvm.sub %arg0, %arg1 overflow <nuw> : i64
// CHECK: %{{.*}} = mul nuw nsw i64 %{{.*}}, %{{.*}}
%2 = llvm.mul %arg0, %arg1 overflow <nsw, nuw> : i64
// CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
%3 = llvm.shl %arg0, %arg1 overflow <nsw, nuw> : i64
llvm.return
}