Skip to content

Commit c77fb5a

Browse files
committed
[mlir][emitc] Arith to EmitC: Handle addi, subi and muli (llvm#86120)
Important to consider that `arith` has wrap around semantics, and in C++ signed overflow is UB. Unless the operation guarantees that no signed overflow happens, we will perform the arithmetic in an equivalent unsigned type. `bool` also doesn't wrap around in C++, and is not addressed here.
1 parent 15e1cdb commit c77fb5a

File tree

3 files changed

+108
-10
lines changed

3 files changed

+108
-10
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,55 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
290290
}
291291
};
292292

293+
template <typename ArithOp, typename EmitCOp>
294+
class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
295+
public:
296+
using OpConversionPattern<ArithOp>::OpConversionPattern;
297+
298+
LogicalResult
299+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
300+
ConversionPatternRewriter &rewriter) const override {
301+
302+
Type type = this->getTypeConverter()->convertType(op.getType());
303+
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
304+
return rewriter.notifyMatchFailure(op, "expected integer type");
305+
}
306+
307+
if (type.isInteger(1)) {
308+
// arith expects wrap-around arithmethic, which doesn't happen on `bool`.
309+
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
310+
}
311+
312+
Value lhs = adaptor.getLhs();
313+
Value rhs = adaptor.getRhs();
314+
Type arithmeticType = type;
315+
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
316+
!bitEnumContainsAll(op.getOverflowFlags(),
317+
arith::IntegerOverflowFlags::nsw)) {
318+
// If the C type is signed and the op doesn't guarantee "No Signed Wrap",
319+
// we compute in unsigned integers to avoid UB.
320+
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
321+
/*isSigned=*/false);
322+
}
323+
if (arithmeticType != type) {
324+
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
325+
lhs);
326+
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
327+
rhs);
328+
}
329+
330+
Value result = rewriter.template create<EmitCOp>(op.getLoc(),
331+
arithmeticType, lhs, rhs);
332+
333+
if (arithmeticType != type) {
334+
result =
335+
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
336+
}
337+
rewriter.replaceOp(op, result);
338+
return success();
339+
}
340+
};
341+
293342
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
294343
public:
295344
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -432,9 +481,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
432481
ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
433482
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
434483
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
435-
ArithOpConversion<arith::AddIOp, emitc::AddOp>,
436-
ArithOpConversion<arith::MulIOp, emitc::MulOp>,
437-
ArithOpConversion<arith::SubIOp, emitc::SubOp>,
484+
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
485+
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
486+
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
438487
CmpFOpConversion,
439488
CmpIOpConversion,
440489
SelectOpConversion,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt -convert-arith-to-emitc %s -split-input-file -verify-diagnostics
2+
3+
func.func @bool(%arg0: i1, %arg1: i1) {
4+
// expected-error@+1 {{failed to legalize operation 'arith.addi'}}
5+
%0 = arith.addi %arg0, %arg1 : i1
6+
return
7+
}
8+
9+
// -----
10+
11+
func.func @vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) {
12+
// expected-error@+1 {{failed to legalize operation 'arith.addi'}}
13+
%0 = arith.addi %arg0, %arg1 : vector<4xi32>
14+
return
15+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,51 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
3737

3838
// -----
3939

40+
// CHECK-LABEL: arith_integer_ops
4041
func.func @arith_integer_ops(%arg0: i32, %arg1: i32) {
41-
// CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
42+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
43+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
44+
// CHECK: %[[ADD:[^ ]*]] = emitc.add %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
45+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[ADD]] : ui32 to i32
4246
%0 = arith.addi %arg0, %arg1 : i32
43-
// CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
47+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
48+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
49+
// CHECK: %[[SUB:[^ ]*]] = emitc.sub %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
50+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SUB]] : ui32 to i32
4451
%1 = arith.subi %arg0, %arg1 : i32
45-
// CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
52+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
53+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
54+
// CHECK: %[[MUL:[^ ]*]] = emitc.mul %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
55+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[MUL]] : ui32 to i32
4656
%2 = arith.muli %arg0, %arg1 : i32
47-
// CHECK: emitc.div %arg0, %arg1 : (i32, i32) -> i32
48-
%3 = arith.divsi %arg0, %arg1 : i32
49-
// CHECK: emitc.rem %arg0, %arg1 : (i32, i32) -> i32
50-
%4 = arith.remsi %arg0, %arg1 : i32
57+
58+
return
59+
}
60+
61+
// -----
62+
63+
// CHECK-LABEL: arith_integer_ops_signed_nsw
64+
func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
65+
// CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
66+
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i32
67+
// CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
68+
%1 = arith.subi %arg0, %arg1 overflow<nsw> : i32
69+
// CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
70+
%2 = arith.muli %arg0, %arg1 overflow<nsw> : i32
71+
72+
return
73+
}
74+
75+
// -----
76+
77+
// CHECK-LABEL: arith_index
78+
func.func @arith_index(%arg0: index, %arg1: index) {
79+
// CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
80+
%0 = arith.addi %arg0, %arg1 : index
81+
// CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
82+
%1 = arith.subi %arg0, %arg1 : index
83+
// CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
84+
%2 = arith.muli %arg0, %arg1 : index
5185

5286
return
5387
}

0 commit comments

Comments
 (0)