Skip to content

[mlir][LLVM] Add disjoint flag #115855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 15, 2024
Merged

[mlir][LLVM] Add disjoint flag #115855

merged 6 commits into from
Nov 15, 2024

Conversation

lfrenot
Copy link
Contributor

@lfrenot lfrenot commented Nov 12, 2024

The implementation is mostly based on the one existing for the exact flag.

disjoint means that for each bit, that bit is zero in at least one of the inputs. This allows the Or to be treated as an Add since no carry can occur from any bit. If the disjoint keyword is present, the result value of the or is a poison value if both inputs have a one in the same bit position. For vectors, only the element containing the bit is poison.

@lfrenot lfrenot marked this pull request as ready for review November 12, 2024 11:21
@lfrenot
Copy link
Contributor Author

lfrenot commented Nov 12, 2024

@zero9178, @gysit and @Dinistro, could you take a look?

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: None (lfrenot)

Changes

The implementation is mostly based on the one existing for the exact flag.

disjoint means that for each bit, that bit is zero in at least one of the inputs. This allows the Or to be treated as an Add since no carry can occur from any bit. If the disjoint keyword is present, the result value of the or is a poison value if both inputs have a one in the same bit position. For vectors, only the element containing the bit is poison.


Full diff: https://github.com/llvm/llvm-project/pull/115855.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+27)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+21-1)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+5)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h (+5)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+8)
  • (modified) mlir/lib/Target/LLVMIR/ModuleTranslation.cpp (+8)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+4)
  • (added) mlir/test/Target/LLVMIR/Import/disjoint.ll (+8)
  • (added) mlir/test/Target/LLVMIR/disjoint.mlir (+8)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 352e2ec91bdbea..2699a0ed14d4b3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
   ];
 }
 
