Skip to content

Commit 0ddb122

Browse files
authored
[mlir][emitc] Arith to EmitC conversion: constants (#83798)
* Add a conversion from `arith.constant` to `emitc.constant`. * Drop the translation for `arith.constant`s.
1 parent 9f5be5f commit 0ddb122

File tree

13 files changed

+70
-66
lines changed

13 files changed

+70
-66
lines changed

mlir/docs/Dialects/emitc.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,3 @@ translating the following operations:
3030
* `func.call`
3131
* `func.func`
3232
* `func.return`
33-
* 'arith' Dialect
34-
* `arith.constant`

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,21 @@ using namespace mlir;
2424
//===----------------------------------------------------------------------===//
2525

2626
namespace {
27+
class ArithConstantOpConversionPattern
28+
: public OpConversionPattern<arith::ConstantOp> {
29+
public:
30+
using OpConversionPattern::OpConversionPattern;
31+
32+
LogicalResult
33+
matchAndRewrite(arith::ConstantOp arithConst,
34+
arith::ConstantOp::Adaptor adaptor,
35+
ConversionPatternRewriter &rewriter) const override {
36+
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
37+
arithConst, arithConst.getType(), adaptor.getValue());
38+
return success();
39+
}
40+
};
41+
2742
template <typename ArithOp, typename EmitCOp>
2843
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
2944
public:
@@ -51,6 +66,7 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
5166

5267
// clang-format off
5368
patterns.add<
69+
ArithConstantOpConversionPattern,
5470
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
5571
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
5672
ArithOpConversion<arith::MulFOp, emitc::MulOp>,

mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ void ConvertArithToEmitC::runOnOperation() {
3838

3939
target.addLegalDialect<emitc::EmitCDialect>();
4040
target.addIllegalDialect<arith::ArithDialect>();
41-
target.addLegalOp<arith::ConstantOp>();
4241

4342
RewritePatternSet patterns(&getContext());
4443

mlir/lib/Target/Cpp/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ add_mlir_translation_library(MLIRTargetCpp
66
${EMITC_MAIN_INCLUDE_DIR}/emitc/Target/Cpp
77

88
LINK_LIBS PUBLIC
9-
MLIRArithDialect
109
MLIRControlFlowDialect
1110
MLIREmitCDialect
1211
MLIRFuncDialect

mlir/lib/Target/Cpp/TranslateRegistration.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Dialect/Arith/IR/Arith.h"
109
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
1110
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1211
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -41,8 +40,7 @@ void registerToCppTranslation() {
4140
},
4241
[](DialectRegistry &registry) {
4342
// clang-format off
44-
registry.insert<arith::ArithDialect,
45-
cf::ControlFlowDialect,
43+
registry.insert<cf::ControlFlowDialect,
4644
emitc::EmitCDialect,
4745
func::FuncDialect,
4846
math::MathDialect,

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Dialect/Arith/IR/Arith.h"
109
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1110
#include "mlir/Dialect/EmitC/IR/EmitC.h"
1211
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -335,14 +334,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
335334
return printConstantOp(emitter, operation, value);
336335
}
337336

338-
static LogicalResult printOperation(CppEmitter &emitter,
339-
arith::ConstantOp constantOp) {
340-
Operation *operation = constantOp.getOperation();
341-
Attribute value = constantOp.getValue();
342-
343-
return printConstantOp(emitter, operation, value);
344-
}
345-
346337
static LogicalResult printOperation(CppEmitter &emitter,
347338
emitc::AssignOp assignOp) {
348339
auto variableOp = cast<emitc::VariableOp>(assignOp.getVar().getDefiningOp());
@@ -1391,9 +1382,6 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
13911382
// Func ops.
13921383
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
13931384
[&](auto op) { return printOperation(*this, op); })
1394-
// Arithmetic ops.
1395-
.Case<arith::ConstantOp>(
1396-
[&](auto op) { return printOperation(*this, op); })
13971385
.Case<emitc::LiteralOp>([&](auto op) { return success(); })
13981386
.Default([&](Operation *) {
13991387
return op.emitOpError("unable to find printer for op");

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,26 @@
1-
// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s
1+
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc %s | FileCheck %s
2+
3+
// CHECK-LABEL: arith_constants
4+
func.func @arith_constants() {
5+
// CHECK: emitc.constant
6+
// CHECK-SAME: value = 0 : index
7+
%c_index = arith.constant 0 : index
8+
// CHECK: emitc.constant
9+
// CHECK-SAME: value = 0 : i32
10+
%c_signless_int_32 = arith.constant 0 : i32
11+
// CHECK: emitc.constant
12+
// CHECK-SAME: value = 0.{{0+}}e+00 : f32
13+
%c_float_32 = arith.constant 0.0 : f32
14+
// CHECK: emitc.constant
15+
// CHECK-SAME: value = dense<0> : tensor<i32>
16+
%c_tensor_single_value = arith.constant dense<0> : tensor<i32>
17+
// CHECK: emitc.constant
18+
// CHECK-SAME: value{{.*}}[1, 2], [-3, 9], [0, 0], [2, -1]{{.*}}tensor<4x2xi64>
19+
%c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64>
20+
return
21+
}
22+
23+
// -----
224

325
func.func @arith_ops(%arg0: f32, %arg1: f32) {
426
// CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32

mlir/test/Target/Cpp/call.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func.func @emitc_call_opaque() {
1818

1919

2020
func.func @emitc_call_opaque_two_results() {
21-
%0 = arith.constant 0 : index
21+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> index
2222
%1:2 = emitc.call_opaque "two_results" () : () -> (i32, i32)
2323
return
2424
}

mlir/test/Target/Cpp/const.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ func.func @emitc_constant() {
88
%c3 = "emitc.constant"(){value = -1 : si8} : () -> si8
99
%c4 = "emitc.constant"(){value = 255 : ui8} : () -> ui8
1010
%c5 = "emitc.constant"(){value = #emitc.opaque<"CHAR_MIN">} : () -> !emitc.opaque<"char">
11+
%c6 = "emitc.constant"(){value = 2 : index} : () -> index
12+
%c7 = "emitc.constant"(){value = 2.0 : f32} : () -> f32
13+
%c8 = "emitc.constant"(){value = dense<0> : tensor<i32>} : () -> tensor<i32>
14+
%c9 = "emitc.constant"(){value = dense<[0, 1]> : tensor<2xindex>} : () -> tensor<2xindex>
15+
%c10 = "emitc.constant"(){value = dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
1116
return
1217
}
1318
// CPP-DEFAULT: void emitc_constant() {
@@ -17,6 +22,11 @@ func.func @emitc_constant() {
1722
// CPP-DEFAULT-NEXT: int8_t [[V3:[^ ]*]] = -1;
1823
// CPP-DEFAULT-NEXT: uint8_t [[V4:[^ ]*]] = 255;
1924
// CPP-DEFAULT-NEXT: char [[V5:[^ ]*]] = CHAR_MIN;
25+
// CPP-DEFAULT-NEXT: size_t [[V6:[^ ]*]] = 2;
26+
// CPP-DEFAULT-NEXT: float [[V7:[^ ]*]] = (float)2.000000000e+00;
27+
// CPP-DEFAULT-NEXT: Tensor<int32_t> [[V8:[^ ]*]] = {0};
28+
// CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]] = {0, 1};
29+
// CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00};
2030

2131
// CPP-DECLTOP: void emitc_constant() {
2232
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
@@ -25,9 +35,19 @@ func.func @emitc_constant() {
2535
// CPP-DECLTOP-NEXT: int8_t [[V3:[^ ]*]];
2636
// CPP-DECLTOP-NEXT: uint8_t [[V4:[^ ]*]];
2737
// CPP-DECLTOP-NEXT: char [[V5:[^ ]*]];
38+
// CPP-DECLTOP-NEXT: size_t [[V6:[^ ]*]];
39+
// CPP-DECLTOP-NEXT: float [[V7:[^ ]*]];
40+
// CPP-DECLTOP-NEXT: Tensor<int32_t> [[V8:[^ ]*]];
41+
// CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V9:[^ ]*]];
42+
// CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V10:[^ ]*]];
2843
// CPP-DECLTOP-NEXT: [[V0]] = INT_MAX;
2944
// CPP-DECLTOP-NEXT: [[V1]] = 42;
3045
// CPP-DECLTOP-NEXT: [[V2]] = -1;
3146
// CPP-DECLTOP-NEXT: [[V3]] = -1;
3247
// CPP-DECLTOP-NEXT: [[V4]] = 255;
3348
// CPP-DECLTOP-NEXT: [[V5]] = CHAR_MIN;
49+
// CPP-DECLTOP-NEXT: [[V6]] = 2;
50+
// CPP-DECLTOP-NEXT: [[V7]] = (float)2.000000000e+00;
51+
// CPP-DECLTOP-NEXT: [[V8]] = {0};
52+
// CPP-DECLTOP-NEXT: [[V9]] = {0, 1};
53+
// CPP-DECLTOP-NEXT: [[V10]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00};

mlir/test/Target/Cpp/for.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ func.func @test_for(%arg0 : index, %arg1 : index, %arg2 : index) {
3333
// CPP-DECLTOP-NEXT: return;
3434

3535
func.func @test_for_yield() {
36-
%start = arith.constant 0 : index
37-
%stop = arith.constant 10 : index
38-
%step = arith.constant 1 : index
36+
%start = "emitc.constant"() <{value = 0 : index}> : () -> index
37+
%stop = "emitc.constant"() <{value = 10 : index}> : () -> index
38+
%step = "emitc.constant"() <{value = 1 : index}> : () -> index
3939

40-
%s0 = arith.constant 0 : i32
41-
%p0 = arith.constant 1.0 : f32
40+
%s0 = "emitc.constant"() <{value = 0 : i32}> : () -> i32
41+
%p0 = "emitc.constant"() <{value = 1.0 : f32}> : () -> f32
4242

4343
%0 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
4444
%1 = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32

mlir/test/Target/Cpp/if.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func.func @test_if_else(%arg0: i1, %arg1: f32) {
4949

5050

5151
func.func @test_if_yield(%arg0: i1, %arg1: f32) {
52-
%0 = arith.constant 0 : i8
52+
%0 = "emitc.constant"() <{value = 0 : i8}> : () -> i8
5353
%x = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32
5454
%y = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f64
5555
emitc.if %arg0 {

mlir/test/Target/Cpp/stdops.mlir

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,6 @@
11
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
22
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
33

4-
func.func @std_constant() {
5-
%c0 = arith.constant 0 : i32
6-
%c1 = arith.constant 2 : index
7-
%c2 = arith.constant 2.0 : f32
8-
%c3 = arith.constant dense<0> : tensor<i32>
9-
%c4 = arith.constant dense<[0, 1]> : tensor<2xindex>
10-
%c5 = arith.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
11-
return
12-
}
13-
// CPP-DEFAULT: void std_constant() {
14-
// CPP-DEFAULT-NEXT: int32_t [[V0:[^ ]*]] = 0;
15-
// CPP-DEFAULT-NEXT: size_t [[V1:[^ ]*]] = 2;
16-
// CPP-DEFAULT-NEXT: float [[V2:[^ ]*]] = (float)2.000000000e+00;
17-
// CPP-DEFAULT-NEXT: Tensor<int32_t> [[V3:[^ ]*]] = {0};
18-
// CPP-DEFAULT-NEXT: Tensor<size_t, 2> [[V4:[^ ]*]] = {0, 1};
19-
// CPP-DEFAULT-NEXT: Tensor<float, 2, 2> [[V5:[^ ]*]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00};
20-
21-
// CPP-DECLTOP: void std_constant() {
22-
// CPP-DECLTOP-NEXT: int32_t [[V0:[^ ]*]];
23-
// CPP-DECLTOP-NEXT: size_t [[V1:[^ ]*]];
24-
// CPP-DECLTOP-NEXT: float [[V2:[^ ]*]];
25-
// CPP-DECLTOP-NEXT: Tensor<int32_t> [[V3:[^ ]*]];
26-
// CPP-DECLTOP-NEXT: Tensor<size_t, 2> [[V4:[^ ]*]];
27-
// CPP-DECLTOP-NEXT: Tensor<float, 2, 2> [[V5:[^ ]*]];
28-
// CPP-DECLTOP-NEXT: [[V0]] = 0;
29-
// CPP-DECLTOP-NEXT: [[V1]] = 2;
30-
// CPP-DECLTOP-NEXT: [[V2]] = (float)2.000000000e+00;
31-
// CPP-DECLTOP-NEXT: [[V3]] = {0};
32-
// CPP-DECLTOP-NEXT: [[V4]] = {0, 1};
33-
// CPP-DECLTOP-NEXT: [[V5]] = {(float)0.0e+00, (float)1.000000000e+00, (float)2.000000000e+00, (float)3.000000000e+00};
34-
354
func.func @std_call() {
365
%0 = call @one_result () : () -> i32
376
%1 = call @one_result () : () -> i32
@@ -49,13 +18,11 @@ func.func @std_call() {
4918

5019

5120
func.func @std_call_two_results() {
52-
%c = arith.constant 0 : i8
5321
%0:2 = call @two_results () : () -> (i32, f32)
5422
%1:2 = call @two_results () : () -> (i32, f32)
5523
return
5624
}
5725
// CPP-DEFAULT: void std_call_two_results() {
58-
// CPP-DEFAULT-NEXT: int8_t [[V0:[^ ]*]] = 0;
5926
// CPP-DEFAULT-NEXT: int32_t [[V1:[^ ]*]];
6027
// CPP-DEFAULT-NEXT: float [[V2:[^ ]*]];
6128
// CPP-DEFAULT-NEXT: std::tie([[V1]], [[V2]]) = two_results();
@@ -64,18 +31,16 @@ func.func @std_call_two_results() {
6431
// CPP-DEFAULT-NEXT: std::tie([[V3]], [[V4]]) = two_results();
6532

6633
// CPP-DECLTOP: void std_call_two_results() {
67-
// CPP-DECLTOP-NEXT: int8_t [[V0:[^ ]*]];
6834
// CPP-DECLTOP-NEXT: int32_t [[V1:[^ ]*]];
6935
// CPP-DECLTOP-NEXT: float [[V2:[^ ]*]];
7036
// CPP-DECLTOP-NEXT: int32_t [[V3:[^ ]*]];
7137
// CPP-DECLTOP-NEXT: float [[V4:[^ ]*]];
72-
// CPP-DECLTOP-NEXT: [[V0]] = 0;
7338
// CPP-DECLTOP-NEXT: std::tie([[V1]], [[V2]]) = two_results();
7439
// CPP-DECLTOP-NEXT: std::tie([[V3]], [[V4]]) = two_results();
7540

7641

7742
func.func @one_result() -> i32 {
78-
%0 = arith.constant 0 : i32
43+
%0 = "emitc.constant"() <{value = 0 : i32}> : () -> i32
7944
return %0 : i32
8045
}
8146
// CPP-DEFAULT: int32_t one_result() {
@@ -89,8 +54,8 @@ func.func @one_result() -> i32 {
8954

9055

9156
func.func @two_results() -> (i32, f32) {
92-
%0 = arith.constant 0 : i32
93-
%1 = arith.constant 1.0 : f32
57+
%0 = "emitc.constant"() <{value = 0 : i32}> : () -> i32
58+
%1 = "emitc.constant"() <{value = 1.0 : f32}> : () -> f32
9459
return %0, %1 : i32, f32
9560
}
9661
// CPP-DEFAULT: std::tuple<int32_t, float> two_results() {

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1769,7 +1769,6 @@ cc_library(
17691769
]),
17701770
hdrs = glob(["include/mlir/Target/Cpp/*.h"]),
17711771
deps = [
1772-
":ArithDialect",
17731772
":ControlFlowDialect",
17741773
":EmitCDialect",
17751774
":FuncDialect",

0 commit comments

Comments
 (0)