Skip to content

Commit a1a5594

Browse files
authored
[mlir][arith] add wide integer emulation support for subi (llvm#133248)
Adds wide integer emulation support for the `arith.subi` op. `(i2N, i2N) -> (i2N)` ops are emulated as `(vector<2xiN>, vector<2xiN>) -> (vector<2xiN>)`, just as the other emulation patterns. The emulation uses the following scheme: ``` resLow = lhsLow - rhsLow; // carry = 1 if rhsLow > lhsLow resHigh = lhsLow - carry - rhsLow; ``` Signed-off-by: Ege Beysel <[email protected]>
1 parent 2fb53f5 commit a1a5594

File tree

3 files changed

+196
-19
lines changed

3 files changed

+196
-19
lines changed

mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,46 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> {
866866
}
867867
};
868868

869+
//===----------------------------------------------------------------------===//
870+
// ConvertSubI
871+
//===----------------------------------------------------------------------===//
872+
873+
struct ConvertSubI final : OpConversionPattern<arith::SubIOp> {
874+
using OpConversionPattern::OpConversionPattern;
875+
876+
LogicalResult
877+
matchAndRewrite(arith::SubIOp op, OpAdaptor adaptor,
878+
ConversionPatternRewriter &rewriter) const override {
879+
Location loc = op->getLoc();
880+
auto newTy = getTypeConverter()->convertType<VectorType>(op.getType());
881+
if (!newTy)
882+
return rewriter.notifyMatchFailure(
883+
loc, llvm::formatv("unsupported type: {}", op.getType()));
884+
885+
Type newElemTy = reduceInnermostDim(newTy);
886+
887+
auto [lhsElem0, lhsElem1] =
888+
extractLastDimHalves(rewriter, loc, adaptor.getLhs());
889+
auto [rhsElem0, rhsElem1] =
890+
extractLastDimHalves(rewriter, loc, adaptor.getRhs());
891+
892+
// Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where
893+
// CARRY is 1 or 0.
894+
Value low = rewriter.create<arith::SubIOp>(loc, lhsElem0, rhsElem0);
895+
// We have a carry if lhsElem0 < rhsElem0.
896+
Value carry0 = rewriter.create<arith::CmpIOp>(
897+
loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0);
898+
Value carryVal = rewriter.create<arith::ExtUIOp>(loc, newElemTy, carry0);
899+
900+
Value high0 = rewriter.create<arith::SubIOp>(loc, lhsElem1, carryVal);
901+
Value high = rewriter.create<arith::SubIOp>(loc, high0, rhsElem1);
902+
903+
Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high});
904+
rewriter.replaceOp(op, resultVec);
905+
return success();
906+
}
907+
};
908+
869909
//===----------------------------------------------------------------------===//
870910
// ConvertSIToFP
871911
//===----------------------------------------------------------------------===//
@@ -885,22 +925,16 @@ struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> {
885925
return rewriter.notifyMatchFailure(
886926
loc, llvm::formatv("unsupported type: {0}", oldTy));
887927

888-
unsigned oldBitWidth = getElementTypeOrSelf(oldTy).getIntOrFloatBitWidth();
889928
Value zeroCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 0);
890-
Value oneCst = createScalarOrSplatConstant(rewriter, loc, oldTy, 1);
891-
Value allOnesCst = createScalarOrSplatConstant(
892-
rewriter, loc, oldTy, APInt::getAllOnes(oldBitWidth));
893929

894930
// To avoid operating on very large unsigned numbers, perform the
895931
// conversion on the absolute value. Then, decide whether to negate the
896-
// result or not based on that sign bit. We assume two's complement and
897-
// implement negation by flipping all bits and adding 1.
898-
// Note that this relies on the the other conversion patterns to legalize
899-
// created ops and narrow the bit widths.
932+
// result or not based on that sign bit. We implement negation by
933+
// subtracting from zero. Note that this relies on the the other conversion
934+
// patterns to legalize created ops and narrow the bit widths.
900935
Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
901936
in, zeroCst);
902-
Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, in, allOnesCst);
903-
Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);
937+
Value neg = rewriter.create<arith::SubIOp>(loc, zeroCst, in);
904938
Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in);
905939

