Skip to content

[mlir][LLVM] Add nneg flag #115498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 11, 2024
Merged

[mlir][LLVM] Add nneg flag #115498

merged 2 commits into from
Nov 11, 2024

Conversation

lfrenot
Copy link
Contributor

@lfrenot lfrenot commented Nov 8, 2024

This implementation is based on the existing one for the exact flag.

If the nneg flag is set and the argument is negative, the result is a poison value.

@zero9178, @gysit and @Dinistro, could you take a look?

This implementation is based on the existing one for the exact flag.

If the nneg flag is set, and the argument is negative, the result is a poison value.
@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: None (lfrenot)

Changes

This implementation is based on the existing one for the exact flag.

If the nneg flag is set and the argument is negative, the result is a poison value.

@zero9178, @gysit and @Dinistro, could you take a look?


Full diff: https://github.com/llvm/llvm-project/pull/115498.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+27)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+19-2)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+5)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+3)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+6)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+13)
  • (added) mlir/test/Target/LLVMIR/Import/nneg.ll (+10)
  • (added) mlir/test/Target/LLVMIR/nneg.mlir (+10)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 12c430df208925..352e2ec91bdbea 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
   ];
 }
 
+def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
+  let description = [{
+    This interface defines an LLVM operation with an nneg flag and
+    provides a uniform API for accessing it.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<[{
+      Get the nneg flag for the operation.
+    }], "bool", "getNonNeg", (ins), [{}], [{
+      return $_op.getProperties().nonNeg;
+    }]>,
+    InterfaceMethod<[{
+      Set the nneg flag for the operation.
+    }], "void", "setNonNeg", (ins "bool":$nonNeg), [{}], [{
+      $_op.getProperties().nonNeg = nonNeg;
+    }]>,
+    StaticInterfaceMethod<[{
+      Get the attribute name of the nonNeg property.
+    }], "StringRef", "getNonNegName", (ins), [{}], [{
+      return "nonNeg";
+    }]>,
+  ];
+}
+
 def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
   let description = [{
     An interface for operations that can carry branch weights metadata. It
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 315af2594047a5..afdba9b9a16b5f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -508,6 +508,23 @@ class LLVM_CastOp<string mnemonic, string instName, Type type,
       $_location, $_resultType, $arg);
   }];
 }
+class LLVM_CastOpWithNNegFlag<string mnemonic, string instName, Type type,
+                  Type resultType, list<Trait> traits = []> :
+    LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<NonNegFlagInterface>], traits)>,
+    LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.getNonNeg());"> {
+  let arguments = (ins type:$arg,  UnitAttr:$nonNeg);
+  let results = (outs resultType:$res);
+  let builders = [LLVM_OneResultOpBuilder];
+  let assemblyFormat = "(`nneg` $nonNeg^)? $arg attr-dict `:` type($arg) `to` type($res)";
+  string llvmInstName = instName;
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>(
+      $_location, $_resultType, $arg);
+    moduleImport.setNonNegFlag(inst, op);
+    $res = op;
+  }];
+}
+
 def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
     LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
   let hasFolder = 1;
@@ -531,7 +548,7 @@ def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt",
                               LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
   let hasVerifier = 1;
 }
-def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
+def LLVM_ZExtOp : LLVM_CastOpWithNNegFlag<"zext", "ZExt",
                               LLVM_ScalarOrVectorOf<AnySignlessInteger>,
                               LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
   let hasFolder = 1;
@@ -543,7 +560,7 @@ def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
 def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "SIToFP",
                                 LLVM_ScalarOrVectorOf<AnySignlessInteger>,
                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
-def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "UIToFP",
+def LLVM_UIToFPOp : LLVM_CastOpWithNNegFlag<"uitofp", "UIToFP",
                                 LLVM_ScalarOrVectorOf<AnySignlessInteger>,
                                 LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
 def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "FPToSI",
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 6c3a500f20e3a9..30164843f63675 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -192,6 +192,11 @@ class ModuleImport {
   /// implement the exact flag interface.
   void setExactFlag(llvm::Instruction *inst, Operation *op) const;
 
+  /// Sets the nneg flag attribute for the imported operation `op` given
+  /// the original instruction `inst`. Asserts if the operation does not
+  /// implement the nneg flag interface.
+  void setNonNegFlag(llvm::Instruction *inst, Operation *op) const;
+
   /// Sets the fastmath flags attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
   /// implement the fastmath interface.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 6b2d8943bf4885..d39edb8020ccb6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -146,6 +146,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
   } else if (auto iface = dyn_cast<ExactFlagInterface>(op)) {
     printer.printOptionalAttrDict(filteredAttrs,
                                   /*elidedAttrs=*/{iface.getIsExactName()});
+  } else if (auto iface = dyn_cast<NonNegFlagInterface>(op)) {
+    printer.printOptionalAttrDict(filteredAttrs,
+                                  /*elidedAttrs=*/{iface.getNonNegName()});
   } else {
     printer.printOptionalAttrDict(filteredAttrs);
   }
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 12145f7a2217df..71d88d3a62f2b9 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -689,6 +689,12 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
   iface.setIsExact(inst->isExact());
 }
 
