Skip to content

Commit b4e0e1c

Browse files
committed
Avoid UB due to signed wrap around
1 parent d509738 commit b4e0e1c

File tree

2 files changed

+88
-6
lines changed

2 files changed

+88
-6
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,50 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
5555
}
5656
};
5757

58+
template <typename ArithOp, typename EmitCOp>
59+
class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
60+
public:
61+
using OpConversionPattern<ArithOp>::OpConversionPattern;
62+
63+
LogicalResult
64+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
65+
ConversionPatternRewriter &rewriter) const override {
66+
67+
Type type = this->getTypeConverter()->convertType(op.getType());
68+
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
69+
return rewriter.notifyMatchFailure(op, "expected integer type");
70+
}
71+
72+
Value lhs = adaptor.getLhs();
73+
Value rhs = adaptor.getRhs();
74+
Type arithmeticType = type;
75+
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
76+
!bitEnumContainsAll(op.getOverflowFlags(),
77+
arith::IntegerOverflowFlags::nsw)) {
78+
// If the C type is signed and the op doesn't guarantee "No Signed Wrap",
79+
// we compute in unsigned integers to avoid UB.
80+
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
81+
/*isSigned=*/false);
82+
}
83+
if (arithmeticType != type) {
84+
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
85+
lhs);
86+
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
87+
rhs);
88+
}
89+
90+
Value result = rewriter.template create<EmitCOp>(op.getLoc(),
91+
arithmeticType, lhs, rhs);
92+
93+
if (arithmeticType != type) {
94+
result =
95+
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
96+
}
97+
rewriter.replaceOp(op, result);
98+
return success();
99+
}
100+
};
101+
58102
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
59103
public:
60104
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -96,9 +140,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
96140
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
97141
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
98142
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
99-
ArithOpConversion<arith::AddIOp, emitc::AddOp>,
100-
ArithOpConversion<arith::MulIOp, emitc::MulOp>,
101-
ArithOpConversion<arith::SubIOp, emitc::SubOp>,
143+
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
144+
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
145+
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
102146
SelectOpConversion
103147
>(typeConverter, ctx);
104148
// clang-format on

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,57 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
3737

3838
// -----
3939

40+
// CHECK-LABEL: arith_integer_ops
4041
func.func @arith_integer_ops(%arg0: i32, %arg1: i32) {
41-
// CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
42+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui3
43+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui3
44+
// CHECK: %[[ADD:[^ ]*]] = emitc.add %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
45+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[ADD]] : ui32 to i3
4246
%0 = arith.addi %arg0, %arg1 : i32
43-
// CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
47+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui3
48+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui3
49+
// CHECK: %[[SUB:[^ ]*]] = emitc.sub %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
50+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SUB]] : ui32 to i3
4451
%1 = arith.subi %arg0, %arg1 : i32
45-
// CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
52+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui3
53+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui3
54+
// CHECK: %[[MUL:[^ ]*]] = emitc.mul %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
55+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[MUL]] : ui32 to i3
4656
%2 = arith.muli %arg0, %arg1 : i32
4757

4858
return
4959
}
5060

5161
// -----
5262

63+
// CHECK-LABEL: arith_integer_ops_signed_nsw
64+
func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
65+
// CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
66+
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i32
67+
// CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
68+
%1 = arith.subi %arg0, %arg1 overflow<nsw> : i32
69+
// CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
70+
%2 = arith.muli %arg0, %arg1 overflow<nsw> : i32
71+
72+
return
73+
}
74+
75+
// -----
76+
77+
// CHECK-LABEL: arith_index
78+
func.func @arith_index(%arg0: index, %arg1: index) {
79+
// CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
80+
%0 = arith.addi %arg0, %arg1 : index
81+
// CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
82+
%1 = arith.subi %arg0, %arg1 : index
83+
// CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
84+
%2 = arith.muli %arg0, %arg1 : index
85+
86+
return
87+
}
88+
89+
// -----
90+
5391
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
5492
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
5593
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>

0 commit comments

Comments
 (0)