Skip to content

[mlir][arith] Add overflow flags to arith.trunci #144863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.

This op supports `nuw`/`nsw` overflow flags which stands stand for
This op supports `nuw`/`nsw` overflow flags which stands for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.
Expand Down Expand Up @@ -321,7 +321,7 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.

This op supports `nuw`/`nsw` overflow flags which stands stand for
This op supports `nuw`/`nsw` overflow flags which stands for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.
Expand Down Expand Up @@ -367,7 +367,7 @@ def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli",
these is required to be the same type. This type may be an integer scalar type,
a vector whose element type is integer, or a tensor of integers.

This op supports `nuw`/`nsw` overflow flags which stands stand for
This op supports `nuw`/`nsw` overflow flags which stands for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.
Expand Down Expand Up @@ -800,7 +800,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
operand is greater or equal than the bitwidth of the first operand, then the
operation returns poison.

This op supports `nuw`/`nsw` overflow flags which stands stand for
This op supports `nuw`/`nsw` overflow flags which stands for
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
`nsw` flags are present, and an unsigned/signed overflow occurs
(respectively), the result is poison.
Expand Down Expand Up @@ -1271,25 +1271,49 @@ def Arith_ScalingExtFOp
// TruncIOp
//===----------------------------------------------------------------------===//

def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
def Arith_TruncIOp : Op<Arith_Dialect, "trunci",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for dropping the traits from Arith_Op here? Like ElementwiseMappable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that was by accident. Sending a fix...

[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<CastOpInterface>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]> {
let summary = "integer truncation operation";
let description = [{
The integer truncation operation takes an integer input of
width M and an integer destination type of width N. The destination
bit-width must be smaller than the input bit-width (N < M).
The top-most (N - M) bits of the input are discarded.

This op supports `nuw`/`nsw` overflow flags which stands for "No Unsigned
Wrap" and "No Signed Wrap", respectively. If the nuw keyword is present,
and any of the truncated bits are non-zero, the result is a poison value.
If the nsw keyword is present, and any of the truncated bits are not the
same as the top bit of the truncation result, the result is a poison value.

Example:

```mlir
// Scalar truncation.
%1 = arith.constant 21 : i5 // %1 is 0b10101
%2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101
%3 = arith.trunci %1 : i5 to i3 // %3 is 0b101

%5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
// Vector truncation.
%4 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>

// Scalar truncation with overflow flags.
%5 = arith.trunci %a overflow<nsw, nuw> : i32 to i16
```
}];

let arguments = (ins
SignlessFixedWidthIntegerLike:$in,
DefaultValuedAttr<Arith_IntegerOverflowAttr,
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags);
let results = (outs SignlessFixedWidthIntegerLike:$out);
let assemblyFormat = [{
$in (`overflow` `` $overflowFlags^)? attr-dict
`:` type($in) `to` type($out)
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
arith::AttrConverterConstrainedFPToLLVM>;
using TruncIOpLowering =
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
arith::AttrConvertOverflowToLLVM>;
using UIToFPOpLowering =
VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -378,31 +378,31 @@ def TruncationMatchesShiftAmount :

// trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
def TruncIExtSIToExtSI :
Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x), $overflow),
(Arith_ExtSIOp $x),
[(ValueWiderThan $ext, $tr),
(ValueWiderThan $tr, $x)]>;

// trunci(extui(x)) -> extui(x), when only the zero-extension bits are truncated
def TruncIExtUIToExtUI :
Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x)),
Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x), $overflow),
(Arith_ExtUIOp $x),
[(ValueWiderThan $ext, $tr),
(ValueWiderThan $tr, $x)]>;

// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
def TruncIShrSIToTrunciShrUI :
Pat<(Arith_TruncIOp:$tr
(Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0))),
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))),
(Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow),
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;

// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
def TruncIShrUIMulIToMulSIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
(ConstantLikeMatcher AnyAttr:$c0))),
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
(Arith_MulSIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
Expand All @@ -413,7 +413,7 @@ def TruncIShrUIMulIToMulUIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
(ConstantLikeMatcher AnyAttr:$c0))),
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am about to delete the two patterns above #144844 but I guess you will be faster landing since your PR is already approved :).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can rebase now.

(Arith_MulUIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
Expand Down
2 changes: 2 additions & 0 deletions mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,8 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = llvm.trunc %{{.*}} overflow<nsw, nuw> : i64 to i32
%4 = arith.trunci %arg0 overflow<nsw, nuw> : i64 to i32
return
}

Expand Down
2 changes: 2 additions & 0 deletions mlir/test/Dialect/Arith/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1159,5 +1159,7 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = arith.shli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
// CHECK: %{{.*}} = arith.trunci %{{.*}} overflow<nsw, nuw> : i64 to i32
%4 = arith.trunci %arg0 overflow<nsw, nuw> : i64 to i32
return
}
Loading