Skip to content

Commit 9951b0f

Browse files
committed
[mlir][EmitC] Add Arith to EmitC conversions
This adds patterns and a pass to convert the Arith dialect to EmitC. For now, this covers arithemtic binary ops operating on floating point types. It is not checked within the patterns whether the types, such as the Tensor type, are supported in the respective EmitC operations. If unsupported types should be converted, the conversion will fail anyway because no legal EmitC operation can be created. This can clearly be improved in a follow up, also resulting in better error messages. Functions for such checks should not solely be used in the conversions and should also be (re)used in the verifier.
1 parent eaf0d82 commit 9951b0f

File tree

9 files changed

+199
-0
lines changed

9 files changed

+199
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
10+
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
11+
12+
namespace mlir {
13+
class RewritePatternSet;
14+
15+
void populateArithToEmitCPatterns(RewritePatternSet &patterns);
16+
} // namespace mlir
17+
18+
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
10+
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_CONVERTARITHTOEMITC
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1414
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
1515
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
16+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
1617
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1718
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
1819
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
133133
];
134134
}
135135

136+
//===----------------------------------------------------------------------===//
137+
// ArithToEmitC
138+
//===----------------------------------------------------------------------===//
139+
140+
def ConvertArithToEmitC : Pass<"convert-arith-to-emitc", "ModuleOp"> {
141+
let summary = "Convert Arith dialect to EmitC dialect";
142+
let dependentDialects = ["emitc::EmitCDialect"];
143+
}
144+
136145
//===----------------------------------------------------------------------===//
137146
// ArithToLLVM
138147
//===----------------------------------------------------------------------===//
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns to convert the Arith dialect to the EmitC
10+
// dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
15+
16+
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
using namespace mlir;
21+
22+
//===----------------------------------------------------------------------===//
23+
// Conversion Patterns
24+
//===----------------------------------------------------------------------===//
25+
26+
namespace {
27+
template <typename ArithOp, typename EmitCOp>
28+
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
29+
public:
30+
using OpConversionPattern<ArithOp>::OpConversionPattern;
31+
32+
LogicalResult
33+
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
34+
ConversionPatternRewriter &rewriter) const override {
35+
36+
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
37+
adaptor.getOperands());
38+
39+
return success();
40+
}
41+
};
42+
} // namespace
43+
44+
//===----------------------------------------------------------------------===//
45+
// Pattern population
46+
//===----------------------------------------------------------------------===//
47+
48+
void mlir::populateArithToEmitCPatterns(RewritePatternSet &patterns) {
49+
MLIRContext *ctx = patterns.getContext();
50+
51+
// clang-format off
52+
patterns.add<
53+
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
54+
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
55+
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
56+
ArithOpConversion<arith::SubFOp, emitc::SubOp>
57+
>(ctx);
58+
// clang-format on
59+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
//===- ArithToEmitCPass.cpp - Func to EmitC Pass ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert the Arith dialect to the EmitC
10+
// dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
15+
16+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
19+
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Transforms/DialectConversion.h"
21+
22+
namespace mlir {
23+
#define GEN_PASS_DEF_CONVERTARITHTOEMITC
24+
#include "mlir/Conversion/Passes.h.inc"
25+
} // namespace mlir
26+
27+
using namespace mlir;
28+
29+
namespace {
30+
struct ConvertArithToEmitC
31+
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
32+
void runOnOperation() override;
33+
};
34+
} // namespace
35+
36+
void ConvertArithToEmitC::runOnOperation() {
37+
ConversionTarget target(getContext());
38+
39+
target.addLegalDialect<emitc::EmitCDialect>();
40+
41+
RewritePatternSet patterns(&getContext());
42+
populateArithToEmitCPatterns(patterns);
43+
44+
if (failed(
45+
applyPartialConversion(getOperation(), target, std::move(patterns))))
46+
signalPassFailure();
47+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_conversion_library(MLIRArithToEmitC
2+
ArithToEmitC.cpp
3+
ArithToEmitCPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
7+
8+
DEPENDS
9+
MLIRConversionPassIncGen
10+
11+
LINK_LIBS PUBLIC
12+
MLIRArithDialect
13+
MLIREmitCDialect
14+
MLIRPass
15+
MLIRTransformUtils
16+
)

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
33
add_subdirectory(ArithCommon)
44
add_subdirectory(ArithToAMDGPU)
55
add_subdirectory(ArithToArmSME)
6+
add_subdirectory(ArithToEmitC)
67
add_subdirectory(ArithToLLVM)
78
add_subdirectory(ArithToSPIRV)
89
add_subdirectory(ArmNeon2dToIntr)

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4011,6 +4011,7 @@ cc_library(
40114011
":AffineToStandard",
40124012
":ArithToAMDGPU",
40134013
":ArithToArmSME",
4014+
":ArithToEmitC",
40144015
":ArithToLLVM",
40154016
":ArithToSPIRV",
40164017
":ArmNeon2dToIntr",
@@ -8156,6 +8157,32 @@ cc_library(
81568157
],
81578158
)
81588159

8160+
cc_library(
8161+
name = "ArithToEmitC",
8162+
srcs = glob([
8163+
"lib/Conversion/ArithToEmitC/*.cpp",
8164+
"lib/Conversion/ArithToEmitC/*.h",
8165+
]),
8166+
hdrs = glob([
8167+
"include/mlir/Conversion/ArithToEmitC/*.h",
8168+
]),
8169+
includes = [
8170+
"include",
8171+
"lib/Conversion/ArithToEmitC",
8172+
],
8173+
deps = [
8174+
":ArithDialect",
8175+
":ConversionPassIncGen",
8176+
":EmitCDialect",
8177+
":IR",
8178+
":Pass",
8179+
":Support",
8180+
":TransformUtils",
8181+
":Transforms",
8182+
"//llvm:Support",
8183+
],
8184+
)
8185+
81598186
cc_library(
81608187
name = "ArithToLLVM",
81618188
srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),

0 commit comments

Comments
 (0)