Skip to content

Commit e33f13b

Browse files
[mlir][arith] Add overflow flags to arith.trunci (#144863)
LLVM already supports overflow flags on `llvm.trunc` for a while. This commit adds support for these flags to `arith.trunci`.
1 parent 046e2f5 commit e33f13b

File tree

5 files changed

+42
-13
lines changed

5 files changed

+42
-13
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
226226
these is required to be the same type. This type may be an integer scalar type,
227227
a vector whose element type is integer, or a tensor of integers.
228228

229-
This op supports `nuw`/`nsw` overflow flags which stands stand for
229+
This op supports `nuw`/`nsw` overflow flags which stands for
230230
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
231231
`nsw` flags are present, and an unsigned/signed overflow occurs
232232
(respectively), the result is poison.
@@ -321,7 +321,7 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
321321
these is required to be the same type. This type may be an integer scalar type,
322322
a vector whose element type is integer, or a tensor of integers.
323323

324-
This op supports `nuw`/`nsw` overflow flags which stands stand for
324+
This op supports `nuw`/`nsw` overflow flags which stands for
325325
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
326326
`nsw` flags are present, and an unsigned/signed overflow occurs
327327
(respectively), the result is poison.
@@ -367,7 +367,7 @@ def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli",
367367
these is required to be the same type. This type may be an integer scalar type,
368368
a vector whose element type is integer, or a tensor of integers.
369369

370-
This op supports `nuw`/`nsw` overflow flags which stands stand for
370+
This op supports `nuw`/`nsw` overflow flags which stands for
371371
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
372372
`nsw` flags are present, and an unsigned/signed overflow occurs
373373
(respectively), the result is poison.
@@ -800,7 +800,7 @@ def Arith_ShLIOp : Arith_IntBinaryOpWithOverflowFlags<"shli"> {
800800
operand is greater or equal than the bitwidth of the first operand, then the
801801
operation returns poison.
802802

803-
This op supports `nuw`/`nsw` overflow flags which stands stand for
803+
This op supports `nuw`/`nsw` overflow flags which stands for
804804
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
805805
`nsw` flags are present, and an unsigned/signed overflow occurs
806806
(respectively), the result is poison.
@@ -1271,25 +1271,49 @@ def Arith_ScalingExtFOp
12711271
// TruncIOp
12721272
//===----------------------------------------------------------------------===//
12731273

1274-
def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
1274+
def Arith_TruncIOp : Op<Arith_Dialect, "trunci",
1275+
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
1276+
DeclareOpInterfaceMethods<CastOpInterface>,
1277+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>,
1278+
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]> {
12751279
let summary = "integer truncation operation";
12761280
let description = [{
12771281
The integer truncation operation takes an integer input of
12781282
width M and an integer destination type of width N. The destination
12791283
bit-width must be smaller than the input bit-width (N < M).
12801284
The top-most (N - M) bits of the input are discarded.
12811285

1286+
This op supports `nuw`/`nsw` overflow flags which stands for "No Unsigned
1287+
Wrap" and "No Signed Wrap", respectively. If the nuw keyword is present,
1288+
and any of the truncated bits are non-zero, the result is a poison value.
1289+
If the nsw keyword is present, and any of the truncated bits are not the
1290+
same as the top bit of the truncation result, the result is a poison value.
1291+
12821292
Example:
12831293

12841294
```mlir
1295+
// Scalar truncation.
12851296
%1 = arith.constant 21 : i5 // %1 is 0b10101
12861297
%2 = arith.trunci %1 : i5 to i4 // %2 is 0b0101
12871298
%3 = arith.trunci %1 : i5 to i3 // %3 is 0b101
12881299

1289-
%5 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
1300+
// Vector truncation.
1301+
%4 = arith.trunci %0 : vector<2 x i32> to vector<2 x i16>
1302+
1303+
// Scalar truncation with overflow flags.
1304+
%5 = arith.trunci %a overflow<nsw, nuw> : i32 to i16
12901305
```
12911306
}];
12921307

1308+
let arguments = (ins
1309+
SignlessFixedWidthIntegerLike:$in,
1310+
DefaultValuedAttr<Arith_IntegerOverflowAttr,
1311+
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags);
1312+
let results = (outs SignlessFixedWidthIntegerLike:$out);
1313+
let assemblyFormat = [{
1314+
$in (`overflow` `` $overflowFlags^)? attr-dict
1315+
`:` type($in) `to` type($out)
1316+
}];
12931317
let hasFolder = 1;
12941318
let hasCanonicalizer = 1;
12951319
let hasVerifier = 1;

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern<
163163
arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true,
164164
arith::AttrConverterConstrainedFPToLLVM>;
165165
using TruncIOpLowering =
166-
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp>;
166+
VectorConvertToLLVMPattern<arith::TruncIOp, LLVM::TruncOp,
167+
arith::AttrConvertOverflowToLLVM>;
167168
using UIToFPOpLowering =
168169
VectorConvertToLLVMPattern<arith::UIToFPOp, LLVM::UIToFPOp>;
169170
using XOrIOpLowering = VectorConvertToLLVMPattern<arith::XOrIOp, LLVM::XOrOp>;

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -378,31 +378,31 @@ def TruncationMatchesShiftAmount :
378378

379379
// trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
380380
def TruncIExtSIToExtSI :
381-
Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
381+
Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x), $overflow),
382382
(Arith_ExtSIOp $x),
383383
[(ValueWiderThan $ext, $tr),
384384
(ValueWiderThan $tr, $x)]>;
385385

