15
15
16
16
#include " mlir/Dialect/Arith/IR/Arith.h"
17
17
#include " mlir/Dialect/EmitC/IR/EmitC.h"
18
+ #include " mlir/Dialect/EmitC/Transforms/TypeConversions.h"
18
19
#include " mlir/IR/BuiltinAttributes.h"
19
20
#include " mlir/IR/BuiltinTypes.h"
20
21
#include " mlir/Support/LogicalResult.h"
@@ -36,8 +37,11 @@ class ArithConstantOpConversionPattern
36
37
matchAndRewrite (arith::ConstantOp arithConst,
37
38
arith::ConstantOp::Adaptor adaptor,
38
39
ConversionPatternRewriter &rewriter) const override {
39
- rewriter.replaceOpWithNewOp <emitc::ConstantOp>(
40
- arithConst, arithConst.getType (), adaptor.getValue ());
40
+ Type newTy = this ->getTypeConverter ()->convertType (arithConst.getType ());
41
+ if (!newTy)
42
+ return rewriter.notifyMatchFailure (arithConst, " type conversion failed" );
43
+ rewriter.replaceOpWithNewOp <emitc::ConstantOp>(arithConst, newTy,
44
+ adaptor.getValue ());
41
45
return success ();
42
46
}
43
47
};
@@ -52,6 +56,12 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
52
56
return IntegerType::get (ty.getContext (), ty.getIntOrFloatBitWidth (),
53
57
signedness);
54
58
}
59
+ } else if (emitc::isPointerWideType (ty)) {
60
+ if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
61
+ if (needsUnsigned)
62
+ return emitc::SizeTType::get (ty.getContext ());
63
+ return emitc::PtrDiffTType::get (ty.getContext ());
64
+ }
55
65
}
56
66
return ty;
57
67
}
@@ -264,8 +274,9 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
264
274
ConversionPatternRewriter &rewriter) const override {
265
275
266
276
Type type = adaptor.getLhs ().getType ();
267
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
268
- return rewriter.notifyMatchFailure (op, " expected integer or index type" );
277
+ if (type && !(isa<IntegerType>(type) || emitc::isPointerWideType (type))) {
278
+ return rewriter.notifyMatchFailure (
279
+ op, " expected integer or size_t/ssize_t type" );
269
280
}
270
281
271
282
bool needsUnsigned = needsUnsignedCmp (op.getPredicate ());
@@ -318,17 +329,21 @@ class CastConversion : public OpConversionPattern<ArithOp> {
318
329
ConversionPatternRewriter &rewriter) const override {
319
330
320
331
Type opReturnType = this ->getTypeConverter ()->convertType (op.getType ());
321
- if (!isa_and_nonnull<IntegerType>(opReturnType))
322
- return rewriter.notifyMatchFailure (op, " expected integer result type" );
332
+ if (opReturnType && !(isa_and_nonnull<IntegerType>(opReturnType) ||
333
+ emitc::isPointerWideType (opReturnType)))
334
+ return rewriter.notifyMatchFailure (
335
+ op, " expected integer or size_t/ssize_t result type" );
323
336
324
337
if (adaptor.getOperands ().size () != 1 ) {
325
338
return rewriter.notifyMatchFailure (
326
339
op, " CastConversion only supports unary ops" );
327
340
}
328
341
329
342
Type operandType = adaptor.getIn ().getType ();
330
- if (!isa_and_nonnull<IntegerType>(operandType))
331
- return rewriter.notifyMatchFailure (op, " expected integer operand type" );
343
+ if (operandType && !(isa_and_nonnull<IntegerType>(operandType) ||
344
+ emitc::isPointerWideType (operandType)))
345
+ return rewriter.notifyMatchFailure (
346
+ op, " expected integer or size_t/ssize_t operand type" );
332
347
333
348
// Signed (sign-extending) casts from i1 are not supported.
334
349
if (operandType.isInteger (1 ) && !castToUnsigned)
@@ -339,8 +354,11 @@ class CastConversion : public OpConversionPattern<ArithOp> {
339
354
// equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
340
355
// truncation.
341
356
if (opReturnType.isInteger (1 )) {
357
+ Type attrType = (emitc::isPointerWideType (operandType))
358
+ ? rewriter.getIndexType ()
359
+ : operandType;
342
360
auto constOne = rewriter.create <emitc::ConstantOp>(
343
- op.getLoc (), operandType, rewriter.getIntegerAttr (operandType , 1 ));
361
+ op.getLoc (), operandType, rewriter.getIntegerAttr (attrType , 1 ));
344
362
auto oneAndOperand = rewriter.create <emitc::BitwiseAndOp>(
345
363
op.getLoc (), operandType, adaptor.getIn (), constOne);
346
364
rewriter.replaceOpWithNewOp <emitc::CastOp>(op, opReturnType,
@@ -393,7 +411,11 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
393
411
matchAndRewrite (ArithOp arithOp, typename ArithOp::Adaptor adaptor,
394
412
ConversionPatternRewriter &rewriter) const override {
395
413
396
- rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, arithOp.getType (),
414
+ Type newTy = this ->getTypeConverter ()->convertType (arithOp.getType ());
415
+ if (!newTy)
416
+ return rewriter.notifyMatchFailure (arithOp,
417
+ " converting result type failed" );
418
+ rewriter.template replaceOpWithNewOp <EmitCOp>(arithOp, newTy,
397
419
adaptor.getOperands ());
398
420
399
421
return success ();
@@ -410,8 +432,10 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
410
432
ConversionPatternRewriter &rewriter) const override {
411
433
412
434
Type type = this ->getTypeConverter ()->convertType (op.getType ());
413
- if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
414
- return rewriter.notifyMatchFailure (op, " expected integer type" );
435
+ if (type && !(isa_and_nonnull<IntegerType>(type) ||
436
+ emitc::isPointerWideType (type))) {
437
+ return rewriter.notifyMatchFailure (
438
+ op, " expected integer or size_t/ssize_t/ptrdiff_t type" );
415
439
}
416
440
417
441
if (type.isInteger (1 )) {
@@ -606,6 +630,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
606
630
RewritePatternSet &patterns) {
607
631
MLIRContext *ctx = patterns.getContext ();
608
632
633
+ mlir::populateEmitCSizeTTypeConversions (typeConverter);
634
+
609
635
// clang-format off
610
636
patterns.add <
611
637
ArithConstantOpConversionPattern,
@@ -629,6 +655,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
629
655
UnsignedCastConversion<arith::TruncIOp>,
630
656
SignedCastConversion<arith::ExtSIOp>,
631
657
UnsignedCastConversion<arith::ExtUIOp>,
658
+ SignedCastConversion<arith::IndexCastOp>,
659
+ UnsignedCastConversion<arith::IndexCastUIOp>,
632
660
ItoFCastOpConversion<arith::SIToFPOp>,
633
661
ItoFCastOpConversion<arith::UIToFPOp>,
634
662
FtoICastOpConversion<arith::FPToSIOp>,
0 commit comments