Skip to content

Commit 89aaf2c

Browse files
authored
[mlir][LLVM] Add nneg flag (#115498)
This implementation is based on the existing one for the exact flag. If the nneg flag is set and the argument is negative, the result is a poison value.
1 parent 5e7662e commit 89aaf2c

File tree

7 files changed

+90
-2
lines changed

7 files changed

+90
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,33 @@ def ExactFlagInterface : OpInterface<"ExactFlagInterface"> {
114114
];
115115
}
116116

117+
def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
118+
let description = [{
119+
This interface defines an LLVM operation with an nneg flag and
120+
provides a uniform API for accessing it.
121+
}];
122+
123+
let cppNamespace = "::mlir::LLVM";
124+
125+
let methods = [
126+
InterfaceMethod<[{
127+
Get the nneg flag for the operation.
128+
}], "bool", "getNonNeg", (ins), [{}], [{
129+
return $_op.getProperties().nonNeg;
130+
}]>,
131+
InterfaceMethod<[{
132+
Set the nneg flag for the operation.
133+
}], "void", "setNonNeg", (ins "bool":$nonNeg), [{}], [{
134+
$_op.getProperties().nonNeg = nonNeg;
135+
}]>,
136+
StaticInterfaceMethod<[{
137+
Get the attribute name of the nonNeg property.
138+
}], "StringRef", "getNonNegName", (ins), [{}], [{
139+
return "nonNeg";
140+
}]>,
141+
];
142+
}
143+
117144
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
118145
let description = [{
119146
An interface for operations that can carry branch weights metadata. It

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,23 @@ class LLVM_CastOp<string mnemonic, string instName, Type type,
508508
$_location, $_resultType, $arg);
509509
}];
510510
}
511+
class LLVM_CastOpWithNNegFlag<string mnemonic, string instName, Type type,
512+
Type resultType, list<Trait> traits = []> :
513+
LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<NonNegFlagInterface>], traits)>,
514+
LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.getNonNeg());"> {
515+
let arguments = (ins type:$arg, UnitAttr:$nonNeg);
516+
let results = (outs resultType:$res);
517+
let builders = [LLVM_OneResultOpBuilder];
518+
let assemblyFormat = "(`nneg` $nonNeg^)? $arg attr-dict `:` type($arg) `to` type($res)";
519+
string llvmInstName = instName;
520+
string mlirBuilder = [{
521+
auto op = $_builder.create<$_qualCppClassName>(
522+
$_location, $_resultType, $arg);
523+
moduleImport.setNonNegFlag(inst, op);
524+
$res = op;
525+
}];
526+
}
527+
511528
def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
512529
LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
513530
let hasFolder = 1;
@@ -531,7 +548,7 @@ def LLVM_SExtOp : LLVM_CastOp<"sext", "SExt",
531548
LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
532549
let hasVerifier = 1;
533550
}
534-
def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
551+
def LLVM_ZExtOp : LLVM_CastOpWithNNegFlag<"zext", "ZExt",
535552
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
536553
LLVM_ScalarOrVectorOf<AnySignlessInteger>> {
537554
let hasFolder = 1;
@@ -543,7 +560,7 @@ def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
543560
def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "SIToFP",
544561
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
545562
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
546-
def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "UIToFP",
563+
def LLVM_UIToFPOp : LLVM_CastOpWithNNegFlag<"uitofp", "UIToFP",
547564
LLVM_ScalarOrVectorOf<AnySignlessInteger>,
548565
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
549566
def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "FPToSI",

mlir/include/mlir/Target/LLVMIR/ModuleImport.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ class ModuleImport {
192192
/// implement the exact flag interface.
193193
void setExactFlag(llvm::Instruction *inst, Operation *op) const;
194194

195+
/// Sets the nneg flag attribute for the imported operation `op` given
196+
/// the original instruction `inst`. Asserts if the operation does not
197+
/// implement the nneg flag interface.
198+
void setNonNegFlag(llvm::Instruction *inst, Operation *op) const;
199+
195200
/// Sets the fastmath flags attribute for the imported operation `op` given
196201
/// the original instruction `inst`. Asserts if the operation does not
197202
/// implement the fastmath interface.

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,12 @@ void ModuleImport::setExactFlag(llvm::Instruction *inst, Operation *op) const {
689689
iface.setIsExact(inst->isExact());
690690
}
691691

692+
void ModuleImport::setNonNegFlag(llvm::Instruction *inst, Operation *op) const {
693+
auto iface = cast<NonNegFlagInterface>(op);
694+
695+
iface.setNonNeg(inst->hasNonNeg());
696+
}
697+
692698
void ModuleImport::setFastmathFlagsAttr(llvm::Instruction *inst,
693699
Operation *op) const {
694700
auto iface = cast<FastmathFlagsInterface>(op);

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,19 @@ func.func @casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
325325
llvm.return
326326
}
327327

328+
// CHECK-LABEL: @nneg_casts
329+
// CHECK-SAME: (%[[I32:.*]]: i32, %[[I64:.*]]: i64, %[[V4I32:.*]]: vector<4xi32>, %[[V4I64:.*]]: vector<4xi64>, %[[PTR:.*]]: !llvm.ptr)
330+
func.func @nneg_casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
331+
%arg3: vector<4xi64>, %arg4: !llvm.ptr) {
332+
// CHECK: = llvm.zext nneg %[[I32]] : i32 to i64
333+
%0 = llvm.zext nneg %arg0 : i32 to i64
334+
// CHECK: = llvm.zext nneg %[[V4I32]] : vector<4xi32> to vector<4xi64>
335+
%4 = llvm.zext nneg %arg2 : vector<4xi32> to vector<4xi64>
336+
// CHECK: = llvm.uitofp nneg %[[I32]] : i32 to f32
337+
%7 = llvm.uitofp nneg %arg0 : i32 to f32
338+
llvm.return
339+
}
340+
328341
// CHECK-LABEL: @vect
329342
func.func @vect(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32, %arg3: !llvm.vec<2 x ptr>) {
330343
// CHECK: = llvm.extractelement {{.*}} : vector<4xf32>
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s
2+
3+
; CHECK-LABEL: @nnegflag_inst
4+
define void @nnegflag_inst(i32 %arg1) {
5+
; CHECK: llvm.zext nneg %{{.*}} : i32 to i64
6+
%1 = zext nneg i32 %arg1 to i64
7+
; CHECK: llvm.uitofp nneg %{{.*}} : i32 to f32
8+
%2 = uitofp nneg i32 %arg1 to float
9+
ret void
10+
}

mlir/test/Target/LLVMIR/nneg.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
// CHECK-LABEL: define void @nnegflag_func
4+
llvm.func @nnegflag_func(%arg0: i32) {
5+
// CHECK: %{{.*}} = zext nneg i32 %{{.*}} to i64
6+
%0 = llvm.zext nneg %arg0 : i32 to i64
7+
// CHECK: %{{.*}} = uitofp nneg i32 %{{.*}} to float
8+
%1 = llvm.uitofp nneg %arg0 : i32 to f32
9+
llvm.return
10+
}

0 commit comments

Comments
 (0)