Skip to content

Commit a4a10e9

Browse files
committed
Add shift operations
1 parent 5ed5160 commit a4a10e9

File tree

3 files changed

+222
-0
lines changed

3 files changed

+222
-0
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/Region.h"
2122
#include "mlir/Support/LogicalResult.h"
2223
#include "mlir/Transforms/DialectConversion.h"
2324

@@ -506,6 +507,90 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
506507
}
507508
};
508509

510+
template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
511+
class ShiftOpConversion : public OpConversionPattern<ArithOp> {
512+
public:
513+
using OpConversionPattern<ArithOp>::OpConversionPattern;
514+
515+
LogicalResult
516+
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
517+
ConversionPatternRewriter &rewriter) const override {
518+
519+
Type type = this->getTypeConverter()->convertType(op.getType());
520+
if (type && !(isa_and_nonnull<IntegerType>(type) ||
521+
emitc::isPointerWideType(type))) {
522+
return rewriter.notifyMatchFailure(
523+
op, "expected integer or size_t/ssize_t type");
524+
}
525+
526+
if (type.isInteger(1)) {
527+
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
528+
}
529+
530+
Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
531+
532+
Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
533+
// Shift amount interpreted as unsigned per Arith dialect spec.
534+
Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
535+
/*needsUnsigned=*/true);
536+
Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
537+
538+
// Add a runtime check for overflow
539+
Value width;
540+
if (emitc::isPointerWideType(type)) {
541+
Value eight = rewriter.create<emitc::ConstantOp>(
542+
op.getLoc(), rhsType, rewriter.getIndexAttr(8));
543+
emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
544+
op.getLoc(), rhsType, "sizeof", SmallVector<Value, 1>({eight}));
545+
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
546+
sizeOfCall.getResult(0));
547+
} else {
548+
width = rewriter.create<emitc::ConstantOp>(
549+
op.getLoc(), rhsType,
550+
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
551+
}
552+
553+
Value excessCheck = rewriter.create<emitc::CmpOp>(
554+
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
555+
556+
// Any concrete value is a valid refinement of poison.
557+
Value poison = rewriter.create<emitc::ConstantOp>(
558+
op.getLoc(), arithmeticType,
559+
(isa<IntegerType>(arithmeticType)
560+
? rewriter.getIntegerAttr(arithmeticType, 0)
561+
: rewriter.getIndexAttr(0)));
562+
563+
emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
564+
op.getLoc(), arithmeticType, /*do_not_inline=*/false);
565+
Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
566+
auto currentPoint = rewriter.getInsertionPoint();
567+
rewriter.setInsertionPointToStart(&bodyBlock);
568+
Value arithmeticResult =
569+
rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
570+
Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
571+
op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
572+
rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
573+
rewriter.setInsertionPoint(op->getBlock(), currentPoint);
574+
575+
Value result = adaptValueType(ternary, rewriter, type);
576+
577+
rewriter.replaceOp(op, result);
578+
return success();
579+
}
580+
};
581+
582+
template <typename ArithOp, typename EmitCOp>
583+
class SignedShiftOpConversion final
584+
: public ShiftOpConversion<ArithOp, EmitCOp, false> {
585+
using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
586+
};
587+
588+
template <typename ArithOp, typename EmitCOp>
589+
class UnsignedShiftOpConversion final
590+
: public ShiftOpConversion<ArithOp, EmitCOp, true> {
591+
using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
592+
};
593+
509594
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
510595
public:
511596
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
@@ -647,6 +732,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
647732
BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
648733
BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
649734
BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
735+
UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
736+
SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
737+
UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
650738
CmpFOpConversion,
651739
CmpIOpConversion,
652740
NegFOpConversion,

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,27 @@ func.func @arith_extsi_i1_to_i32(%arg0: i1) {
110110
%idx = arith.extsi %arg0 : i1 to i32
111111
return
112112
}
113+
114+
// -----
115+
116+
func.func @arith_shli_i1(%arg0: i1, %arg1: i1) {
117+
// expected-error @+1 {{failed to legalize operation 'arith.shli'}}
118+
%shli = arith.shli %arg0, %arg1 : i1
119+
return
120+
}
121+
122+
// -----
123+
124+
func.func @arith_shrsi_i1(%arg0: i1, %arg1: i1) {
125+
// expected-error @+1 {{failed to legalize operation 'arith.shrsi'}}
126+
%shrsi = arith.shrsi %arg0, %arg1 : i1
127+
return
128+
}
129+
130+
// -----
131+
132+
func.func @arith_shrui_i1(%arg0: i1, %arg1: i1) {
133+
// expected-error @+1 {{failed to legalize operation 'arith.shrui'}}
134+
%shrui = arith.shrui %arg0, %arg1 : i1
135+
return
136+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,116 @@ func.func @arith_signed_integer_div_rem(%arg0: i32, %arg1: i32) {
144144

145145
// -----
146146

147+
// CHECK-LABEL: arith_shift_left
148+
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
149+
func.func @arith_shift_left(%arg0: i32, %arg1: i32) {
150+
// CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
151+
// CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
152+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32
153+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
154+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
155+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
156+
// CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
157+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : ui32
158+
// CHECK: emitc.yield %[[Ternary]] : ui32
159+
// CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
160+
%1 = arith.shli %arg0, %arg1 : i32
161+
return
162+
}
163+
164+
// -----
165+
166+
// CHECK-LABEL: arith_shift_right
167+
// CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32
168+
func.func @arith_shift_right(%arg0: i32, %arg1: i32) {
169+
// CHECK-DAG: %[[C1:[^ ]*]] = emitc.cast %[[ARG0]] : i32 to ui32
170+
// CHECK-DAG: %[[C2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
171+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
172+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[C2]], %[[SizeConstant]] : (ui32, ui32) -> i1
173+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}ui32
174+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : ui32
175+
// CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
176+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : ui32
177+
// CHECK: emitc.yield %[[Ternary]] : ui32
178+
// CHECK: emitc.cast %[[ShiftRes]] : ui32 to i32
179+
%2 = arith.shrui %arg0, %arg1 : i32
180+
181+
// CHECK-DAG: %[[SC2:[^ ]*]] = emitc.cast %[[ARG1]] : i32 to ui32
182+
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = "emitc.constant"{{.*}}value = 32{{.*}}ui32
183+
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[SC2]], %[[SSizeConstant]] : (ui32, ui32) -> i1
184+
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}i32
185+
// CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : i32
186+
// CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[ARG0]], %[[SC2]] : (i32, ui32) -> i32
187+
// CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : i32
188+
// CHECK: emitc.yield %[[STernary]] : i32
189+
%3 = arith.shrsi %arg0, %arg1 : i32
190+
191+
return
192+
}
193+
194+
// -----
195+
196+
// CHECK-LABEL: arith_shift_left_index
197+
// CHECK-SAME: %[[AMOUNT:.*]]: i32
198+
func.func @arith_shift_left_index(%amount: i32) {
199+
%cst0 = "arith.constant"() {value = 42 : index} : () -> (index)
200+
%cast1 = arith.index_cast %amount : i32 to index
201+
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
202+
// CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
203+
// CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
204+
// CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
205+
// CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
206+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
207+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
208+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0
209+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
210+
// CHECK: %[[SHL:[^ ]*]] = emitc.bitwise_left_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
211+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHL]], %[[Zero]] : !emitc.size_t
212+
// CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
213+
%1 = arith.shli %cst0, %cast1 : index
214+
return
215+
}
216+
217+
// -----
218+
219+
// CHECK-LABEL: arith_shift_right_index
220+
// CHECK-SAME: %[[AMOUNT:.*]]: i32
221+
func.func @arith_shift_right_index(%amount: i32) {
222+
// CHECK-DAG: %[[C1:[^ ]*]] = "emitc.constant"(){{.*}}value = 42{{.*}}!emitc.size_t
223+
// CHECK-DAG: %[[Cast1:[^ ]*]] = emitc.cast %[[AMOUNT]] : i32 to !emitc.ssize_t
224+
// CHECK-DAG: %[[AmountIdx:[^ ]*]] = emitc.cast %[[Cast1]] : !emitc.ssize_t to !emitc.size_t
225+
%arg0 = "arith.constant"() {value = 42 : index} : () -> (index)
226+
%arg1 = arith.index_cast %amount : i32 to index
227+
228+
// CHECK-DAG: %[[Byte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index
229+
// CHECK-DAG: %[[SizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[Byte]]) : (!emitc.size_t) -> !emitc.size_t
230+
// CHECK-DAG: %[[SizeConstant:[^ ]*]] = emitc.mul %[[Byte]], %[[SizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
231+
// CHECK-DAG: %[[CmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
232+
// CHECK-DAG: %[[Zero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.size_t
233+
// CHECK: %[[ShiftRes:[^ ]*]] = emitc.expression : !emitc.size_t
234+
// CHECK: %[[SHR:[^ ]*]] = emitc.bitwise_right_shift %[[C1]], %[[AmountIdx]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
235+
// CHECK: %[[Ternary:[^ ]*]] = emitc.conditional %[[CmpNoExcess]], %[[SHR]], %[[Zero]] : !emitc.size_t
236+
// CHECK: emitc.yield %[[Ternary]] : !emitc.size_t
237+
%2 = arith.shrui %arg0, %arg1 : index
238+
239+
// CHECK-DAG: %[[SC1:[^ ]*]] = emitc.cast %[[C1]] : !emitc.size_t to !emitc.ssize_t
240+
// CHECK-DAG: %[[SByte:[^ ]*]] = "emitc.constant"{{.*}}value = 8{{.*}}index{{.*}}!emitc.size_t
241+
// CHECK-DAG: %[[SSizeOf:[^ ]*]] = emitc.call_opaque "sizeof"(%[[SByte]]) : (!emitc.size_t) -> !emitc.size_t
242+
// CHECK-DAG: %[[SSizeConstant:[^ ]*]] = emitc.mul %[[SByte]], %[[SSizeOf]] : (!emitc.size_t, !emitc.size_t) -> !emitc.size_t
243+
// CHECK-DAG: %[[SCmpNoExcess:[^ ]*]] = emitc.cmp lt, %[[AmountIdx]], %[[SSizeConstant]] : (!emitc.size_t, !emitc.size_t) -> i1
244+
// CHECK-DAG: %[[SZero:[^ ]*]] = "emitc.constant"{{.*}}value = 0{{.*}}!emitc.ssize_t
245+
// CHECK: %[[SShiftRes:[^ ]*]] = emitc.expression : !emitc.ssize_t
246+
// CHECK: %[[SHRSI:[^ ]*]] = emitc.bitwise_right_shift %[[SC1]], %[[AmountIdx]] : (!emitc.ssize_t, !emitc.size_t) -> !emitc.ssize_t
247+
// CHECK: %[[STernary:[^ ]*]] = emitc.conditional %[[SCmpNoExcess]], %[[SHRSI]], %[[SZero]] : !emitc.ssize_t
248+
// CHECK: emitc.yield %[[STernary]] : !emitc.ssize_t
249+
// CHECK: emitc.cast %[[SShiftRes]] : !emitc.ssize_t to !emitc.size_t
250+
%3 = arith.shrsi %arg0, %arg1 : index
251+
252+
return
253+
}
254+
255+
// -----
256+
147257
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
148258
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
149259
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>

0 commit comments

Comments
 (0)