Skip to content

[mlir][LLVM] Add disjoint flag #115855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
];
}

def DisjointFlagInterface : OpInterface<"DisjointFlagInterface"> {
let description = [{
This interface defines an LLVM operation with a disjoint flag and
provides a uniform API for accessing it.
}];

let cppNamespace = "::mlir::LLVM";

let methods = [
InterfaceMethod<[{
Get the disjoint flag for the operation.
}], "bool", "getIsDisjoint", (ins), [{}], [{
return $_op.getProperties().isDisjoint;
}]>,
InterfaceMethod<[{
Set the disjoint flag for the operation.
}], "void", "setIsDisjoint", (ins "bool":$isDisjoint), [{}], [{
$_op.getProperties().isDisjoint = isDisjoint;
}]>,
StaticInterfaceMethod<[{
Get the attribute name of the isDisjoint property.
}], "StringRef", "getIsDisjointName", (ins), [{}], [{
return "isDisjoint";
}]>,
];
}

def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
let description = [{
This interface defines an LLVM operation with an nneg flag and
Expand Down
22 changes: 21 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ class LLVM_IntArithmeticOpWithExactFlag<string mnemonic, string instName,
"$res = builder.Create" # instName #
"($lhs, $rhs, /*Name=*/\"\", op.getIsExact());";
}
class LLVM_IntArithmeticOpWithDisjointFlag<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<DisjointFlagInterface>], traits)> {
let arguments = !con(commonArgs, (ins UnitAttr:$isDisjoint));

string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
moduleImport.setDisjointFlag(inst, op);
$res = op;
}];
let assemblyFormat = [{
(`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)
}];
string llvmBuilder = [{
auto inst = builder.Create}] # instName # [{($lhs, $rhs, /*Name=*/"");
moduleTranslation.setDisjointFlag(op, inst);
$res = inst;
}];
}
class LLVM_FloatArithmeticOp<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyFloat, mnemonic, instName,
Expand Down Expand Up @@ -138,7 +158,7 @@ def LLVM_SDivOp : LLVM_IntArithmeticOpWithExactFlag<"sdiv", "SDiv">;
def LLVM_URemOp : LLVM_IntArithmeticOp<"urem", "URem">;
def LLVM_SRemOp : LLVM_IntArithmeticOp<"srem", "SRem">;
def LLVM_AndOp : LLVM_IntArithmeticOp<"and", "And">;
def LLVM_OrOp : LLVM_IntArithmeticOp<"or", "Or"> {
def LLVM_OrOp : LLVM_IntArithmeticOpWithDisjointFlag<"or", "Or"> {
let hasFolder = 1;
}
def LLVM_XOrOp : LLVM_IntArithmeticOp<"xor", "Xor">;
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ class ModuleImport {
/// implement the exact flag interface.
void setExactFlag(llvm::Instruction *inst, Operation *op) const;

/// Sets the disjoint flag attribute for the imported operation `op`
/// given the original instruction `inst`. Asserts if the operation does
/// not implement the disjoint flag interface.
void setDisjointFlag(llvm::Instruction *inst, Operation *op) const;

/// Sets the nneg flag attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
/// implement the nneg flag interface.
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ class ModuleTranslation {
/// attribute.
void setLoopMetadata(Operation *op, llvm::Instruction *inst);

/// Sets the disjoint flag attribute for the exported instruction `value`
/// given the original operation `op`. Asserts if the operation does
/// not implement the disjoint flag interface, and asserts if the value
/// is an instruction that implements the disjoint flag.
void setDisjointFlag(Operation *op, llvm::Value *value);

/// Converts the type from MLIR LLVM dialect to LLVM.
llvm::Type *convertType(Type type);

Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,14 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
iface.setIsExact(inst->isExact());
}

void ModuleImport::setDisjointFlag(llvm::Instruction *inst,
Operation *op) const {
auto iface = cast<DisjointFlagInterface>(op);
auto instDisjoint = cast<llvm::PossiblyDisjointInst>(inst);

iface.setIsDisjoint(instDisjoint->isDisjoint());
}

void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
auto iface = cast<NonNegFlagInterface>(op);

Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,13 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
}

void ModuleTranslation::setDisjointFlag(Operation *op, llvm::Value *value) {
auto iface = cast<DisjointFlagInterface>(op);
// We do a dyn_cast here in case the value got folded into a constant.
if (auto disjointInst = dyn_cast<llvm::PossiblyDisjointInst>(value))
disjointInst->setIsDisjoint(iface.getIsDisjoint());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What operation produces the error on flang? I would expect this function is only called for llvm.or operation? Shouldn't the generated operation also implement the interface? Or may the llvm instruction fold somehow / be a constant?

I am not agains the solution but it would be nice to understand what the root cause is and then add a comment why the cast may fail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've looked into it, and yes, what happens is that the instruction gets folded into a constant by CreateOr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for investigating.

}

llvm::Type *ModuleTranslation::convertType(Type type) {
return typeTranslator.translateType(type);
}
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ func.func @ops(%arg0: i32, %arg1: f32,
%ashr_flag = llvm.ashr exact %arg0, %arg0 : i32
%lshr_flag = llvm.lshr exact %arg0, %arg0 : i32

// Integer disjoint flag.
// CHECK: {{.*}} = llvm.or disjoint %[[I32]], %[[I32]] : i32
%or_flag = llvm.or disjoint %arg0, %arg0 : i32

// Floating point binary operations.
//
// CHECK: {{.*}} = llvm.fadd %[[FLOAT]], %[[FLOAT]] : f32
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Target/LLVMIR/Import/disjoint.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s

; CHECK-LABEL: @disjointflag_inst
define void @disjointflag_inst(i64 %arg1, i64 %arg2) {
; CHECK: llvm.or disjoint %{{.*}}, %{{.*}} : i64
%1 = or disjoint i64 %arg1, %arg2
ret void
}
8 changes: 8 additions & 0 deletions mlir/test/Target/LLVMIR/disjoint.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

// CHECK-LABEL: define void @disjointflag_func
llvm.func @disjointflag_func(%arg0: i64, %arg1: i64) {
// CHECK: %{{.*}} = or disjoint i64 %{{.*}}, %{{.*}}
%0 = llvm.or disjoint %arg0, %arg1 : i64
llvm.return
}
Loading