Skip to content

Commit d7e2bb9

Browse files
marbremgehre-amd
authored andcommitted
[mlir][EmitC] Add an emitc.conditional operator (llvm#84883)
This adds an `emitc.conditional` operation for the ternary conditional operator. Furthermore, this adds a converion from `arith.select` to the new op.
1 parent 86a976b commit d7e2bb9

File tree

6 files changed

+110
-7
lines changed

6 files changed

+110
-7
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
908908
let hasVerifier = 1;
909909
}
910910

911+
def EmitC_ConditionalOp : EmitC_Op<"conditional",
912+
[AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
913+
let summary = "Conditional (ternary) operation";
914+
let description = [{
915+
With the `conditional` operation the ternary conditional operator can
916+
be applied.
917+
918+
Example:
919+
920+
```mlir
921+
%0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1
922+
923+
%c0 = "emitc.constant"() {value = 10 : i32} : () -> i32
924+
%c1 = "emitc.constant"() {value = 11 : i32} : () -> i32
925+
926+
%1 = emitc.conditional %0, %c0, %c1 : i32
927+
```
928+
```c++
929+
// Code emitted for the operations above.
930+
bool v3 = v1 > v2;
931+
int32_t v4 = 10;
932+
int32_t v5 = 11;
933+
int32_t v6 = v3 ? v4 : v5;
934+
```
935+
}];
936+
let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
937+
let results = (outs AnyType:$result);
938+
let assemblyFormat = "operands attr-dict `:` type($result)";
939+
}
940+
911941
def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
912942
let summary = "Unary minus operation";
913943
let description = [{

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,31 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
5454
return success();
5555
}
5656
};
57+
58+
class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
59+
public:
60+
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
61+
62+
LogicalResult
63+
matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
64+
ConversionPatternRewriter &rewriter) const override {
65+
66+
Type dstType = getTypeConverter()->convertType(selectOp.getType());
67+
if (!dstType)
68+
return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
69+
70+
if (!adaptor.getCondition().getType().isInteger(1))
71+
return rewriter.notifyMatchFailure(
72+
selectOp,
73+
"can only be converted if condition is a scalar of type i1");
74+
75+
rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
76+
adaptor.getOperands());
77+
78+
return success();
79+
}
80+
};
81+
5782
} // namespace
5883

5984
//===----------------------------------------------------------------------===//
@@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
7095
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
7196
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
7297
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
73-
ArithOpConversion<arith::SubFOp, emitc::SubOp>
98+
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
99+
SelectOpConversion
74100
>(typeConverter, ctx);
75101
// clang-format on
76102
}

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
9696
}
9797
return op->emitError("unsupported cmp predicate");
9898
})
99+
.Case<emitc::ConditionalOp>([&](auto op) { return 2; })
99100
.Case<emitc::DivOp>([&](auto op) { return 13; })
100101
.Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
101102
.Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
@@ -455,6 +456,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
455456
return printBinaryOperation(emitter, operation, binaryOperator);
456457
}
457458

459+
static LogicalResult printOperation(CppEmitter &emitter,
460+
emitc::ConditionalOp conditionalOp) {
461+
raw_ostream &os = emitter.ostream();
462+
463+
if (failed(emitter.emitAssignPrefix(*conditionalOp)))
464+
return failure();
465+
466+
if (failed(emitter.emitOperand(conditionalOp.getCondition())))
467+
return failure();
468+
469+
os << " ? ";
470+
471+
if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
472+
return failure();
473+
474+
os << " : ";
475+
476+
if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
477+
return failure();
478+
479+
return success();
480+
}
481+
458482
static LogicalResult printOperation(CppEmitter &emitter,
459483
emitc::VerbatimOp verbatimOp) {
460484
raw_ostream &os = emitter.ostream();
@@ -1406,12 +1430,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14061430
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
14071431
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
14081432
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1409-
emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
1410-
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
1411-
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1412-
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1413-
emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
1414-
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
1433+
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
1434+
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1435+
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
1436+
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
1437+
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
1438+
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1439+
emitc::VerbatimOp>(
14151440
[&](auto op) { return printOperation(*this, op); })
14161441
// Func ops.
14171442
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {
3434

3535
return
3636
}
37+
38+
// -----
39+
40+
func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
41+
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
42+
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
43+
return
44+
}

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () {
7171
return
7272
}
7373

74+
func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
75+
%0 = emitc.conditional %cond, %arg0, %arg1 : i32
76+
return
77+
}
78+
7479
func.func @div_int(%arg0: i32, %arg1: i32) {
7580
%1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32
7681
return

mlir/test/Target/Cpp/conditional.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
3+
func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
4+
%0 = emitc.conditional %cond, %arg0, %arg1 : i32
5+
return
6+
}
7+
8+
// CHECK-LABEL: void cond
9+
// CHECK-NEXT: int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];

0 commit comments

Comments
 (0)