Skip to content

Commit 4309170

Browse files
committed
[mlir] Add arith.addi_carry op
The `arith.addi_carry` op implements integer addition with overflows. The carry is returned via the second result, as `i1`. Reviewed By: antiagainst, bondhugula Differential Revision: https://reviews.llvm.org/D131893
1 parent 36bdec4 commit 4309170

File tree

5 files changed

+249
-1
lines changed

5 files changed

+249
-1
lines changed

mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,41 @@ def Arith_AddIOp : Arith_IntBinaryOp<"addi", [Commutative]> {
202202
let hasCanonicalizer = 1;
203203
}
204204

205+
206+
def Arith_AddICarryOp : Arith_Op<"addi_carry", [Commutative,
207+
AllTypesMatch<["lhs", "rhs", "sum"]>]> {
208+
let summary = "integer addition operation returning both the sum and carry";
209+
let description = [{
210+
The `addi_carry` operation takes two operands and returns two results: the
211+
sum (same type as both operands), and the carry (boolean-like).
212+
213+
Example:
214+
215+
```mlir
216+
// Scalar addition.
217+
%sum, %carry = arith.addi_carry %b, %c : i64, i1
218+
219+
// Vector element-wise addition.
220+
%b:2 = arith.addi_carry %g, %h : vector<4xi32>, vector<4xi1>
221+
222+
// Tensor element-wise addition.
223+
%c:2 = arith.addi_carry %y, %z : tensor<4x?xi8>, tensor<4x?xi1>
224+
```
225+
}];
226+
227+
let arguments = (ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs);
228+
let results = (outs SignlessIntegerLike:$sum, BoolLike:$carry);
229+
let assemblyFormat = [{
230+
$lhs `,` $rhs attr-dict `:` type($sum) `,` type($carry)
231+
}];
232+
233+
let hasFolder = 1;
234+
235+
let extraClassDeclaration = [{
236+
::llvm::Optional<::llvm::SmallVector<int64_t, 4>> getShapeForUnroll();
237+
}];
238+
}
239+
205240
//===----------------------------------------------------------------------===//
206241
// SubIOp
207242
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <cassert>
910
#include <utility>
1011

1112
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
@@ -15,9 +16,9 @@
1516
#include "mlir/IR/OpImplementation.h"
1617
#include "mlir/IR/PatternMatch.h"
1718
#include "mlir/IR/TypeUtilities.h"
18-
#include "llvm/ADT/SmallString.h"
1919

2020
#include "llvm/ADT/APSInt.h"
21+
#include "llvm/ADT/SmallString.h"
2122

2223
using namespace mlir;
2324
using namespace mlir::arith;
@@ -216,6 +217,81 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
216217
context);
217218
}
218219

220+
//===----------------------------------------------------------------------===//
221+
// AddICarryOp
222+
//===----------------------------------------------------------------------===//
223+
224+
Optional<SmallVector<int64_t, 4>> arith::AddICarryOp::getShapeForUnroll() {
225+
if (auto vt = getType(0).dyn_cast<VectorType>())
226+
return llvm::to_vector<4>(vt.getShape());
227+
return None;
228+
}
229+
230+
// Returns the carry bit, assuming that `sum` is the result of addition of
231+
// `operand` and another number.
232+
static APInt calculateCarry(const APInt &sum, const APInt &operand) {
233+
return sum.ult(operand) ? APInt::getAllOnes(1) : APInt::getZero(1);
234+
}
235+
236+
LogicalResult arith::AddICarryOp::fold(ArrayRef<Attribute> operands,
237+
SmallVectorImpl<OpFoldResult> &results) {
238+
auto carryTy = getCarry().getType();
239+
// addi_carry(x, 0) -> x, false
240+
if (matchPattern(getRhs(), m_Zero())) {
241+
auto carryZero = APInt::getZero(1);
242+
Builder builder(getContext());
243+
auto falseValue = builder.getZeroAttr(carryTy);
244+
245+
results.push_back(getLhs());
246+
results.push_back(falseValue);
247+
return success();
248+
}
249+
250+
// addi_carry(constant_a, constant_b) -> constant_sum, constant_carry
251+
// Let the `constFoldBinaryOp` utility attempt to fold the sum of both
252+
// operands. If that succeeds, calculate the carry boolean based on the sum
253+
// and the first (constant) operand, `lhs`. Note that we cannot simply call
254+
// `constFoldBinaryOp` again to calculate the carry (bit) because the
255+
// constructed attribute is of the same element type as both operands.
256+
if (Attribute sumAttr = constFoldBinaryOp<IntegerAttr>(
257+
operands, [](APInt a, const APInt &b) { return std::move(a) + b; })) {
258+
Attribute carryAttr;
259+
if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
260+
// Both arguments are scalars, calculate the scalar carry value.
261+
auto sum = sumAttr.cast<IntegerAttr>();
262+
carryAttr = IntegerAttr::get(
263+
carryTy, calculateCarry(sum.getValue(), lhs.getValue()));
264+
} else if (auto lhs = operands[0].dyn_cast<SplatElementsAttr>()) {
265+
// Both arguments are splats, calculate the splat carry value.
266+
auto sum = sumAttr.cast<SplatElementsAttr>();
267+
APInt carry = calculateCarry(sum.getSplatValue<APInt>(),
268+
lhs.getSplatValue<APInt>());
269+
carryAttr = SplatElementsAttr::get(carryTy, carry);
270+
} else if (auto lhs = operands[0].dyn_cast<ElementsAttr>()) {
271+
// Othwerwise calculate element-wise carry values.
272+
auto sum = sumAttr.cast<ElementsAttr>();
273+
const auto numElems = static_cast<size_t>(sum.getNumElements());
274+
SmallVector<APInt> carryValues;
275+
carryValues.reserve(numElems);
276+
277+
auto sumIt = sum.value_begin<APInt>();
278+
auto lhsIt = lhs.value_begin<APInt>();
279+
for (size_t i = 0, e = numElems; i != e; ++i, ++sumIt, ++lhsIt)
280+
carryValues.push_back(calculateCarry(*sumIt, *lhsIt));
281+
282+
carryAttr = DenseElementsAttr::get(carryTy, carryValues);
283+
} else {
284+
return failure();
285+
}
286+
287+
results.push_back(sumAttr);
288+
results.push_back(carryAttr);
289+
return success();
290+
}
291+
292+
return failure();
293+
}
294+
219295
//===----------------------------------------------------------------------===//
220296
// SubIOp
221297
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Arithmetic/canonicalize.mlir

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,87 @@ func.func @doubleAddSub2(%arg0: index, %arg1 : index) -> index {
544544
return %add : index
545545
}
546546

