Skip to content

Commit ce0b93b

Browse files
committed
Add disjoint flag
1 parent e385e0d commit ce0b93b

File tree

9 files changed

+95
-1
lines changed

9 files changed

+95
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
114114
];
115115
}
116116

117+
def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
118+
let description = [{
119+
This interface defines an LLVM operation with an disjoint flag and
120+
provides a uniform API for accessing it.
121+
}];
122+
123+
let cppNamespace = "::mlir::LLVM";
124+
125+
let methods = [
126+
InterfaceMethod<[{
127+
Get the disjoint flag for the operation.
128+
}], "bool", "getIsDisjoint", (ins), [{}], [{
129+
return $_op.getProperties().isDisjoint;
130+
}]>,
131+
InterfaceMethod<[{
132+
Set the disjoint flag for the operation.
133+
}], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
134+
$_op.getProperties().isDisjoint = isDisjoint;
135+
}]>,
136+
StaticInterfaceMethod<[{
137+
Get the attribute name of the isDisjoint property.
138+
}], "StringRef", "getIsDisjointName", (ins), [{}], [{
139+
return "isDisjoint";
140+
}]>,
141+
];
142+
}
143+
117144
def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
118145
let description = [{
119146
This interface defines an LLVM operation with an nneg flag and

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
9393
"$res = builder.Create" # instName #
9494
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
9595
}
96+
class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
97+
list<Trait> traits = []> :
98+
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
99+
!listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
100+
let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));
101+
102+
string mlirBuilder = [{
103+
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
104+
moduleImport.setDisjointFlag(inst, op);
105+
$res = op;
106+
}];
107+
let assemblyFormat = [{
108+
(`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
109+
}];
110+
string llvmBuilder =
111+
[{auto inst = builder.Create}] # instName #
112+
[{($lhs, $rhs, /*Name=*/"");
113+
moduleTranslation.setDisjointFlag(op, inst);
114+
$res = inst;}];
115+
}
96116
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
97117
list<Trait> traits = []> :
98118
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
@@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
138158
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
139159
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
140160
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
141-
def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
161+
def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
142162
let hasFolder = 1;
143163
}
144164
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ class ModuleImport {
192192
/// implement the exact flag interface.
193193
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
194194

195+
/// Sets the disjoint flag attribute for the imported operation `op`
196+
/// given the original instruction `inst`. Asserts if the operation does
197+
/// not implement the disjoint flag interface.
198+
void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;
199+
195200
/// Sets the nneg flag attribute for the imported operation `op` given
196201
/// the original instruction `inst`. Asserts if the operation does not
197202
/// implement the nneg flag interface.

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ class ModuleTranslation {
167167
/// attribute.
168168
void setLoopMetadata(Operation *op, llvm::Instruction *inst);
169169

170+
/// Sets the disjoint flag attribute for the exported instruction `inst`
171+
/// given the original operation `op`. Asserts if the operation does
172+
/// not implement the disjoint flag interface.
173+
void setDisjointFlag(Operation *op, llvm::Value *inst);
174+
170175
/// Converts the type from MLIR LLVM dialect to LLVM.
171176
llvm::Type *convertType(Type type);
172177

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,15 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
689689
iface.setIsExact(inst->isExact());
690690
}
691691

692+
void ModuleImport::setDisjointFlag(llvm::Instruction *inst,
693+
Operation *op) const {
694+
auto iface = cast<DisjointFlagInterface>(op);
695+
696+
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
697+
698+
iface.setIsDisjoint(inst_disjoint->isDisjoint());
699+
}
700+
692701
void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
693702
auto iface = cast<NonNegFlagInterface>(op);
694703

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,14 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
18981898
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
18991899
}
19001900

1901+
void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *inst) {
1902+
auto iface = cast<DisjointFlagInterface>(op);
1903+
1904+
auto inst_disjoint = cast<llvm::PossiblyDisjointInst>(inst);
1905+
1906+
inst_disjoint->setIsDisjoint(iface.getIsDisjoint());
1907+
}
1908+
19011909
llvm::Type *ModuleTranslation::convertType(Type type) {
19021910
return typeTranslator.translateType(type);
19031911
}

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
5959
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
6060
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32
6161

62+
// Integer disjoint flag.
63+
// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
64+
%or_flag = llvm.or disjoint %arg0, %arg0 : i32
65+
6266
// Floating point binary operations.
6367
//
6468
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
2+
3+
; CHECK-LABEL: @disjointflag_inst
4+
define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
5+
; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
6+
%1 = or disjoint i64 %arg1, %arg2
7+
ret void
8+
}

mlir/test/Target/LLVMIR/disjoint.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: define void @disjointflag_func
4+
llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
5+
// CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
6+
%0 = llvm.or disjoint %arg0, %arg1 : i64
7+
llvm.return
8+
}

0 commit comments

Comments
 (0)