+def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
+  let description = [{
+    This interface defines an LLVM operation with an disjoint flag and
+    provides a uniform API for accessing it.
+  }];
+
+  let cppNamespace = "::mlir::LLVM";
+
+  let methods = [
+    InterfaceMethod<[{
+      Get the disjoint flag for the operation.
+    }], "bool", "getIsDisjoint", (ins), [{}], [{
+      return $_op.getProperties().isDisjoint;
+    }]>,
+    InterfaceMethod<[{
+      Set the disjoint flag for the operation.
+    }], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
+      $_op.getProperties().isDisjoint = isDisjoint;
+    }]>,
+    StaticInterfaceMethod<[{
+      Get the attribute name of the isDisjoint property.
+    }], "StringRef", "getIsDisjointName", (ins), [{}], [{
+      return "isDisjoint";
+    }]>,
+  ];
+}
+
 def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
   let description = [{
     This interface defines an LLVM operation with an nneg flag and
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 34f3e4b33b8295..3a3311d8469dfd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
     "$res = builder.Create" # instName #
     "($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
 }
+class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
+                                   list<Trait> traits = []> :
+    LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
+    !listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
+  let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));
+
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+    moduleImport.setDisjointFlag(inst, op);
+    $res = op;
+  }];
+  let assemblyFormat = [{
+    (`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
+  }];
+  string llvmBuilder =
+    [{auto inst = builder.Create}] # instName #
+    [{($lhs, $rhs, /*Name=*/"");
+    moduleTranslation.setDisjointFlag(op, inst);
+    $res = inst;}];
+}
 class LLVM_FloatArithmeticOp<string mnemonic, string instName,
                              list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
 def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
 def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
 def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
-def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
+def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
   let hasFolder = 1;
 }
 def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 30164843f63675..eea0647895b01b 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -192,6 +192,11 @@ class ModuleImport {
   /// implement the exact flag interface.
   void setExactFlag(llvm::Instruction *inst, Operation *op) const;
 
+  /// Sets the disjoint flag attribute for the imported operation `op`
+  /// given the original instruction `inst`. Asserts if the operation does
+  /// not implement the disjoint flag interface.
+  void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;
+
   /// Sets the nneg flag attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
   /// implement the nneg flag interface.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0b14a665337d58 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -167,6 +167,11 @@ class ModuleTranslation {
   /// attribute.
   void setLoopMetadata(Operation *op, llvm::Instruction *inst);
 
+  /// Sets the disjoint flag attribute for the exported instruction `inst`
+  /// given the original operation `op`. Asserts if the operation does
+  /// not implement the disjoint flag interface.
+  void setDisjointFlag(Operation *op, llvm::Value *inst);
+
   /// Converts the type from MLIR LLVM dialect to LLVM.
   llvm::Type *convertType(Type type);
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 71d88d3a62f2b9..5592cc7f5df8f1 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -689,6 +689,14 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
   iface.setIsExact(inst->isExact());
 }
 
+void ModuleImport::setDisjointFlag(llvm::Instruction *inst, Operation *op) const {
+  auto iface = cast<DisjointFlagInterface>(op);
+
+  auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
+
+  iface.setIsDisjoint(inst_disjoint->isDisjoint());
+}
+
 void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
   auto iface = cast<NonNegFlagInterface>(op);
 
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..bbf567f8cf8d4c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1898,6 +1898,14 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
   inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
 }
 
+void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) {
+  auto iface = cast<DisjointFlagInterface>(op);
+
+  auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
+
+  inst_disjoint->setIsDisjoint(iface.getIsDisjoint());
+}
+
 llvm::Type *ModuleTranslation::convertType(Type type) {
   return typeTranslator.translateType(type);
 }
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index aa558bad2299ce..06f7b2d9f586fd 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
   %ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
   %lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
 
+// Integer disjoint flag.
+// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
+  %or_flag = llvm.or disjoint %arg0, %arg0 : i32
+
 // Floating point binary operations.
 //
 // CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/disjoint.ll b/mlir/test/Target/LLVMIR/Import/disjoint.ll
new file mode 100644
index 00000000000000..36091c09043525
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/disjoint.ll
@@ -0,0 +1,8 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @disjointflag_inst
+define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
+  ; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
+  %1 = or disjoint i64 %arg1, %arg2
+  ret void
+}
diff --git a/mlir/test/Target/LLVMIR/disjoint.mlir b/mlir/test/Target/LLVMIR/disjoint.mlir
new file mode 100644
index 00000000000000..1f5a42e608ba40
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/disjoint.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @disjointflag_func
+llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
+  // CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
+  %0 = llvm.or disjoint %arg0, %arg1 : i64
+  llvm.return
+}

Copy link

github-actions bot commented Nov 12, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing how many flags LLVM has :).

Thanks LGTM modulo nits.

@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
];
}

def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
let description = [{
This interface defines an LLVM operation with an disjoint flag and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This interface defines an LLVM operation with an disjoint flag and
This interface defines an LLVM operation with a disjoint flag and

nit:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 694 to 698
auto iface = cast<DisjointFlagInterface>(op);

auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);

iface.setIsDisjoint(inst_disjoint->isDisjoint());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto iface = cast<DisjointFlagInterface>(op);
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
iface.setIsDisjoint(inst_disjoint->isDisjoint());
auto iface = cast<DisjointFlagInterface>(op);
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
iface.setIsDisjoint(inst_disjoint->isDisjoint());

nit: I think it is fine to drop the newlines here. At the very least the one between the two cast instructions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 1902 to 1906
auto iface = cast<DisjointFlagInterface>(op);

auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);

inst_disjoint->setIsDisjoint(iface.getIsDisjoint());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto iface = cast<DisjointFlagInterface>(op);
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
inst_disjoint->setIsDisjoint(iface.getIsDisjoint());
auto iface = cast<DisjointFlagInterface>(op);
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
inst_disjoint->setIsDisjoint(iface.getIsDisjoint());

nit: here as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 110 to 114
string llvmBuilder =
[{auto inst = builder.Create}] # instName #
[{($lhs, $rhs, /*Name=*/"");
moduleTranslation.setDisjointFlag(op, inst);
$res = inst;}];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
string llvmBuilder =
[{auto inst = builder.Create}] # instName #
[{($lhs, $rhs, /*Name=*/"");
moduleTranslation.setDisjointFlag(op, inst);
$res = inst;}];
string llvmBuilder = [{
auto inst = builder.Create}] # instName # [{($lhs, $rhs, /*Name=*/"");
moduleTranslation.setDisjointFlag(op, inst);
$res = inst;
}];

nit: an attempt to format tablegen a bit nicer, feel free to pick it if it works or ignore it otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to work, done.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comment! Just saw two small things while looking at your changes. Once fixed, I will wait for the build bots and land.

@@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
}

void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) {
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Instruction *inst) {

just saw this. Can this be llvm::Instruction as well. Otherwise we should probably rename inst to value or so?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I try using an instruction here, I get error: invalid conversion from ‘llvm::Value*’ to ‘llvm::Instruction*’ so I`ll rename inst to value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
}

void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) {
auto iface = cast<DisjointFlagInterface>(op);
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
auto disjointInst = cast<llvm::PossiblyDisjointInst>(inst);

nit: missed this before.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename inst_disjoint in the other function as well? Sorry I forgot to mark that as well. Feel free to fix this things proactively.

Regarding the llvm::Value. I think you could use llvm::Instruction and it would cast to a llvm::Value right? That maybe preferable over using a value since it is in line with all the other helper functions we have. If it doesn't work the comment on the function needs to be fixed by replacing inst -> value.

Sorry for being a bit pedantic :).

Copy link
Contributor Author

@lfrenot lfrenot Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I forgot about the other function.

As for the Value issue, I may have been unclear, but the error I get if I use Instruction is from the llvm builder

/home/leon_frenot/fork-llvm/build/tools/mlir/include/mlir/Dialect/LLVMIR/LLVMConversions.inc:279:43: error: invalid conversion from ‘llvm::Value*’ to ‘llvm::Instruction*’ [-fpermissive]
  279 |     moduleTranslation.setDisjointFlag(op, inst);
      |                                           ^~~~
      |                                           |
      |                                           llvm::Value*

So the issue comes from builder.CreateOr returning a Value instead of an Instruction.

For now I'll simply fix the comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense!

@gysit
Copy link
Contributor

gysit commented Nov 12, 2024

It looks like there is actually some issue with a flang test (see red CI runs). I am not sure if this is new or if this has been there before the last nit changes.

@lfrenot
Copy link
Contributor Author

lfrenot commented Nov 12, 2024

Apparently, it's been here since the first commit: https://buildkite.com/llvm-project/github-pull-requests/builds/118691
I'll look into it, there seems to be an issue with the case in moduleTranslate

@lfrenot
Copy link
Contributor Author

lfrenot commented Nov 13, 2024

This should fix the flang tests @gysit
But I'm not sure it is a proper way to do it

void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) {
auto iface = cast<DisjointFlagInterface>(op);
if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(value))
disjointInst->setIsDisjoint(iface.getIsDisjoint());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What operation produces the error on flang? I would expect this function is only called for llvm.or operation? Shouldn't the generated operation also implement the interface? Or may the llvm instruction fold somehow / be a constant?

I am not agains the solution but it would be nice to understand what the root cause is and then add a comment why the cast may fail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've looked into it, and yes, what happens is that the instruction gets folded into a constant by CreateOr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating.

@gysit gysit merged commit 40afff7 into llvm:main Nov 15, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants