Skip to content

Commit c40146c

Browse files
authored
[mlir][EmitC] Add Arith to EmitC conversions (#84151)
This adds patterns and a pass to convert the Arith dialect to EmitC. For now, this covers arithemtic binary ops operating on floating point types. It is not checked within the patterns whether the types, such as the Tensor type, are supported in the respective EmitC operations. If unsupported types should be converted, the conversion will fail anyway because no legal EmitC operation can be created. This can clearly be improved in a follow up, also resulting in better error messages. Functions for such checks should not solely be used in the conversions and should also be (re)used in the verifier.
1 parent 6f54a54 commit c40146c

File tree

10 files changed

+222
-0
lines changed

10 files changed

+222
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
10+
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
11+
12+
namespace mlir {
13+
class RewritePatternSet;
14+
class TypeConverter;
15+
16+
void populateArithToEmitCPatterns(TypeConverter &typeConverter,
17+
RewritePatternSet &patterns);
18+
} // namespace mlir
19+
20+
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
10+
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_CONVERTARITHTOEMITC
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1414
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
1515
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
16+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
1617
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1718
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
1819
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
133133
];
134134
}
135135

136+
//===----------------------------------------------------------------------===//
137+
// ArithToEmitC
138+
//===----------------------------------------------------------------------===//
139+
140+
def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
141+
let summary = "Convert Arith dialect to EmitC dialect";
142+
let dependentDialects = ["emitc::EmitCDialect"];
143+
}
144+
136145
//===----------------------------------------------------------------------===//
137146
// ArithToLLVM
138147
//===----------------------------------------------------------------------===//
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns to convert the Arith dialect to the EmitC
10+
// dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
15+
16+
#include "mlir/Dialect/Arith/IR/Arith.h"
17+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
using namespace mlir;
21+
22+
//===----------------------------------------------------------------------===//
23+
// Conversion Patterns
24+
//===----------------------------------------------------------------------===//
25+
26+
namespace {
27+
template <typename ArithOp, typename EmitCOp>
28+
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
29+
public:
30+
using OpConversionPattern<ArithOp>::OpConversionPattern;
31+
32+
LogicalResult
33+
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
34+
ConversionPatternRewriter &rewriter) const override {
35+
36+
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
37+
adaptor.getOperands());
38+
39+
return success();
40+
}
41+
};
42+
} // namespace
43+
44+
//===----------------------------------------------------------------------===//
45+
// Pattern population
46+
//===----------------------------------------------------------------------===//
47+
48+
void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
49+
RewritePatternSet &patterns) {
50+
MLIRContext *ctx = patterns.getContext();
51+
52+
// clang-format off
53+
patterns.add<
54+
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
55+
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
56+
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
57+
ArithOpConversion<arith::SubFOp, emitc::SubOp>
58+
>(typeConverter, ctx);
59+
// clang-format on
60+
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert the Arith dialect to the EmitC
10+
// dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
15+
16+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
19+
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Transforms/DialectConversion.h"
21+
22+
namespace mlir {
23+
#define GEN_PASS_DEF_CONVERTARITHTOEMITC
24+
#include "mlir/Conversion/Passes.h.inc"
25+
} // namespace mlir
26+
27+
using namespace mlir;
28+
29+
namespace {
30+
struct ConvertArithToEmitC
31+
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
32+
void runOnOperation() override;
33+
};
34+
} // namespace
35+
36+
void ConvertArithToEmitC::runOnOperation() {
37+
ConversionTarget target(getContext());
38+
39+
target.addLegalDialect<emitc::EmitCDialect>();
40+
target.addIllegalDialect<arith::ArithDialect>();
41+
target.addLegalOp<arith::ConstantOp>();
42+
43+
RewritePatternSet patterns(&getContext());
44+
45+
TypeConverter typeConverter;
46+
typeConverter.addConversion([](Type type) { return type; });
47+
48+
populateArithToEmitCPatterns(typeConverter, patterns);
49+
50+
if (failed(
51+
applyPartialConversion(getOperation(), target, std::move(patterns))))
52+
signalPassFailure();
53+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_conversion_library(MLIRArithToEmitC
2+
ArithToEmitC.cpp
3+
ArithToEmitCPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
7+
8+
DEPENDS
9+
MLIRConversionPassIncGen
10+
11+
LINK_LIBS PUBLIC
12+
MLIRArithDialect
13+
MLIREmitCDialect
14+
MLIRPass
15+
MLIRTransformUtils
16+
)

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
33
add_subdirectory(ArithCommon)
44
add_subdirectory(ArithToAMDGPU)
55
add_subdirectory(ArithToArmSME)
6+
add_subdirectory(ArithToEmitC)
67
add_subdirectory(ArithToLLVM)
78
add_subdirectory(ArithToSPIRV)
89
add_subdirectory(ArmNeon2dToIntr)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s
2+
3+
func.func @arith_ops(%arg0: f32, %arg1: f32) {
4+
// CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32
5+
%0 = arith.addf %arg0, %arg1 : f32
6+
// CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32
7+
%1 = arith.divf %arg0, %arg1 : f32
8+
// CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32
9+
%2 = arith.mulf %arg0, %arg1 : f32
10+
// CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32
11+
%3 = arith.subf %arg0, %arg1 : f32
12+
13+
return
14+
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4014,6 +4014,7 @@ cc_library(
40144014
":AffineToStandard",
40154015
":ArithToAMDGPU",
40164016
":ArithToArmSME",
4017+
":ArithToEmitC",
40174018
":ArithToLLVM",
40184019
":ArithToSPIRV",
40194020
":ArmNeon2dToIntr",
@@ -8162,6 +8163,32 @@ cc_library(
81628163
],
81638164
)
81648165

8166+
cc_library(
8167+
name = "ArithToEmitC",
8168+
srcs = glob([
8169+
"lib/Conversion/ArithToEmitC/*.cpp",
8170+
"lib/Conversion/ArithToEmitC/*.h",
8171+
]),
8172+
hdrs = glob([
8173+
"include/mlir/Conversion/ArithToEmitC/*.h",
8174+
]),
8175+
includes = [
8176+
"include",
8177+
"lib/Conversion/ArithToEmitC",
8178+
],
8179+
deps = [
8180+
":ArithDialect",
8181+
":ConversionPassIncGen",
8182+
":EmitCDialect",
8183+
":IR",
8184+
":Pass",
8185+
":Support",
8186+
":TransformUtils",
8187+
":Transforms",
8188+
"//llvm:Support",
8189+
],
8190+
)
8191+
81658192
cc_library(
81668193
name = "ArithToLLVM",
81678194
srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),

0 commit comments

Comments
 (0)