Skip to content

Commit 40afff7

Browse files
authored
[mlir][LLVM] Add disjoint flag (#115855)
The implementation is mostly based on the one existing for the exact flag. disjoint means that for each bit, that bit is zero in at least one of the inputs. This allows the Or to be treated as an Add since no carry can occur from any bit. If the disjoint keyword is present, the result value of the or is a [poison value](https://llvm.org/docs/LangRef.html#poisonvalues) if both inputs have a one in the same bit position. For vectors, only the element containing the bit is poison.
1 parent 6d05831 commit 40afff7

File tree

9 files changed

+94
-1
lines changed

9 files changed

+94
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
114114
];
115115
}
116116

117+
def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
118+
let description = [{
119+
This interface defines an LLVM operation with a disjoint flag and
120+
provides a uniform API for accessing it.
121+
}];
122+
123+
let cppNamespace = "::mlir::LLVM";
124+
125+
let methods = [
126+
InterfaceMethod<[{
127+
Get the disjoint flag for the operation.
128+
}], "bool", "getIsDisjoint", (ins), [{}], [{
129+
return $_op.getProperties().isDisjoint;
130+
}]>,
131+
InterfaceMethod<[{
132+
Set the disjoint flag for the operation.
133+
}], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
134+
$_op.getProperties().isDisjoint = isDisjoint;
135+
}]>,
136+
StaticInterfaceMethod<[{
137+
Get the attribute name of the isDisjoint property.
138+
}], "StringRef", "getIsDisjointName", (ins), [{}], [{
139+
return "isDisjoint";
140+
}]>,
141+
];
142+
}
143+
117144
def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
118145
let description = [{
119146
This interface defines an LLVM operation with an nneg flag and

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
9393
"$res = builder.Create" # instName #
9494
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
9595
}
96+
class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
97+
list<Trait> traits = []> :
98+
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
99+
!listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
100+
let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));
101+
102+
string mlirBuilder = [{
103+
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
104+
moduleImport.setDisjointFlag(inst, op);
105+
$res = op;
106+
}];
107+
let assemblyFormat = [{
108+
(`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
109+
}];
110+
string llvmBuilder = [{
111+
auto inst = builder.Create}] # instName # [{($lhs, $rhs, /*Name=*/"");
112+
moduleTranslation.setDisjointFlag(op, inst);
113+
$res = inst;
114+
}];
115+
}
96116
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
97117
list<Trait> traits = []> :
98118
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
138158
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
139159
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
140160
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
141-
def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
161+
def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
142162
let hasFolder = 1;
143163
}
144164
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ class ModuleImport {
192192
/// implement the exact flag interface.
193193
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
194194

195+
/// Sets the disjoint flag attribute for the imported operation `op`
196+
/// given the original instruction `inst`. Asserts if the operation does
197+
/// not implement the disjoint flag interface.
198+
void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;
199+
195200
/// Sets the nneg flag attribute for the imported operation `op` given
196201
/// the original instruction `inst`. Asserts if the operation does not
197202
/// implement the nneg flag interface.

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,12 @@ class ModuleTranslation {
167167
/// attribute.
168168
void setLoopMetadata(Operation *op, llvm::Instruction *inst);
169169

170+
/// Sets the disjoint flag attribute for the exported instruction `value`
171+
/// given the original operation `op`. Asserts if the operation does
172+
/// not implement the disjoint flag interface, and asserts if the value
173+
/// is an instruction that implements the disjoint flag.
174+
void setDisjointFlag(Operation *op, llvm::Value *value);
175+
170176
/// Converts the type from MLIR LLVM dialect to LLVM.
171177
llvm::Type *convertType(Type type);
172178

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,14 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
689689
iface.setIsExact(inst->isExact());
690690
}
691691

692+
void ModuleImport::setDisjointFlag(llvm::Instruction *inst,
693+
Operation *op) const {
694+
auto iface = cast<DisjointFlagInterface>(op);
695+
auto instDisjoint = cast<llvm::PossiblyDisjointInst>(inst);
696+
697+
iface.setIsDisjoint(instDisjoint->isDisjoint());
698+
}
699+
692700
void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
693701
auto iface = cast<NonNegFlagInterface>(op);
694702

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
18981898
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
18991899
}
19001900

1901+
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) {
1902+
auto iface = cast<DisjointFlagInterface>(op);
1903+
// We do a dyn_cast here in case the value got folded into a constant.
1904+
if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(value))
1905+
disjointInst->setIsDisjoint(iface.getIsDisjoint());
1906+
}
1907+
19011908
llvm::Type *ModuleTranslation::convertType(Type type) {
19021909
return typeTranslator.translateType(type);
19031910
}

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
5959
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
6060
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
6161

62+
// Integer disjoint flag.
63+
// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
64+
%or_flag = llvm.or disjoint %arg0, %arg0 : i32
65+
6266
// Floating point binary operations.
6367
//
6468
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
2+
3+
; CHECK-LABEL: @disjointflag_inst
4+
define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
5+
; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
6+
%1 = or disjoint i64 %arg1, %arg2
7+
ret void
8+
}

mlir/test/Target/LLVMIR/disjoint.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: define void @disjointflag_func
4+
llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
5+
// CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
6+
%0 = llvm.or disjoint %arg0, %arg1 : i64
7+
llvm.return
8+
}

0 commit comments

Comments
 (0)