906940
Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs);
@@ -1139,7 +1173,7 @@ void arith::populateArithWideIntEmulationPatterns(
11391173
ConvertMaxMin<arith::MaxUIOp, arith::CmpIPredicate::ugt>,
11401174
ConvertMaxMin<arith::MaxSIOp, arith::CmpIPredicate::sgt>,
11411175
ConvertMaxMin<arith::MinUIOp, arith::CmpIPredicate::ult>,
1142-
ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>,
1176+
ConvertMaxMin<arith::MinSIOp, arith::CmpIPredicate::slt>, ConvertSubI,
11431177
// Bitwise binary ops.
11441178
ConvertBitwiseBinary<arith::AndIOp>, ConvertBitwiseBinary<arith::OrIOp>,
11451179
ConvertBitwiseBinary<arith::XOrIOp>,

mlir/test/Dialect/Arith/emulate-wide-int.mlir

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,44 @@ func.func @addi_vector_a_b(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi
130130
return %x : vector<4xi64>
131131
}
132132

133+
// CHECK-LABEL: func @subi_scalar
134+
// CHECK-SAME: ([[ARG0:%.+]]: vector<2xi32>, [[ARG1:%.+]]: vector<2xi32>) -> vector<2xi32>
135+
// CHECK-NEXT: [[LOW0:%.+]] = vector.extract [[ARG0]][0] : i32 from vector<2xi32>
136+
// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract [[ARG0]][1] : i32 from vector<2xi32>
137+
// CHECK-NEXT: [[LOW1:%.+]] = vector.extract [[ARG1]][0] : i32 from vector<2xi32>
138+
// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract [[ARG1]][1] : i32 from vector<2xi32>
139+
// CHECK-NEXT: [[SUB_L:%.+]] = arith.subi [[LOW0]], [[LOW1]] : i32
140+
// CHECK-NEXT: [[ULT:%.+]] = arith.cmpi ult, [[LOW0]], [[LOW1]] : i32
141+
// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[ULT]] : i1 to i32
142+
// CHECK-NEXT: [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : i32
143+
// CHECK-NEXT: [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : i32
144+
// CHECK: [[INS0:%.+]] = vector.insert [[SUB_L]], {{%.+}} [0] : i32 into vector<2xi32>
145+
// CHECK-NEXT: [[INS1:%.+]] = vector.insert [[SUB_H1]], [[INS0]] [1] : i32 into vector<2xi32>
146+
// CHECK-NEXT: return [[INS1]] : vector<2xi32>
147+
func.func @subi_scalar(%a : i64, %b : i64) -> i64 {
148+
%x = arith.subi %a, %b : i64
149+
return %x : i64
150+
}
151+
152+
// CHECK-LABEL: func @subi_vector
153+
// CHECK-SAME: ([[ARG0:%.+]]: vector<4x2xi32>, [[ARG1:%.+]]: vector<4x2xi32>) -> vector<4x2xi32>
154+
// CHECK-NEXT: [[LOW0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
155+
// CHECK-NEXT: [[HIGH0:%.+]] = vector.extract_strided_slice [[ARG0]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
156+
// CHECK-NEXT: [[LOW1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 0], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
157+
// CHECK-NEXT: [[HIGH1:%.+]] = vector.extract_strided_slice [[ARG1]] {offsets = [0, 1], sizes = [4, 1], strides = [1, 1]} : vector<4x2xi32> to vector<4x1xi32>
158+
// CHECK-NEXT: [[SUB_L:%.+]] = arith.subi [[LOW0]], [[LOW1]] : vector<4x1xi32>
159+
// CHECK-NEXT: [[ULT:%.+]] = arith.cmpi ult, [[LOW0]], [[LOW1]] : vector<4x1xi32>
160+
// CHECK-NEXT: [[CARRY:%.+]] = arith.extui [[ULT]] : vector<4x1xi1> to vector<4x1xi32>
161+
// CHECK-NEXT: [[SUB_H0:%.+]] = arith.subi [[HIGH0]], [[CARRY]] : vector<4x1xi32>
162+
// CHECK-NEXT: [[SUB_H1:%.+]] = arith.subi [[SUB_H0]], [[HIGH1]] : vector<4x1xi32>
163+
// CHECK: [[INS0:%.+]] = vector.insert_strided_slice [[SUB_L]], {{%.+}} {offsets = [0, 0], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
164+
// CHECK-NEXT: [[INS1:%.+]] = vector.insert_strided_slice [[SUB_H1]], [[INS0]] {offsets = [0, 1], strides = [1, 1]} : vector<4x1xi32> into vector<4x2xi32>
165+
// CHECK-NEXT: return [[INS1]] : vector<4x2xi32>
166+
func.func @subi_vector(%a : vector<4xi64>, %b : vector<4xi64>) -> vector<4xi64> {
167+
%x = arith.subi %a, %b : vector<4xi64>
168+
return %x : vector<4xi64>
169+
}
170+
133171
// CHECK-LABEL: func.func @cmpi_eq_scalar
134172
// CHECK-SAME: ([[LHS:%.+]]: vector<2xi32>, [[RHS:%.+]]: vector<2xi32>)
135173
// CHECK-NEXT: [[LHSLOW:%.+]] = vector.extract [[LHS]][0] : i32 from vector<2xi32>
@@ -967,11 +1005,12 @@ func.func @uitofp_i64_f16(%a : i64) -> f16 {
9671005

9681006
// CHECK-LABEL: func @sitofp_i64_f64
9691007
// CHECK-SAME: ([[ARG:%.+]]: vector<2xi32>) -> f64
970-
// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<2xi32>
971-
// CHECK: [[ONES1:%.+]] = vector.extract [[VONES]][0] : i32 from vector<2xi32>
972-
// CHECK-NEXT: [[ONES2:%.+]] = vector.extract [[VONES]][1] : i32 from vector<2xi32>
973-
// CHECK: arith.xori {{%.+}}, [[ONES1]] : i32
974-
// CHECK-NEXT: arith.xori {{%.+}}, [[ONES2]] : i32
1008+
// CHECK: [[VZERO:%.+]] = arith.constant dense<0> : vector<2xi32>
1009+
// CHECK: vector.extract [[VZERO]][0] : i32 from vector<2xi32>
1010+
// CHECK: [[ZERO1:%.+]] = vector.extract [[VZERO]][0] : i32 from vector<2xi32>
1011+
// CHECK-NEXT: [[ZERO2:%.+]] = vector.extract [[VZERO]][1] : i32 from vector<2xi32>
1012+
// CHECK: arith.subi [[ZERO1]], {{%.+}} : i32
1013+
// CHECK: arith.subi [[ZERO2]], {{%.+}} : i32
9751014
// CHECK: [[CST0:%.+]] = arith.constant 0 : i32
9761015
// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0]] : i32
9771016
// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : i32 to f64
@@ -990,9 +1029,9 @@ func.func @sitofp_i64_f64(%a : i64) -> f64 {
9901029

9911030
// CHECK-LABEL: func @sitofp_i64_f64_vector
9921031
// CHECK-SAME: ([[ARG:%.+]]: vector<3x2xi32>) -> vector<3xf64>
993-
// CHECK: [[VONES:%.+]] = arith.constant dense<-1> : vector<3x2xi32>
994-
// CHECK: arith.xori
995-
// CHECK-NEXT: arith.xori
1032+
// CHECK: [[VZERO:%.+]] = arith.constant dense<0> : vector<3x2xi32>
1033+
// CHECK: arith.subi
1034+
// CHECK: arith.subi
9961035
// CHECK: [[HIEQ0:%.+]] = arith.cmpi eq, [[HI:%.+]], [[CST0:%.+]] : vector<3xi32>
9971036
// CHECK-NEXT: [[LOWFP:%.+]] = arith.uitofp [[LOW:%.+]] : vector<3xi32> to vector<3xf64>
9981037
// CHECK-NEXT: [[HIFP:%.+]] = arith.uitofp [[HI:%.+]] : vector<3xi32> to vector<3xf64>
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Ops in this function will be emulated using i16 types.
2+
3+
// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
4+
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
5+
// RUN: mlir-runner -e entry -entry-point-result=void \
6+
// RUN: --shared-libs=%mlir_c_runner_utils | \
7+
// RUN: FileCheck %s --match-full-lines
8+
9+
// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=16" \
10+
// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
11+
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
12+
// RUN: mlir-runner -e entry -entry-point-result=void \
13+
// RUN: --shared-libs=%mlir_c_runner_utils | \
14+
// RUN: FileCheck %s --match-full-lines
15+
16+
func.func @emulate_subi(%arg: i32, %arg0: i32) -> i32 {
17+
%res = arith.subi %arg, %arg0 : i32
18+
return %res : i32
19+
}
20+
21+
func.func @check_subi(%arg : i32, %arg0 : i32) -> () {
22+
%res = func.call @emulate_subi(%arg, %arg0) : (i32, i32) -> (i32)
23+
vector.print %res : i32
24+
return
25+
}
26+
27+
func.func @entry() {
28+
%lhs1 = arith.constant 1 : i32
29+
%rhs1 = arith.constant 2 : i32
30+
31+
// CHECK: -1
32+
func.call @check_subi(%lhs1, %rhs1) : (i32, i32) -> ()
33+
// CHECK-NEXT: 1
34+
func.call @check_subi(%rhs1, %lhs1) : (i32, i32) -> ()
35+
36+
%lhs2 = arith.constant 1 : i32
37+
%rhs2 = arith.constant -2 : i32
38+
39+
// CHECK-NEXT: 3
40+
func.call @check_subi(%lhs2, %rhs2) : (i32, i32) -> ()
41+
// CHECK-NEXT: -3
42+
func.call @check_subi(%rhs2, %lhs2) : (i32, i32) -> ()
43+
44+
%lhs3 = arith.constant -1 : i32
45+
%rhs3 = arith.constant -2 : i32
46+
47+
// CHECK-NEXT: 1
48+
func.call @check_subi(%lhs3, %rhs3) : (i32, i32) -> ()
49+
// CHECK-NEXT: -1
50+
func.call @check_subi(%rhs3, %lhs3) : (i32, i32) -> ()
51+
52+
// Overflow from the upper/lower part.
53+
%lhs4 = arith.constant 131074 : i32
54+
%rhs4 = arith.constant 3 : i32
55+
56+
// CHECK-NEXT: 131071
57+
func.call @check_subi(%lhs4, %rhs4) : (i32, i32) -> ()
58+
// CHECK-NEXT: -131071
59+
func.call @check_subi(%rhs4, %lhs4) : (i32, i32) -> ()
60+
61+
// Overflow in both parts.
62+
%lhs5 = arith.constant 16385027 : i32
63+
%rhs5 = arith.constant 16450564 : i32
64+
65+
// CHECK-NEXT: -65537
66+
func.call @check_subi(%lhs5, %rhs5) : (i32, i32) -> ()
67+
// CHECK-NEXT: 65537
68+
func.call @check_subi(%rhs5, %lhs5) : (i32, i32) -> ()
69+
70+
%lhs6 = arith.constant 65536 : i32
71+
%rhs6 = arith.constant 1 : i32
72+
73+
// CHECK-NEXT: 65535
74+
func.call @check_subi(%lhs6, %rhs6) : (i32, i32) -> ()
75+
// CHECK-NEXT: -65535
76+
func.call @check_subi(%rhs6, %lhs6) : (i32, i32) -> ()
77+
78+
// Max/Min (un)signed integers.
79+
%sintmax = arith.constant 2147483647 : i32
80+
%sintmin = arith.constant -2147483648 : i32
81+
%uintmax = arith.constant -1 : i32
82+
%uintmin = arith.constant 0 : i32
83+
%cst1 = arith.constant 1 : i32
84+
85+
// CHECK-NEXT: -1
86+
func.call @check_subi(%sintmax, %sintmin) : (i32, i32) -> ()
87+
// CHECK-NEXT: 1
88+
func.call @check_subi(%sintmin, %sintmax) : (i32, i32) -> ()
89+
// CHECK-NEXT: 2147483647
90+
func.call @check_subi(%sintmin, %cst1) : (i32, i32) -> ()
91+
// CHECK-NEXT: -2147483648
92+
func.call @check_subi(%sintmax, %uintmax) : (i32, i32) -> ()
93+
// CHECK-NEXT: -2
94+
func.call @check_subi(%uintmax, %cst1) : (i32, i32) -> ()
95+
// CHECK-NEXT: 0
96+
func.call @check_subi(%uintmax, %uintmax) : (i32, i32) -> ()
97+
// CHECK-NEXT: -1
98+
func.call @check_subi(%uintmin, %cst1) : (i32, i32) -> ()
99+
// CHECK-NEXT: 1
100+
func.call @check_subi(%uintmin, %uintmax) : (i32, i32) -> ()
101+
102+
103+
return
104+
}

0 commit comments

Comments
 (0)