386386
// trunci(extui(x)) -> extui(x), when only the zero-extension bits are truncated
387387
def TruncIExtUIToExtUI :
388-
Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x)),
388+
Pat<(Arith_TruncIOp:$tr (Arith_ExtUIOp:$ext $x), $overflow),
389389
(Arith_ExtUIOp $x),
390390
[(ValueWiderThan $ext, $tr),
391391
(ValueWiderThan $tr, $x)]>;
392392

393393
// trunci(shrsi(x, c)) -> trunci(shrui(x, c))
394394
def TruncIShrSIToTrunciShrUI :
395395
Pat<(Arith_TruncIOp:$tr
396-
(Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0))),
397-
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0)))),
396+
(Arith_ShRSIOp $x, (ConstantLikeMatcher TypedAttrInterface:$c0)), $overflow),
397+
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
398398
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
399399

400400
// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
401401
def TruncIShrUIMulIToMulSIExtended :
402402
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
403403
(Arith_MulIOp:$mul
404404
(Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
405-
(ConstantLikeMatcher AnyAttr:$c0))),
405+
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
406406
(Arith_MulSIExtendedOp:$res__1 $x, $y),
407407
[(ValuesWithSameType $tr, $x, $y),
408408
(ValueWiderThan $mul, $x),
@@ -413,7 +413,7 @@ def TruncIShrUIMulIToMulUIExtended :
413413
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
414414
(Arith_MulIOp:$mul
415415
(Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
416-
(ConstantLikeMatcher AnyAttr:$c0))),
416+
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
417417
(Arith_MulUIExtendedOp:$res__1 $x, $y),
418418
[(ValuesWithSameType $tr, $x, $y),
419419
(ValueWiderThan $mul, $x),

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,8 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
731731
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
732732
// CHECK: %{{.*}} = llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
733733
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
734+
// CHECK: %{{.*}} = llvm.trunc %{{.*}} overflow<nsw, nuw> : i64 to i32
735+
%4 = arith.trunci %arg0 overflow<nsw, nuw> : i64 to i32
734736
return
735737
}
736738

mlir/test/Dialect/Arith/ops.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,5 +1159,7 @@ func.func @intflags_func(%arg0: i64, %arg1: i64) {
11591159
%2 = arith.muli %arg0, %arg1 overflow<nsw, nuw> : i64
11601160
// CHECK: %{{.*}} = arith.shli %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
11611161
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
1162+
// CHECK: %{{.*}} = arith.trunci %{{.*}} overflow<nsw, nuw> : i64 to i32
1163+
%4 = arith.trunci %arg0 overflow<nsw, nuw> : i64 to i32
11621164
return
11631165
}

0 commit comments

Comments
 (0)