Skip to content

Commit 19266ca

Browse files
authored
[mlir][EmitC] Add an emitc.conditional operator (#84883)
This adds an `emitc.conditional` operation for the ternary conditional operator. Furthermore, this adds a converion from `arith.select` to the new op.
1 parent 9997e03 commit 19266ca

File tree

6 files changed

+110
-7
lines changed

6 files changed

+110
-7
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
908908
let hasVerifier = 1;
909909
}
910910

911+
def EmitC_ConditionalOp : EmitC_Op<"conditional",
912+
[AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
913+
let summary = "Conditional (ternary) operation";
914+
let description = [{
915+
With the `conditional` operation the ternary conditional operator can
916+
be applied.
917+
918+
Example:
919+
920+
```mlir
921+
%0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1
922+
923+
%c0 = "emitc.constant"() {value = 10 : i32} : () -> i32
924+
%c1 = "emitc.constant"() {value = 11 : i32} : () -> i32
925+
926+
%1 = emitc.conditional %0, %c0, %c1 : i32
927+
```
928+
```c++
929+
// Code emitted for the operations above.
930+
bool v3 = v1 > v2;
931+
int32_t v4 = 10;
932+
int32_t v5 = 11;
933+
int32_t v6 = v3 ? v4 : v5;
934+
```
935+
}];
936+
let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
937+
let results = (outs AnyType:$result);
938+
let assemblyFormat = "operands attr-dict `:` type($result)";
939+
}
940+
911941
def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
912942
let summary = "Unary minus operation";
913943
let description = [{

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,31 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
5454
return success();
5555
}
5656
};
57+
58+
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
59+
public:
60+
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
61+
62+
LogicalResult
63+
matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
64+
ConversionPatternRewriter &rewriter) const override {
65+
66+
Type dstType = getTypeConverter()->convertType(selectOp.getType());
67+
if (!dstType)
68+
return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
69+
70+
if (!adaptor.getCondition().getType().isInteger(1))
71+
return rewriter.notifyMatchFailure(
72+
selectOp,
73+
"can only be converted if condition is a scalar of type i1");
74+
75+
rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
76+
adaptor.getOperands());
77+
78+
return success();
79+
}
80+
};
81+
5782
} // namespace
5883

5984
//===----------------------------------------------------------------------===//
@@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
7095
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
7196
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
7297
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
73-
ArithOpConversion<arith::SubFOp, emitc::SubOp>
98+
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
99+
SelectOpConversion
74100
>(typeConverter, ctx);
75101
// clang-format on
76102
}

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
9696
}
9797
return op->emitError("unsupported cmp predicate");
9898
})
99+
.Case<emitc::ConditionalOp>([&](auto op) { return 2; })
99100
.Case<emitc::DivOp>([&](auto op) { return 13; })
100101
.Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
101102
.Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
@@ -446,6 +447,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
446447
return printBinaryOperation(emitter, operation, binaryOperator);
447448
}
448449

450+
static LogicalResult printOperation(CppEmitter &emitter,
451+
emitc::ConditionalOp conditionalOp) {
452+
raw_ostream &os = emitter.ostream();
453+
454+
if (failed(emitter.emitAssignPrefix(*conditionalOp)))
455+
return failure();
456+
457+
if (failed(emitter.emitOperand(conditionalOp.getCondition())))
458+
return failure();
459+
460+
os << " ? ";
461+
462+
if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
463+
return failure();
464+
465+
os << " : ";
466+
467+
if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
468+
return failure();
469+
470+
return success();
471+
}
472+
449473
static LogicalResult printOperation(CppEmitter &emitter,
450474
emitc::VerbatimOp verbatimOp) {
451475
raw_ostream &os = emitter.ostream();
@@ -1383,12 +1407,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
13831407
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
13841408
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
13851409
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1386-
emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
1387-
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
1388-
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1389-
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1390-
emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
1391-
emitc::VariableOp, emitc::VerbatimOp>(
1410+
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1411+
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1412+
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
1413+
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
1414+
emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1415+
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1416+
emitc::VerbatimOp>(
13921417
[&](auto op) { return printOperation(*this, op); })
13931418
// Func ops.
13941419
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
3434

3535
return
3636
}
37+
38+
// -----
39+
40+
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
41+
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
42+
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
43+
return
44+
}

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () {
7171
return
7272
}
7373

74+
func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
75+
%0 = emitc.conditional %cond, %arg0, %arg1 : i32
76+
return
77+
}
78+
7479
func.func @div_int(%arg0: i32, %arg1: i32) {
7580
%1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32
7681
return

mlir/test/Target/Cpp/conditional.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
3+
func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
4+
%0 = emitc.conditional %cond, %arg0, %arg1 : i32
5+
return
6+
}
7+
8+
// CHECK-LABEL: void cond
9+
// CHECK-NEXT: int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];

0 commit comments

Comments
 (0)