+void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
+  auto iface = cast<NonNegFlagInterface>(op);
+
+  iface.setNonNeg(inst->hasNonNeg());
+}
+
 void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
                                         Operation *op) const {
   auto iface = cast<FastmathFlagsInterface>(op);
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 682780c5f0a7df..f96f7d7fecae5a 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -325,6 +325,19 @@ func.func @casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
   llvm.return
 }
 
+// CHECK-LABEL: @nneg_casts
+// CHECK-SAME: (%[[I32:.*]]: i32, %[[I64:.*]]: i64, %[[V4I32:.*]]: vector<4xi32>, %[[V4I64:.*]]: vector<4xi64>, %[[PTR:.*]]: !llvm.ptr)
+func.func @nneg_casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
+                %arg3: vector<4xi64>, %arg4: !llvm.ptr) {
+// CHECK:  = llvm.zext nneg %[[I32]] : i32 to i64
+  %0 = llvm.zext nneg %arg0 : i32 to i64
+// CHECK:  = llvm.zext nneg %[[V4I32]] : vector<4xi32> to vector<4xi64>
+  %4 = llvm.zext nneg %arg2 : vector<4xi32> to vector<4xi64>
+// CHECK: %[[FLOAT:.*]] = llvm.uitofp nneg %[[I32]] : i32 to f32
+  %7 = llvm.uitofp nneg %arg0 : i32 to f32
+  llvm.return
+}
+
 // CHECK-LABEL: @vect
 func.func @vect(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32, %arg3: !llvm.vec<2 x ptr>) {
 // CHECK:  = llvm.extractelement {{.*}} : vector<4xf32>
diff --git a/mlir/test/Target/LLVMIR/Import/nneg.ll b/mlir/test/Target/LLVMIR/Import/nneg.ll
new file mode 100644
index 00000000000000..07756b9f706bdb
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/nneg.ll
@@ -0,0 +1,10 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @nnegflag_inst
+define void @nnegflag_inst(i32 %arg1) {
+  ; CHECK: llvm.zext nneg %{{.*}} : i32 to i64
+  %1 = zext nneg i32 %arg1 to i64
+  ; CHECK: llvm.uitofp nneg %{{.*}} : i32 to f32
+  %2 = uitofp nneg i32 %arg1 to float
+  ret void
+}
diff --git a/mlir/test/Target/LLVMIR/nneg.mlir b/mlir/test/Target/LLVMIR/nneg.mlir
new file mode 100644
index 00000000000000..8afa765a510e24
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nneg.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @nnegflag_func
+llvm.func @nnegflag_func(%arg0: i32) {
+  // CHECK: %{{.*}} = zext nneg i32 %{{.*}} to i64
+  %0 = llvm.zext nneg %arg0 : i32 to i64
+  // CHECK: %{{.*}} = uitofp nneg i32 %{{.*}} to float
+  %1 = llvm.uitofp nneg %arg0 : i32 to f32
+  llvm.return
+}

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

LGTM modulo nit comments!

Type resultType, list<Trait> traits = []> :
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<NonNegFlagInterface>], traits)>,
LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.getNonNeg());"> {
let arguments = (ins type:$arg, UnitAttr:$nonNeg);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let arguments = (ins type:$arg, UnitAttr:$nonNeg);
let arguments = (ins type:$arg, UnitAttr:$nonNeg);

ultra nit:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

%0 = llvm.zext nneg %arg0 : i32 to i64
// CHECK: = llvm.zext nneg %[[V4I32]] : vector<4xi32> to vector<4xi64>
%4 = llvm.zext nneg %arg2 : vector<4xi32> to vector<4xi64>
// CHECK: %[[FLOAT:.*]] = llvm.uitofp nneg %[[I32]] : i32 to f32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// CHECK: %[[FLOAT:.*]] = llvm.uitofp nneg %[[I32]] : i32 to f32
// CHECK: = llvm.uitofp nneg %[[I32]] : i32 to f32

nit: I would drop the variable since it has no uses in this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@@ -146,6 +146,9 @@ static void printLLVMOpAttrs(OpAsmPrinter &printer, Operation *op,
} else if (auto iface = dyn_cast<ExactFlagInterface>(op)) {
printer.printOptionalAttrDict(filteredAttrs,
/*elidedAttrs=*/{iface.getIsExactName()});
} else if (auto iface = dyn_cast<NonNegFlagInterface>(op)) {
printer.printOptionalAttrDict(filteredAttrs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is not needed since the operation does not use custom<LLVMOpAttrs>(attr-dict). It seems like this print/parseLLVMOpAttrs should only be used for the operations that use fast math flags. That can be cleaned up in a separate PR though.

I brief I would suggest to remove the NonNegFlagInterface here, except it is necessary for a reason I do not see right now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@gysit gysit merged commit 89aaf2c into llvm:main Nov 11, 2024
8 checks passed
@lfrenot lfrenot deleted the mlir-llvm-nneg branch November 11, 2024 14:09
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
This implementation is based on the existing one for the exact flag.

If the nneg flag is set and the argument is negative, the result is a
poison value.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants