Skip to content

[mlir][emitc] Arith to EmitC: Handle addi, subi and muli #86120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,55 @@ class ArithOpConversion final : public OpConversionPattern<ArithOp> {
}
};

template <typename ArithOp, typename EmitCOp>
class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
return rewriter.notifyMatchFailure(op, "expected integer type");
}

if (type.isInteger(1)) {
// arith expects wrap-around arithmethic, which doesn't happen on `bool`.
return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
}

Value lhs = adaptor.getLhs();
Value rhs = adaptor.getRhs();
Type arithmeticType = type;
if ((type.isSignlessInteger() || type.isSignedInteger()) &&
!bitEnumContainsAll(op.getOverflowFlags(),
arith::IntegerOverflowFlags::nsw)) {
// If the C type is signed and the op doesn't guarantee "No Signed Wrap",
// we compute in unsigned integers to avoid UB.
arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
if (arithmeticType != type) {
lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
lhs);
rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
rhs);
}

Value result = rewriter.template create<EmitCOp>(op.getLoc(),
arithmeticType, lhs, rhs);

if (arithmeticType != type) {
result =
rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
}
rewriter.replaceOp(op, result);
return success();
}
};

class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
public:
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
Expand Down Expand Up @@ -96,6 +145,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
ArithOpConversion<arith::SubFOp, emitc::SubOp>,
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
SelectOpConversion
>(typeConverter, ctx);
// clang-format on
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc-failed.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: mlir-opt -convert-arith-to-emitc %s -split-input-file -verify-diagnostics

func.func @bool(%arg0: i1, %arg1: i1) {
// expected-error@+1 {{failed to legalize operation 'arith.addi'}}
%0 = arith.addi %arg0, %arg1 : i1
return
}

// -----

func.func @vector(%arg0: vector<4xi32>, %arg1: vector<4xi32>) {
// expected-error@+1 {{failed to legalize operation 'arith.addi'}}
%0 = arith.addi %arg0, %arg1 : vector<4xi32>
return
}
51 changes: 51 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,57 @@ func.func @arith_ops(%arg0: f32, %arg1: f32) {

// -----

// CHECK-LABEL: arith_integer_ops
func.func @arith_integer_ops(%arg0: i32, %arg1: i32) {
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
// CHECK: %[[ADD:[^ ]*]] = emitc.add %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[ADD]] : ui32 to i32
%0 = arith.addi %arg0, %arg1 : i32
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
// CHECK: %[[SUB:[^ ]*]] = emitc.sub %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[SUB]] : ui32 to i32
%1 = arith.subi %arg0, %arg1 : i32
// CHECK: %[[C1:[^ ]*]] = emitc.cast %arg0 : i32 to ui32
// CHECK: %[[C2:[^ ]*]] = emitc.cast %arg1 : i32 to ui32
// CHECK: %[[MUL:[^ ]*]] = emitc.mul %[[C1]], %[[C2]] : (ui32, ui32) -> ui32
// CHECK: %[[C3:[^ ]*]] = emitc.cast %[[MUL]] : ui32 to i32
%2 = arith.muli %arg0, %arg1 : i32

return
}

// -----

// CHECK-LABEL: arith_integer_ops_signed_nsw
func.func @arith_integer_ops_signed_nsw(%arg0: i32, %arg1: i32) {
// CHECK: emitc.add %arg0, %arg1 : (i32, i32) -> i32
%0 = arith.addi %arg0, %arg1 overflow<nsw> : i32
// CHECK: emitc.sub %arg0, %arg1 : (i32, i32) -> i32
%1 = arith.subi %arg0, %arg1 overflow<nsw> : i32
// CHECK: emitc.mul %arg0, %arg1 : (i32, i32) -> i32
%2 = arith.muli %arg0, %arg1 overflow<nsw> : i32

return
}

// -----

// CHECK-LABEL: arith_index
func.func @arith_index(%arg0: index, %arg1: index) {
// CHECK: emitc.add %arg0, %arg1 : (index, index) -> index
%0 = arith.addi %arg0, %arg1 : index
// CHECK: emitc.sub %arg0, %arg1 : (index, index) -> index
%1 = arith.subi %arg0, %arg1 : index
// CHECK: emitc.mul %arg0, %arg1 : (index, index) -> index
%2 = arith.muli %arg0, %arg1 : index

return
}

// -----

func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -> () {
// CHECK: [[V0:[^ ]*]] = emitc.conditional %arg0, %arg1, %arg2 : tensor<8xi32>
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
Expand Down