Skip to content

Commit 71db971

Browse files
authored
[mlir][emitc] Arith to EmitC: Handle addi, subi and muli (#86120)
Important to consider that `arith` has wrap around semantics, and in C++ signed overflow is UB. Unless the operation guarantees that no signed overflow happens, we will perform the arithmetic in an equivalent unsigned type. `bool` also doesn't wrap around in C++, and is not addressed here.
1 parent 8612fa0 commit 71db971

File tree

3 files changed

+118
-0
lines changed

3 files changed

+118
-0
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,55 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
5555
}
5656
};
5757

58+
template <typename ArithOp, typename EmitCOp>
59+
class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
60+
public:
61+
using OpConversionPattern<ArithOp>::OpConversionPattern;
62+
63+
LogicalResult
64+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
65+
ConversionPatternRewriter &rewriter) const override {
66+
67+
Type type = this->getTypeConverter()->convertType(op.getType());
68+
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
69+
return rewriter.notifyMatchFailure(op, "expected integer type");
70+
}
71+
72+
if (type.isInteger(1)) {
73+
// arith expects wrap-around arithmethic, which doesn't happen on `bool`.
74+
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
75+
}
76+
77+
Value lhs = adaptor.getLhs();
78+
Value rhs = adaptor.getRhs();
79+
Type arithmeticType = type;
80+
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
81+
!bitEnumContainsAll(op.getOverflowFlags(),
82+
arith::IntegerOverflowFlags::nsw)) {
83+
// If the C type is signed and the op doesn't guarantee "No Signed Wrap",
84+
// we compute in unsigned integers to avoid UB.
85+
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
86+
/*isSigned=*/false);
87+
}
88+
if (arithmeticType != type) {
89+
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
90+
lhs);
91+
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
92+
rhs);
93+
}
94+
95+
Value result = rewriter.template create<EmitCOp>(op.getLoc(),
96+
arithmeticType, lhs, rhs);
97+
98+
if (arithmeticType != type) {
99+
result =
100+
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
101+
}
102+
rewriter.replaceOp(op, result);
103+
return success();
104+
}
105+
};
106+
58107
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
59108
public:
60109
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -96,6 +145,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
96145
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
97146
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
98147
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
148+
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
149+
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
150+
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
99151
SelectOpConversion
100152
>(typeConverter, ctx);
101153
// clang-format on
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt -convert-arith-to-emitc %s -split-input-file -verify-diagnostics
2+
3+
func.func @bool(%arg0: i1, %arg1: i1) {
4+
// expected-error@+1 {{failed to legalize operation 'arith.addi'}}
5+
%0 = arith.addi %arg0, %arg1 : i1
6+
return
7+
}
8+
9+
// -----
10+
11+
func.func @vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) {
12+
// expected-error@+1 {{failed to legalize operation 'arith.addi'}}
13+
%0 = arith.addi %arg0, %arg1 : vector<4xi32>
14+
return
15+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,57 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
3737

3838
// -----
3939

40+
// CHECK-LABEL: arith_integer_ops
41+
func.func @arith_integer_ops(%arg0: i32, %arg1: i32) {
42+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
43+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
44+
// CHECK: %[[ADD:[^ ]*]] = emitc.add %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
45+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[ADD]] : ui32 to i32
46+
%0 = arith.addi %arg0, %arg1 : i32
47+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
48+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
49+
// CHECK: %[[SUB:[^ ]*]] = emitc.sub %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
50+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SUB]] : ui32 to i32
51+
%1 = arith.subi %arg0, %arg1 : i32
52+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
53+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
54+
// CHECK: %[[MUL:[^ ]*]] = emitc.mul %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
55+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[MUL]] : ui32 to i32
56+
%2 = arith.muli %arg0, %arg1 : i32
57+
58+
return
59+
}
60+
61+
// -----
62+
63+
// CHECK-LABEL: arith_integer_ops_signed_nsw
64+
func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
65+
// CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
66+
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i32
67+
// CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
68+
%1 = arith.subi %arg0, %arg1 overflow<nsw> : i32
69+
// CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
70+
%2 = arith.muli %arg0, %arg1 overflow<nsw> : i32
71+
72+
return
73+
}
74+
75+
// -----
76+
77+
// CHECK-LABEL: arith_index
78+
func.func @arith_index(%arg0: index, %arg1: index) {
79+
// CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
80+
%0 = arith.addi %arg0, %arg1 : index
81+
// CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
82+
%1 = arith.subi %arg0, %arg1 : index
83+
// CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
84+
%2 = arith.muli %arg0, %arg1 : index
85+
86+
return
87+
}
88+
89+
// -----
90+
4091
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
4192
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
4293
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>

0 commit comments

Comments
 (0)