@@ -39,6 +39,79 @@ class ArithConstantOpConversionPattern
39
39
}
40
40
};
41
41
42
+ class CmpIOpConversion : public OpConversionPattern <arith::CmpIOp> {
43
+ public:
44
+ using OpConversionPattern::OpConversionPattern;
45
+
46
+ bool needsUnsignedCmp (arith::CmpIPredicate pred) const {
47
+ switch (pred) {
48
+ case arith::CmpIPredicate::eq:
49
+ case arith::CmpIPredicate::ne:
50
+ case arith::CmpIPredicate::slt:
51
+ case arith::CmpIPredicate::sle:
52
+ case arith::CmpIPredicate::sgt:
53
+ case arith::CmpIPredicate::sge:
54
+ return false ;
55
+ case arith::CmpIPredicate::ult:
56
+ case arith::CmpIPredicate::ule:
57
+ case arith::CmpIPredicate::ugt:
58
+ case arith::CmpIPredicate::uge:
59
+ return true ;
60
+ }
61
+ llvm_unreachable (" unknown cmpi predicate kind" );
62
+ }
63
+
64
+ emitc::CmpPredicate toEmitCPred (arith::CmpIPredicate pred) const {
65
+ switch (pred) {
66
+ case arith::CmpIPredicate::eq:
67
+ return emitc::CmpPredicate::eq;
68
+ case arith::CmpIPredicate::ne:
69
+ return emitc::CmpPredicate::ne;
70
+ case arith::CmpIPredicate::slt:
71
+ case arith::CmpIPredicate::ult:
72
+ return emitc::CmpPredicate::lt;
73
+ case arith::CmpIPredicate::sle:
74
+ case arith::CmpIPredicate::ule:
75
+ return emitc::CmpPredicate::le;
76
+ case arith::CmpIPredicate::sgt:
77
+ case arith::CmpIPredicate::ugt:
78
+ return emitc::CmpPredicate::gt;
79
+ case arith::CmpIPredicate::sge:
80
+ case arith::CmpIPredicate::uge:
81
+ return emitc::CmpPredicate::ge;
82
+ }
83
+ llvm_unreachable (" unknown cmpi predicate kind" );
84
+ }
85
+
86
+ LogicalResult
87
+ matchAndRewrite (arith::CmpIOp op, OpAdaptor adaptor,
88
+ ConversionPatternRewriter &rewriter) const override {
89
+
90
+ Type type = adaptor.getLhs ().getType ();
91
+ if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
92
+ return rewriter.notifyMatchFailure (op, " expected integer or index type" );
93
+ }
94
+
95
+ bool needsUnsigned = needsUnsignedCmp (op.getPredicate ());
96
+ emitc::CmpPredicate pred = toEmitCPred (op.getPredicate ());
97
+ Type arithmeticType = type;
98
+ if (type.isUnsignedInteger () != needsUnsigned) {
99
+ arithmeticType = rewriter.getIntegerType (type.getIntOrFloatBitWidth (),
100
+ /* isSigned=*/ !needsUnsigned);
101
+ }
102
+ Value lhs = adaptor.getLhs ();
103
+ Value rhs = adaptor.getRhs ();
104
+ if (arithmeticType != type) {
105
+ lhs = rewriter.template create <emitc::CastOp>(op.getLoc (), arithmeticType,
106
+ lhs);
107
+ rhs = rewriter.template create <emitc::CastOp>(op.getLoc (), arithmeticType,
108
+ rhs);
109
+ }
110
+ rewriter.replaceOpWithNewOp <emitc::CmpOp>(op, op.getType (), pred, lhs, rhs);
111
+ return success ();
112
+ }
113
+ };
114
+
42
115
template <typename ArithOp, typename EmitCOp>
43
116
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
44
117
public:
@@ -148,6 +221,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
148
221
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
149
222
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
150
223
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
224
+ CmpIOpConversion,
151
225
SelectOpConversion
152
226
>(typeConverter, ctx);
153
227
// clang-format on
0 commit comments