Skip to content

Commit 44c876f

Browse files
authored
[mlir][emitc] Add EmitC lowering for arith.cmpi (#88700)
1 parent a1e7c83 commit 44c876f

File tree

2 files changed

+122
-0
lines changed

2 files changed

+122
-0
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,79 @@ class ArithConstantOpConversionPattern
3939
}
4040
};
4141

42+
class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
43+
public:
44+
using OpConversionPattern::OpConversionPattern;
45+
46+
bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
47+
switch (pred) {
48+
case arith::CmpIPredicate::eq:
49+
case arith::CmpIPredicate::ne:
50+
case arith::CmpIPredicate::slt:
51+
case arith::CmpIPredicate::sle:
52+
case arith::CmpIPredicate::sgt:
53+
case arith::CmpIPredicate::sge:
54+
return false;
55+
case arith::CmpIPredicate::ult:
56+
case arith::CmpIPredicate::ule:
57+
case arith::CmpIPredicate::ugt:
58+
case arith::CmpIPredicate::uge:
59+
return true;
60+
}
61+
llvm_unreachable("unknown cmpi predicate kind");
62+
}
63+
64+
emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
65+
switch (pred) {
66+
case arith::CmpIPredicate::eq:
67+
return emitc::CmpPredicate::eq;
68+
case arith::CmpIPredicate::ne:
69+
return emitc::CmpPredicate::ne;
70+
case arith::CmpIPredicate::slt:
71+
case arith::CmpIPredicate::ult:
72+
return emitc::CmpPredicate::lt;
73+
case arith::CmpIPredicate::sle:
74+
case arith::CmpIPredicate::ule:
75+
return emitc::CmpPredicate::le;
76+
case arith::CmpIPredicate::sgt:
77+
case arith::CmpIPredicate::ugt:
78+
return emitc::CmpPredicate::gt;
79+
case arith::CmpIPredicate::sge:
80+
case arith::CmpIPredicate::uge:
81+
return emitc::CmpPredicate::ge;
82+
}
83+
llvm_unreachable("unknown cmpi predicate kind");
84+
}
85+
86+
LogicalResult
87+
matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
88+
ConversionPatternRewriter &rewriter) const override {
89+
90+
Type type = adaptor.getLhs().getType();
91+
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
92+
return rewriter.notifyMatchFailure(op, "expected integer or index type");
93+
}
94+
95+
bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
96+
emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
97+
Type arithmeticType = type;
98+
if (type.isUnsignedInteger() != needsUnsigned) {
99+
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
100+
/*isSigned=*/!needsUnsigned);
101+
}
102+
Value lhs = adaptor.getLhs();
103+
Value rhs = adaptor.getRhs();
104+
if (arithmeticType != type) {
105+
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
106+
lhs);
107+
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
108+
rhs);
109+
}
110+
rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
111+
return success();
112+
}
113+
};
114+
42115
template <typename ArithOp, typename EmitCOp>
43116
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
44117
public:
@@ -148,6 +221,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
148221
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
149222
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
150223
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
224+
CmpIOpConversion,
151225
SelectOpConversion
152226
>(typeConverter, ctx);
153227
// clang-format on

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,51 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -
9393
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
9494
return
9595
}
96+
97+
// -----
98+
99+
func.func @arith_cmpi_eq(%arg0: i32, %arg1: i32) -> i1 {
100+
// CHECK-LABEL: arith_cmpi_eq
101+
// CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32)
102+
// CHECK-DAG: [[EQ:[^ ]*]] = emitc.cmp eq, [[Arg0]], [[Arg1]] : (i32, i32) -> i1
103+
%eq = arith.cmpi eq, %arg0, %arg1 : i32
104+
// CHECK: return [[EQ]]
105+
return %eq: i1
106+
}
107+
108+
func.func @arith_cmpi_ult(%arg0: i32, %arg1: i32) -> i1 {
109+
// CHECK-LABEL: arith_cmpi_ult
110+
// CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32)
111+
// CHECK-DAG: [[CastArg0:[^ ]*]] = emitc.cast [[Arg0]] : i32 to ui32
112+
// CHECK-DAG: [[CastArg1:[^ ]*]] = emitc.cast [[Arg1]] : i32 to ui32
113+
// CHECK-DAG: [[ULT:[^ ]*]] = emitc.cmp lt, [[CastArg0]], [[CastArg1]] : (ui32, ui32) -> i1
114+
%ult = arith.cmpi ult, %arg0, %arg1 : i32
115+
116+
// CHECK: return [[ULT]]
117+
return %ult: i1
118+
}
119+
120+
func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
121+
// CHECK: emitc.cmp lt, {{.*}} : (ui32, ui32) -> i1
122+
%ult = arith.cmpi ult, %arg0, %arg1 : i32
123+
// CHECK: emitc.cmp lt, {{.*}} : (i32, i32) -> i1
124+
%slt = arith.cmpi slt, %arg0, %arg1 : i32
125+
// CHECK: emitc.cmp le, {{.*}} : (ui32, ui32) -> i1
126+
%ule = arith.cmpi ule, %arg0, %arg1 : i32
127+
// CHECK: emitc.cmp le, {{.*}} : (i32, i32) -> i1
128+
%sle = arith.cmpi sle, %arg0, %arg1 : i32
129+
// CHECK: emitc.cmp gt, {{.*}} : (ui32, ui32) -> i1
130+
%ugt = arith.cmpi ugt, %arg0, %arg1 : i32
131+
// CHECK: emitc.cmp gt, {{.*}} : (i32, i32) -> i1
132+
%sgt = arith.cmpi sgt, %arg0, %arg1 : i32
133+
// CHECK: emitc.cmp ge, {{.*}} : (ui32, ui32) -> i1
134+
%uge = arith.cmpi uge, %arg0, %arg1 : i32
135+
// CHECK: emitc.cmp ge, {{.*}} : (i32, i32) -> i1
136+
%sge = arith.cmpi sge, %arg0, %arg1 : i32
137+
// CHECK: emitc.cmp eq, {{.*}} : (i32, i32) -> i1
138+
%eq = arith.cmpi eq, %arg0, %arg1 : i32
139+
// CHECK: emitc.cmp ne, {{.*}} : (i32, i32) -> i1
140+
%ne = arith.cmpi ne, %arg0, %arg1 : i32
141+
142+
return
143+
}

0 commit comments

Comments
 (0)