-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][LLVM] Add disjoint flag #115855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][LLVM] Add disjoint flag #115855
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: None (lfrenot) ChangesThe implementation is mostly based on the one existing for the exact flag. disjoint means that for each bit, that bit is zero in at least one of the inputs. This allows the Or to be treated as an Add since no carry can occur from any bit. If the disjoint keyword is present, the result value of the or is a poison value if both inputs have a one in the same bit position. For vectors, only the element containing the bit is poison. Full diff: https://github.com/llvm/llvm-project/pull/115855.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index 352e2ec91bdbea..2699a0ed14d4b3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
];
}
+def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
+ let description = [{
+ This interface defines an LLVM operation with an disjoint flag and
+ provides a uniform API for accessing it.
+ }];
+
+ let cppNamespace = "::mlir::LLVM";
+
+ let methods = [
+ InterfaceMethod<[{
+ Get the disjoint flag for the operation.
+ }], "bool", "getIsDisjoint", (ins), [{}], [{
+ return $_op.getProperties().isDisjoint;
+ }]>,
+ InterfaceMethod<[{
+ Set the disjoint flag for the operation.
+ }], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
+ $_op.getProperties().isDisjoint = isDisjoint;
+ }]>,
+ StaticInterfaceMethod<[{
+ Get the attribute name of the isDisjoint property.
+ }], "StringRef", "getIsDisjointName", (ins), [{}], [{
+ return "isDisjoint";
+ }]>,
+ ];
+}
+
def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
let description = [{
This interface defines an LLVM operation with an nneg flag and
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 34f3e4b33b8295..3a3311d8469dfd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
}
+class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
+ list<Trait> traits = []> :
+ LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
+ !listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
+ let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));
+
+ string mlirBuilder = [{
+ auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
+ moduleImport.setDisjointFlag(inst, op);
+ $res = op;
+ }];
+ let assemblyFormat = [{
+ (`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
+ }];
+ string llvmBuilder =
+ [{auto inst = builder.Create}] # instName #
+ [{($lhs, $rhs, /*Name=*/"");
+ moduleTranslation.setDisjointFlag(op, inst);
+ $res = inst;}];
+}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
-def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
+def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
let hasFolder = 1;
}
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 30164843f63675..eea0647895b01b 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -192,6 +192,11 @@ class ModuleImport {
/// implement the exact flag interface.
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
+ /// Sets the disjoint flag attribute for the imported operation `op`
+ /// given the original instruction `inst`. Asserts if the operation does
+ /// not implement the disjoint flag interface.
+ void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;
+
/// Sets the nneg flag attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the nneg flag interface.
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
index ffeeeae57ae952..0b14a665337d58 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
@@ -167,6 +167,11 @@ class ModuleTranslation {
/// attribute.
void setLoopMetadata(Operation *op, llvm::Instruction *inst);
+ /// Sets the disjoint flag attribute for the exported instruction `inst`
+ /// given the original operation `op`. Asserts if the operation does
+ /// not implement the disjoint flag interface.
+ void setDisjointFlag(Operation *op, llvm::Value *inst);
+
/// Converts the type from MLIR LLVM dialect to LLVM.
llvm::Type *convertType(Type type);
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 71d88d3a62f2b9..5592cc7f5df8f1 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -689,6 +689,14 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
iface.setIsExact(inst->isExact());
}
+void ModuleImport::setDisjointFlag(llvm::Instruction *inst, Operation *op) const {
+ auto iface = cast<DisjointFlagInterface>(op);
+
+ auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
+
+ iface.setIsDisjoint(inst_disjoint->isDisjoint());
+}
+
void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
auto iface = cast<NonNegFlagInterface>(op);
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ceb8ba3b33818b..bbf567f8cf8d4c 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1898,6 +1898,14 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
}
+void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) {
+ auto iface = cast<DisjointFlagInterface>(op);
+
+ auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
+
+ inst_disjoint->setIsDisjoint(iface.getIsDisjoint());
+}
+
llvm::Type *ModuleTranslation::convertType(Type type) {
return typeTranslator.translateType(type);
}
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index aa558bad2299ce..06f7b2d9f586fd 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
+// Integer disjoint flag.
+// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
+ %or_flag = llvm.or disjoint %arg0, %arg0 : i32
+
// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
diff --git a/mlir/test/Target/LLVMIR/Import/disjoint.ll b/mlir/test/Target/LLVMIR/Import/disjoint.ll
new file mode 100644
index 00000000000000..36091c09043525
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/Import/disjoint.ll
@@ -0,0 +1,8 @@
+; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
+
+; CHECK-LABEL: @disjointflag_inst
+define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
+ ; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
+ %1 = or disjoint i64 %arg1, %arg2
+ ret void
+}
diff --git a/mlir/test/Target/LLVMIR/disjoint.mlir b/mlir/test/Target/LLVMIR/disjoint.mlir
new file mode 100644
index 00000000000000..1f5a42e608ba40
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/disjoint.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define void @disjointflag_func
+llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
+ // CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
+ %0 = llvm.or disjoint %arg0, %arg1 : i64
+ llvm.return
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
b1ae98a
to
ce0b93b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing how many flags LLVM has :).
Thanks LGTM modulo nits.
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> { | |||
]; | |||
} | |||
|
|||
def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> { | |||
let description = [{ | |||
This interface defines an LLVM operation with an disjoint flag and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This interface defines an LLVM operation with an disjoint flag and | |
This interface defines an LLVM operation with a disjoint flag and |
nit:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
auto iface = cast<DisjointFlagInterface>(op); | ||
|
||
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst); | ||
|
||
iface.setIsDisjoint(inst_disjoint->isDisjoint()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto iface = cast<DisjointFlagInterface>(op); | |
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst); | |
iface.setIsDisjoint(inst_disjoint->isDisjoint()); | |
auto iface = cast<DisjointFlagInterface>(op); | |
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst); | |
iface.setIsDisjoint(inst_disjoint->isDisjoint()); |
nit: I think it is fine to drop the newlines here. At the very least the one between the two cast instructions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
auto iface = cast<DisjointFlagInterface>(op); | ||
|
||
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst); | ||
|
||
inst_disjoint->setIsDisjoint(iface.getIsDisjoint()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto iface = cast<DisjointFlagInterface>(op); | |
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst); | |
inst_disjoint->setIsDisjoint(iface.getIsDisjoint()); | |
auto iface = cast<DisjointFlagInterface>(op); | |
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst); | |
inst_disjoint->setIsDisjoint(iface.getIsDisjoint()); |
nit: here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
string llvmBuilder = | ||
[{auto inst = builder.Create}] # instName # | ||
[{($lhs, $rhs, /*Name=*/""); | ||
moduleTranslation.setDisjointFlag(op, inst); | ||
$res = inst;}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
string llvmBuilder = | |
[{auto inst = builder.Create}] # instName # | |
[{($lhs, $rhs, /*Name=*/""); | |
moduleTranslation.setDisjointFlag(op, inst); | |
$res = inst;}]; | |
string llvmBuilder = [{ | |
auto inst = builder.Create}] # instName # [{($lhs, $rhs, /*Name=*/""); | |
moduleTranslation.setDisjointFlag(op, inst); | |
$res = inst; | |
}]; |
nit: an attempt to format tablegen a bit nicer, feel free to pick it if it works or ignore it otherwise.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to work, done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the comment! Just saw two small things while looking at your changes. Once fixed, I will wait for the build bots and land.
@@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op, | |||
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD); | |||
} | |||
|
|||
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) { | |
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Instruction *inst) { |
just saw this. Can this be llvm::Instruction as well. Otherwise we should probably rename inst to value or so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I try using an instruction here, I get error: invalid conversion from ‘llvm::Value*’ to ‘llvm::Instruction*’
so I`ll rename inst to value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op, | |||
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD); | |||
} | |||
|
|||
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) { | |||
auto iface = cast<DisjointFlagInterface>(op); | |||
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst); | |
auto disjointInst = cast<llvm::PossiblyDisjointInst>(inst); |
nit: missed this before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename inst_disjoint in the other function as well? Sorry I forgot to mark that as well. Feel free to fix this things proactively.
Regarding the llvm::Value. I think you could use llvm::Instruction and it would cast to a llvm::Value right? That maybe preferable over using a value since it is in line with all the other helper functions we have. If it doesn't work the comment on the function needs to be fixed by replacing inst -> value.
Sorry for being a bit pedantic :).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I forgot about the other function.
As for the Value issue, I may have been unclear, but the error I get if I use Instruction is from the llvm builder
/home/leon_frenot/fork-llvm/build/tools/mlir/include/mlir/Dialect/LLVMIR/LLVMConversions.inc:279:43: error: invalid conversion from ‘llvm::Value*’ to ‘llvm::Instruction*’ [-fpermissive]
279 | moduleTranslation.setDisjointFlag(op, inst);
| ^~~~
| |
| llvm::Value*
So the issue comes from builder.CreateOr returning a Value instead of an Instruction.
For now I'll simply fix the comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense!
It looks like there is actually some issue with a flang test (see red CI runs). I am not sure if this is new or if this has been there before the last nit changes. |
Apparently, it's been here since the first commit: https://buildkite.com/llvm-project/github-pull-requests/builds/118691 |
This should fix the flang tests @gysit |
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) { | ||
auto iface = cast<DisjointFlagInterface>(op); | ||
if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(value)) | ||
disjointInst->setIsDisjoint(iface.getIsDisjoint()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What operation produces the error on flang? I would expect this function is only called for llvm.or operation? Shouldn't the generated operation also implement the interface? Or may the llvm instruction fold somehow / be a constant?
I am not agains the solution but it would be nice to understand what the root cause is and then add a comment why the cast may fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've looked into it, and yes, what happens is that the instruction gets folded into a constant by CreateOr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for investigating.
The implementation is mostly based on the one existing for the exact flag.
disjoint means that for each bit, that bit is zero in at least one of the inputs. This allows the Or to be treated as an Add since no carry can occur from any bit. If the disjoint keyword is present, the result value of the or is a poison value if both inputs have a one in the same bit position. For vectors, only the element containing the bit is poison.