Skip to content

[mlir][EmitC] Add an emitc.conditional operator #84883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,36 @@ def EmitC_SubOp : EmitC_BinaryOp<"sub", [CExpression]> {
let hasVerifier = 1;
}

def EmitC_ConditionalOp : EmitC_Op<"conditional",
[AllTypesMatch<["true_value", "false_value", "result"]>, CExpression]> {
let summary = "Conditional (ternary) operation";
let description = [{
With the `conditional` operation the ternary conditional operator can
be applied.

Example:

```mlir
%0 = emitc.cmp gt, %arg0, %arg1 : (i32, i32) -> i1

%c0 = "emitc.constant"() {value = 10 : i32} : () -> i32
%c1 = "emitc.constant"() {value = 11 : i32} : () -> i32

%1 = emitc.conditional %0, %c0, %c1 : i32
```
```c++
// Code emitted for the operations above.
bool v3 = v1 > v2;
int32_t v4 = 10;
int32_t v5 = 11;
int32_t v6 = v3 ? v4 : v5;
```
}];
let arguments = (ins I1:$condition, AnyType:$true_value, AnyType:$false_value);
let results = (outs AnyType:$result);
let assemblyFormat = "operands attr-dict `:` type($result)";
}

def EmitC_UnaryMinusOp : EmitC_UnaryOp<"unary_minus", [CExpression]> {
let summary = "Unary minus operation";
let description = [{
Expand Down
28 changes: 27 additions & 1 deletion mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,31 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
return success();
}
};

class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type dstType = getTypeConverter()->convertType(selectOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(selectOp, "type conversion failed");

if (!adaptor.getCondition().getType().isInteger(1))
return rewriter.notifyMatchFailure(
selectOp,
"can only be converted if condition is a scalar of type i1");

rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
adaptor.getOperands());

return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand All @@ -70,7 +95,8 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
ArithOpConversion<arith::SubFOp, emitc::SubOp>
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
SelectOpConversion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SelectOpConversion
ArithSelectOpConversion

Maybe rename it such that is more consistent to the rest?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intend behind the naming ArithOp is that this is a conversion for a "genric" arith op. The ArithToEmitC already indicates that this file only contains conversions from the Arith to the EmitC dialect.

>(typeConverter, ctx);
// clang-format on
}
37 changes: 31 additions & 6 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ static FailureOr<int> getOperatorPrecedence(Operation *operation) {
}
return op->emitError("unsupported cmp predicate");
})
.Case<emitc::ConditionalOp>([&](auto op) { return 2; })
.Case<emitc::DivOp>([&](auto op) { return 13; })
.Case<emitc::LogicalAndOp>([&](auto op) { return 4; })
.Case<emitc::LogicalNotOp>([&](auto op) { return 15; })
Expand Down Expand Up @@ -446,6 +447,29 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
return printBinaryOperation(emitter, operation, binaryOperator);
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::ConditionalOp conditionalOp) {
raw_ostream &os = emitter.ostream();

if (failed(emitter.emitAssignPrefix(*conditionalOp)))
return failure();

if (failed(emitter.emitOperand(conditionalOp.getCondition())))
return failure();

os << " ? ";

if (failed(emitter.emitOperand(conditionalOp.getTrueValue())))
return failure();

os << " : ";

if (failed(emitter.emitOperand(conditionalOp.getFalseValue())))
return failure();

return success();
}

static LogicalResult printOperation(CppEmitter &emitter,
emitc::VerbatimOp verbatimOp) {
raw_ostream &os = emitter.ostream();
Expand Down Expand Up @@ -1383,12 +1407,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::BitwiseNotOp, emitc::BitwiseOrOp,
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
emitc::SubOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
emitc::VariableOp, emitc::VerbatimOp>(
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp,
emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,11 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {

return
}

// -----

func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
return
}
5 changes: 5 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ func.func @bitwise(%arg0: i32, %arg1: i32) -> () {
return
}

func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
%0 = emitc.conditional %cond, %arg0, %arg1 : i32
return
}

func.func @div_int(%arg0: i32, %arg1: i32) {
%1 = "emitc.div" (%arg0, %arg1) : (i32, i32) -> i32
return
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Target/Cpp/conditional.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s

func.func @cond(%cond: i1, %arg0: i32, %arg1: i32) -> () {
%0 = emitc.conditional %cond, %arg0, %arg1 : i32
return
}

// CHECK-LABEL: void cond
// CHECK-NEXT: int32_t [[V3:[^ ]*]] = [[V0:[^ ]*]] ? [[V1:[^ ]*]] : [[V2:[^ ]*]];