-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][LLVM] Add nneg flag #115498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][LLVM] Add nneg flag #115498
Conversation
This implementation is based on the existing one for the exact flag. If the nneg flag is set, and the argument is negative, the result is a poison value.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: None (lfrenot) ChangesThis implementation is based on the existing one for the exact flag. If the nneg flag is set and the argument is negative, the result is a poison value. @zero9178, @gysit and @Dinistro, could you take a look? Full diff: https://github.com/llvm/llvm-project/pull/115498.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 12c430df208925..352e2ec91bdbea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
];
}
+def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
+ let description = [{
+ This interface defines an LLVM operation with an nneg flag and
+ provides a uniform API for accessing it.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<[{
+ Get the nneg flag for the operation.
+ }], "bool", "getNonNeg", (ins), [{}], [{
+ return $_op.getProperties().nonNeg;
+ }]>,
+ InterfaceMethod<[{
+ Set the nneg flag for the operation.
+ }], "void", "setNonNeg", (ins "bool":$nonNeg), [{}], [{
+ $_op.getProperties().nonNeg = nonNeg;
+ }]>,
+ StaticInterfaceMethod<[{
+ Get the attribute name of the nonNeg property.
+ }], "StringRef", "getNonNegName", (ins), [{}], [{
+ return "nonNeg";
+ }]>,
+ ];
+}
+
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
let description = [{
An interface for operations that can carry branch weights metadata. It
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 315af2594047a5..afdba9b9a16b5f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -508,6 +508,23 @@ class LLVM_CastOp<string mnemonic, string instName, Type type,
$_location, $_resultType, $arg);
}];
}
+class LLVM_CastOpWithNNegFlag<string mnemonic, string instName, Type type,
+ Type resultType, list<Trait> traits = []> :
+ LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<NonNegFlagInterface>], traits)>,
+ LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.getNonNeg());"> {
+ let arguments = (ins type:$arg, UnitAttr:$nonNeg);
+ let results = (outs resultType:$res);
+ let builders = [LLVM_OneResultOpBuilder];
+ let assemblyFormat = "(`nneg` $nonNeg^)? $arg attr-dict `:` type($arg) `to` type($res)";
+ string llvmInstName = instName;
+ string mlirBuilder = [{
+ auto op = $_builder.create<$_qualCppClassName>(
+ $_location, $_resultType, $arg);
+ moduleImport.setNonNegFlag(inst, op);
+ $res = op;
+ }];
+}
+
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
let hasFolder = 1;
@@ -531,7 +548,7 @@ def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt",
LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
let hasVerifier = 1;
}
-def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
+def LLVM_ZExtOp : LLVM_CastOpWithNNegFlag<"zext", "ZExt",
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
let hasFolder = 1;
@@ -543,7 +560,7 @@ def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "SIToFP",
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
-def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "UIToFP",
+def LLVM_UIToFPOp : LLVM_CastOpWithNNegFlag<"uitofp", "UIToFP",
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "FPToSI",
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 6c3a500f20e3a9..30164843f63675 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -192,6 +192,11 @@ class ModuleImport {
/// implement the exact flag interface.
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
+ /// Sets the nneg flag attribute for the imported operation `op` given
+ /// the original instruction `inst`. Asserts if the operation does not
+ /// implement the nneg flag interface.
+ void setNonNegFlag(llvm::Instruction *inst, Operation *op) const;
+
/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the fastmath interface.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 6b2d8943bf4885..d39edb8020ccb6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -146,6 +146,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
} else if (auto iface = dyn_cast<ExactFlagInterface>(op)) {
printer.printOptionalAttrDict(filteredAttrs,
/*elidedAttrs=*/{iface.getIsExactName()});
+ } else if (auto iface = dyn_cast<NonNegFlagInterface>(op)) {
+ printer.printOptionalAttrDict(filteredAttrs,
+ /*elidedAttrs=*/{iface.getNonNegName()});
} else {
printer.printOptionalAttrDict(filteredAttrs);
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 12145f7a2217df..71d88d3a62f2b9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -689,6 +689,12 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
iface.setIsExact(inst->isExact());
}
+void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
+ auto iface = cast<NonNegFlagInterface>(op);
+
+ iface.setNonNeg(inst->hasNonNeg());
+}
+
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<FastmathFlagsInterface>(op);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 682780c5f0a7df..f96f7d7fecae5a 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -325,6 +325,19 @@ func.func @casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
llvm.return
}
+// CHECK-LABEL: @nneg_casts
+// CHECK-SAME: (%[[I32:.*]]: i32, %[[I64:.*]]: i64, %[[V4I32:.*]]: vector<4xi32>, %[[V4I64:.*]]: vector<4xi64>, %[[PTR:.*]]: !llvm.ptr)
+func.func @nneg_casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
+ %arg3: vector<4xi64>, %arg4: !llvm.ptr) {
+// CHECK: = llvm.zext nneg %[[I32]] : i32 to i64
+ %0 = llvm.zext nneg %arg0 : i32 to i64
+// CHECK: = llvm.zext nneg %[[V4I32]] : vector<4xi32> to vector<4xi64>
+ %4 = llvm.zext nneg %arg2 : vector<4xi32> to vector<4xi64>
+// CHECK: %[[FLOAT:.*]] = llvm.uitofp nneg %[[I32]] : i32 to f32
+ %7 = llvm.uitofp nneg %arg0 : i32 to f32
+ llvm.return
+}
+
// CHECK-LABEL: @vect
func.func @vect(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32, %arg3: !llvm.vec<2 x ptr>) {
// CHECK: = llvm.extractelement {{.*}} : vector<4xf32>
diff --git a/mlir/test/Target/LLVMIR/Import/nneg.ll b/mlir/test/Target/LLVMIR/Import/nneg.ll
new file mode 100644
index 00000000000000..07756b9f706bdb
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nneg.ll
@@ -0,0 +1,10 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @nnegflag_inst
+define void @nnegflag_inst(i32 %arg1) {
+ ; CHECK: llvm.zext nneg %{{.*}} : i32 to i64
+ %1 = zext nneg i32 %arg1 to i64
+ ; CHECK: llvm.uitofp nneg %{{.*}} : i32 to f32
+ %2 = uitofp nneg i32 %arg1 to float
+ ret void
+}
diff --git a/mlir/test/Target/LLVMIR/nneg.mlir b/mlir/test/Target/LLVMIR/nneg.mlir
new file mode 100644
index 00000000000000..8afa765a510e24
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nneg.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @nnegflag_func
+llvm.func @nnegflag_func(%arg0: i32) {
+ // CHECK: %{{.*}} = zext nneg i32 %{{.*}} to i64
+ %0 = llvm.zext nneg %arg0 : i32 to i64
+ // CHECK: %{{.*}} = uitofp nneg i32 %{{.*}} to float
+ %1 = llvm.uitofp nneg %arg0 : i32 to f32
+ llvm.return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
LGTM modulo nit comments!
Type resultType, list<Trait> traits = []> : | ||
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<NonNegFlagInterface>], traits)>, | ||
LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.getNonNeg());"> { | ||
let arguments = (ins type:$arg, UnitAttr:$nonNeg); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let arguments = (ins type:$arg, UnitAttr:$nonNeg); | |
let arguments = (ins type:$arg, UnitAttr:$nonNeg); |
ultra nit:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
%0 = llvm.zext nneg %arg0 : i32 to i64 | ||
// CHECK: = llvm.zext nneg %[[V4I32]] : vector<4xi32> to vector<4xi64> | ||
%4 = llvm.zext nneg %arg2 : vector<4xi32> to vector<4xi64> | ||
// CHECK: %[[FLOAT:.*]] = llvm.uitofp nneg %[[I32]] : i32 to f32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// CHECK: %[[FLOAT:.*]] = llvm.uitofp nneg %[[I32]] : i32 to f32 | |
// CHECK: = llvm.uitofp nneg %[[I32]] : i32 to f32 |
nit: I would drop the variable since it has no uses in this test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
@@ -146,6 +146,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op, | |||
} else if (auto iface = dyn_cast<ExactFlagInterface>(op)) { | |||
printer.printOptionalAttrDict(filteredAttrs, | |||
/*elidedAttrs=*/{iface.getIsExactName()}); | |||
} else if (auto iface = dyn_cast<NonNegFlagInterface>(op)) { | |||
printer.printOptionalAttrDict(filteredAttrs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is not needed since the operation does not use custom<LLVMOpAttrs>(attr-dict)
. It seems like this print/parseLLVMOpAttrs should only be used for the operations that use fast math flags. That can be cleaned up in a separate PR though.
I brief I would suggest to remove the NonNegFlagInterface here, except it is necessary for a reason I do not see right now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
This implementation is based on the existing one for the exact flag. If the nneg flag is set and the argument is negative, the result is a poison value.
This implementation is based on the existing one for the exact flag.
If the nneg flag is set and the argument is negative, the result is a poison value.
@zero9178, @gysit and @Dinistro, could you take a look?