Skip to content

Commit e452afa

Browse files
authored
Lower arith.andi, arith.ori, arith.shli, arith.shrsi, arith.shrui, arith.xori to EmitC
1 parent 25a21e0 commit e452afa

File tree

3 files changed

+304
-0
lines changed

3 files changed

+304
-0
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/Region.h"
2122
#include "mlir/Support/LogicalResult.h"
2223
#include "mlir/Transforms/DialectConversion.h"
2324

@@ -443,6 +444,131 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
443444
}
444445
};
445446

447+
template <typename ArithOp, typename EmitCOp>
448+
class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
449+
public:
450+
using OpConversionPattern<ArithOp>::OpConversionPattern;
451+
452+
LogicalResult
453+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
454+
ConversionPatternRewriter &rewriter) const override {
455+
456+
Type type = this->getTypeConverter()->convertType(op.getType());
457+
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
458+
type)) {
459+
return rewriter.notifyMatchFailure(
460+
op, "expected integer or size_t/ssize_t type, vector/tensor support "
461+
"not yet implemented");
462+
}
463+
464+
// Bitwise ops can be performed directly on booleans
465+
if (type.isInteger(1)) {
466+
rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
467+
adaptor.getRhs());
468+
return success();
469+
}
470+
471+
// Bitwise ops are defined by the C standard on unsigned operands.
472+
Type arithmeticType =
473+
adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
474+
475+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
476+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
477+
478+
Value arithmeticResult = rewriter.template create<EmitCOp>(
479+
op.getLoc(), arithmeticType, lhs, rhs);
480+
481+
Value result = adaptValueType(arithmeticResult, rewriter, type);
482+
483+
rewriter.replaceOp(op, result);
484+
return success();
485+
}
486+
};
487+
488+
template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
489+
class ShiftOpConversion : public OpConversionPattern<ArithOp> {
490+
public:
491+
using OpConversionPattern<ArithOp>::OpConversionPattern;
492+
493+
LogicalResult
494+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
495+
ConversionPatternRewriter &rewriter) const override {
496+
497+
Type type = this->getTypeConverter()->convertType(op.getType());
498+
if (!isa_and_nonnull<IntegerType, emitc::SignedSizeTType, emitc::SizeTType>(
499+
type)) {
500+
return rewriter.notifyMatchFailure(
501+
op, "expected integer or size_t/ssize_t type");
502+
}
503+
504+
if (type.isInteger(1)) {
505+
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
506+
}
507+
508+
Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
509+
510+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
511+
// Shift amount interpreted as unsigned per Arith dialect spec.
512+
Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
513+
/*needsUnsigned=*/true);
514+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
515+
516+
// Add a runtime check for overflow
517+
Value width;
518+
if (isa<emitc::SignedSizeTType, emitc::SizeTType>(type)) {
519+
Value eight = rewriter.create<emitc::ConstantOp>(
520+
op.getLoc(), rhsType, rewriter.getIndexAttr(8));
521+
emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
522+
op.getLoc(), rhsType, "sizeof", SmallVector<Value, 1>({eight}));
523+
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
524+
sizeOfCall.getResult(0));
525+
} else {
526+
width = rewriter.create<emitc::ConstantOp>(
527+
op.getLoc(), rhsType,
528+
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
529+
}
530+
531+
Value excessCheck = rewriter.create<emitc::CmpOp>(
532+
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
533+
534+
// Any concrete value is a valid refinement of poison.
535+
Value poison = rewriter.create<emitc::ConstantOp>(
536+
op.getLoc(), arithmeticType,
537+
(isa<IntegerType>(arithmeticType)
538+
? rewriter.getIntegerAttr(arithmeticType, 0)
539+
: rewriter.getIndexAttr(0)));
540+
541+
emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
542+
op.getLoc(), arithmeticType, /*do_not_inline=*/false);
543+
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
544+
auto currentPoint = rewriter.getInsertionPoint();
545+
rewriter.setInsertionPointToStart(&bodyBlock);
546+
Value arithmeticResult =
547+
rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
548+
Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
549+
op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
550+
rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
551+
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
552+
553+
Value result = adaptValueType(ternary, rewriter, type);
554+
555+
rewriter.replaceOp(op, result);
556+
return success();
557+
}
558+
};
559+
560+
template <typename ArithOp, typename EmitCOp>
561+
class SignedShiftOpConversion final
562+
: public ShiftOpConversion<ArithOp, EmitCOp, false> {
563+
using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
564+
};
565+
566+
template <typename ArithOp, typename EmitCOp>
567+
class UnsignedShiftOpConversion final
568+
: public ShiftOpConversion<ArithOp, EmitCOp, true> {
569+
using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
570+
};
571+
446572
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
447573
public:
448574
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -581,6 +707,12 @@ void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns,
581707
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
582708
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
583709
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
710+
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
711+
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
712+
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
713+
UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
714+
SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
715+
UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
584716
CmpFOpConversion,
585717
CmpIOpConversion,
586718
SelectOpConversion,

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,27 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
7878
%t = arith.fptoui %arg0 : f32 to i1
7979
return %t: i1
8080
}
81+
82+
// -----
83+
84+
func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
85+
// expected-error @+1 {{failed to legalize operation 'arith.shli'}}
86+
%shli = arith.shli %arg0, %arg1 : i1
87+
return
88+
}
89+
90+
// -----
91+
92+
func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
93+
// expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
94+
%shrsi = arith.shrsi %arg0, %arg1 : i1
95+
return
96+
}
97+
98+
// -----
99+
100+
func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
101+
// expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
102+
%shrui = arith.shrui %arg0, %arg1 : i1
103+
return
104+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,154 @@ func.func @arith_index(%arg0: i32, %arg1: i32) {
110110

111111
// -----
112112

113+
// CHECK-LABEL: arith_bitwise
114+
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
115+
func.func @arith_bitwise(%arg0: i32, %arg1: i32) {
116+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
117+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
118+
// CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
119+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[AND]] : ui32 to i32
120+
%5 = arith.andi %arg0, %arg1 : i32
121+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
122+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
123+
// CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
124+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[OR]] : ui32 to i32
125+
%6 = arith.ori %arg0, %arg1 : i32
126+
// CHECK: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
127+
// CHECK: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
128+
// CHECK: %[[XOR:[^ ]*]] = emitc.bitwise_xor %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
129+
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[XOR]] : ui32 to i32
130+
%7 = arith.xori %arg0, %arg1 : i32
131+
132+
return
133+
}
134+
135+
// -----
136+
137+
// CHECK-LABEL: arith_bitwise_bool
138+
func.func @arith_bitwise_bool(%arg0: i1, %arg1: i1) {
139+
// CHECK: %[[AND:[^ ]*]] = emitc.bitwise_and %arg0, %arg1 : (i1, i1) -> i1
140+
%5 = arith.andi %arg0, %arg1 : i1
141+
// CHECK: %[[OR:[^ ]*]] = emitc.bitwise_or %arg0, %arg1 : (i1, i1) -> i1
142+
%6 = arith.ori %arg0, %arg1 : i1
143+
// CHECK: %[[xor:[^ ]*]] = emitc.bitwise_xor %arg0, %arg1 : (i1, i1) -> i1
144+
%7 = arith.xori %arg0, %arg1 : i1
145+
146+
return
147+
}
148+
149+
// -----
150+
151+
// CHECK-LABEL: arith_shift_left
152+
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
153+
func.func @arith_shift_left(%arg0: i32, %arg1: i32) {
154+
// CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
155+
// CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
156+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32
157+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
158+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
159+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
160+
// CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
161+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32
162+
// CHECK: emitc.yield %[[Ternary]] : ui32
163+
// CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
164+
%1 = arith.shli %arg0, %arg1 : i32
165+
return
166+
}
167+
168+
// -----
169+
170+
// CHECK-LABEL: arith_shift_right
171+
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
172+
func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
173+
// CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
174+
// CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
175+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
176+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
177+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32
178+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
179+
// CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
180+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32
181+
// CHECK: emitc.yield %[[Ternary]] : ui32
182+
// CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
183+
%2 = arith.shrui %arg0, %arg1 : i32
184+
185+
// CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
186+
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
187+
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1
188+
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32
189+
// CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32
190+
// CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32
191+
// CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32
192+
// CHECK: emitc.yield %[[STernary]] : i32
193+
%3 = arith.shrsi %arg0, %arg1 : i32
194+
195+
return
196+
}
197+
198+
// -----
199+
200+
// CHECK-LABEL: arith_shift_left_index
201+
// CHECK-SAME: %[[AMOUNT:.*]]: i32
202+
func.func @arith_shift_left_index(%amount: i32) {
203+
%cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
204+
%cast1 = arith.index_cast %amount : i32 to index
205+
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
206+
// CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
207+
// CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
208+
// CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
209+
// CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
210+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
211+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
212+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
213+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
214+
// CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
215+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t
216+
// CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
217+
%1 = arith.shli %cst0, %cast1 : index
218+
return
219+
}
220+
221+
// -----
222+
223+
// CHECK-LABEL: arith_shift_right_index
224+
// CHECK-SAME: %[[AMOUNT:.*]]: i32
225+
func.func @arith_shift_right_index(%amount: i32) {
226+
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
227+
// CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
228+
// CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
229+
%arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
230+
%arg1 = arith.index_cast %amount : i32 to index
231+
232+
// CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
233+
// CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
234+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
235+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
236+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t
237+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
238+
// CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
239+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t
240+
// CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
241+
%2 = arith.shrui %arg0, %arg1 : index
242+
243+
// CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t
244+
// CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
245+
// CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
246+
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
247+
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
248+
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t
249+
// CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t
250+
// CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t
251+
// CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t
252+
// CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t
253+
// CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t
254+
%3 = arith.shrsi %arg0, %arg1 : index
255+
256+
return
257+
}
258+
259+
// -----
260+
113261
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
114262
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
115263
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>

0 commit comments

Comments
 (0)