547+
// CHECK-LABEL: @addiCarryZeroRhs
548+
// CHECK-NEXT: %[[false:.+]] = arith.constant false
549+
// CHECK-NEXT: return %arg0, %[[false]]
550+
func.func @addiCarryZeroRhs(%arg0: i32) -> (i32, i1) {
551+
%zero = arith.constant 0 : i32
552+
%sum, %carry = arith.addi_carry %arg0, %zero: i32, i1
553+
return %sum, %carry : i32, i1
554+
}
555+
556+
// CHECK-LABEL: @addiCarryZeroRhsSplat
557+
// CHECK-NEXT: %[[false:.+]] = arith.constant dense<false> : vector<4xi1>
558+
// CHECK-NEXT: return %arg0, %[[false]]
559+
func.func @addiCarryZeroRhsSplat(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi1>) {
560+
%zero = arith.constant dense<0> : vector<4xi32>
561+
%sum, %carry = arith.addi_carry %arg0, %zero: vector<4xi32>, vector<4xi1>
562+
return %sum, %carry : vector<4xi32>, vector<4xi1>
563+
}
564+
565+
// CHECK-LABEL: @addiCarryZeroLhs
566+
// CHECK-NEXT: %[[false:.+]] = arith.constant false
567+
// CHECK-NEXT: return %arg0, %[[false]]
568+
func.func @addiCarryZeroLhs(%arg0: i32) -> (i32, i1) {
569+
%zero = arith.constant 0 : i32
570+
%sum, %carry = arith.addi_carry %zero, %arg0: i32, i1
571+
return %sum, %carry : i32, i1
572+
}
573+
574+
// CHECK-LABEL: @addiCarryConstants
575+
// CHECK-DAG: %[[false:.+]] = arith.constant false
576+
// CHECK-DAG: %[[c50:.+]] = arith.constant 50 : i32
577+
// CHECK-NEXT: return %[[c50]], %[[false]]
578+
func.func @addiCarryConstants() -> (i32, i1) {
579+
%c13 = arith.constant 13 : i32
580+
%c37 = arith.constant 37 : i32
581+
%sum, %carry = arith.addi_carry %c13, %c37: i32, i1
582+
return %sum, %carry : i32, i1
583+
}
584+
585+
// CHECK-LABEL: @addiCarryConstantsOverflow1
586+
// CHECK-DAG: %[[true:.+]] = arith.constant true
587+
// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : i32
588+
// CHECK-NEXT: return %[[c0]], %[[true]]
589+
func.func @addiCarryConstantsOverflow1() -> (i32, i1) {
590+
%max = arith.constant 4294967295 : i32
591+
%c1 = arith.constant 1 : i32
592+
%sum, %carry = arith.addi_carry %max, %c1: i32, i1
593+
return %sum, %carry : i32, i1
594+
}
595+
596+
// CHECK-LABEL: @addiCarryConstantsOverflow2
597+
// CHECK-DAG: %[[true:.+]] = arith.constant true
598+
// CHECK-DAG: %[[c_2:.+]] = arith.constant -2 : i32
599+
// CHECK-NEXT: return %[[c_2]], %[[true]]
600+
func.func @addiCarryConstantsOverflow2() -> (i32, i1) {
601+
%max = arith.constant 4294967295 : i32
602+
%sum, %carry = arith.addi_carry %max, %max: i32, i1
603+
return %sum, %carry : i32, i1
604+
}
605+
606+
// CHECK-LABEL: @addiCarryConstantsOverflowVector
607+
// CHECK-DAG: %[[sum:.+]] = arith.constant dense<[1, 6, 2, 14]> : vector<4xi32>
608+
// CHECK-DAG: %[[carry:.+]] = arith.constant dense<[false, false, true, false]> : vector<4xi1>
609+
// CHECK-NEXT: return %[[sum]], %[[carry]]
610+
func.func @addiCarryConstantsOverflowVector() -> (vector<4xi32>, vector<4xi1>) {
611+
%v1 = arith.constant dense<[1, 3, 3, 7]> : vector<4xi32>
612+
%v2 = arith.constant dense<[0, 3, 4294967295, 7]> : vector<4xi32>
613+
%sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
614+
return %sum, %carry : vector<4xi32>, vector<4xi1>
615+
}
616+
617+
// CHECK-LABEL: @addiCarryConstantsSplatVector
618+
// CHECK-DAG: %[[sum:.+]] = arith.constant dense<3> : vector<4xi32>
619+
// CHECK-DAG: %[[carry:.+]] = arith.constant dense<false> : vector<4xi1>
620+
// CHECK-NEXT: return %[[sum]], %[[carry]]
621+
func.func @addiCarryConstantsSplatVector() -> (vector<4xi32>, vector<4xi1>) {
622+
%v1 = arith.constant dense<1> : vector<4xi32>
623+
%v2 = arith.constant dense<2> : vector<4xi32>
624+
%sum, %carry = arith.addi_carry %v1, %v2 : vector<4xi32>, vector<4xi1>
625+
return %sum, %carry : vector<4xi32>, vector<4xi1>
626+
}
627+
547628
// CHECK-LABEL: @notCmpEQ
548629
// CHECK: %[[cres:.+]] = arith.cmpi ne, %arg0, %arg1 : i8
549630
// CHECK: return %[[cres]]

mlir/test/Dialect/Arithmetic/invalid.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,38 @@ func.func @func_with_ops(f32) {
110110

111111
// -----
112112

113+
func.func @func_with_ops(%a: f32) {
114+
// expected-error@+1 {{'arith.addi_carry' op operand #0 must be signless-integer-like}}
115+
%r:2 = arith.addi_carry %a, %a : f32, i32
116+
return
117+
}
118+
119+
// -----
120+
121+
func.func @func_with_ops(%a: i32) {
122+
// expected-error@+1 {{'arith.addi_carry' op result #1 must be bool-like}}
123+
%r:2 = arith.addi_carry %a, %a : i32, i32
124+
return
125+
}
126+
127+
// -----
128+
129+
func.func @func_with_ops(%a: vector<8xi32>) {
130+
// expected-error@+1 {{'arith.addi_carry' op if an operand is non-scalar, then all results must be non-scalar}}
131+
%r:2 = arith.addi_carry %a, %a : vector<8xi32>, i1
132+
return
133+
}
134+
135+
// -----
136+
137+
func.func @func_with_ops(%a: vector<8xi32>) {
138+
// expected-error@+1 {{'arith.addi_carry' op all non-scalar operands/results must have the same shape and base type}}
139+
%r:2 = arith.addi_carry %a, %a : vector<8xi32>, tensor<8xi1>
140+
return
141+
}
142+
143+
// -----
144+
113145
func.func @func_with_ops(i32) {
114146
^bb0(%a : i32):
115147
%sf = arith.addf %a, %a : i32 // expected-error {{'arith.addf' op operand #0 must be floating-point-like}}

mlir/test/Dialect/Arithmetic/ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,30 @@ func.func @test_addi_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]
2525
return %0 : vector<[8]xi64>
2626
}
2727

28+
// CHECK-LABEL: test_addi_carry
29+
func.func @test_addi_carry(%arg0 : i64, %arg1 : i64) -> i64 {
30+
%sum, %carry = arith.addi_carry %arg0, %arg1 : i64, i1
31+
return %sum : i64
32+
}
33+
34+
// CHECK-LABEL: test_addi_carry_tensor
35+
func.func @test_addi_carry_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) -> tensor<8x8xi64> {
36+
%sum, %carry = arith.addi_carry %arg0, %arg1 : tensor<8x8xi64>, tensor<8x8xi1>
37+
return %sum : tensor<8x8xi64>
38+
}
39+
40+
// CHECK-LABEL: test_addi_carry_vector
41+
func.func @test_addi_carry_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
42+
%0:2 = arith.addi_carry %arg0, %arg1 : vector<8xi64>, vector<8xi1>
43+
return %0#0 : vector<8xi64>
44+
}
45+
46+
// CHECK-LABEL: test_addi_carry_scalable_vector
47+
func.func @test_addi_carry_scalable_vector(%arg0 : vector<[8]xi64>, %arg1 : vector<[8]xi64>) -> vector<[8]xi64> {
48+
%0:2 = arith.addi_carry %arg0, %arg1 : vector<[8]xi64>, vector<[8]xi1>
49+
return %0#0 : vector<[8]xi64>
50+
}
51+
2852
// CHECK-LABEL: test_subi
2953
func.func @test_subi(%arg0 : i64, %arg1 : i64) -> i64 {
3054
%0 = arith.subi %arg0, %arg1 : i64

0 commit comments

Comments
 (0)