Skip to content

Commit c2ebdad

Browse files
committed
[mlir][emitc] Arith to EmitC conversion pass
Add a conversion pass from Arith to EmitC. Add an initial conversion from `arith.constant` to `emitc.constant`.
1 parent dfec4ef commit c2ebdad

File tree

8 files changed

+193
-0
lines changed

8 files changed

+193
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- ArithToEmitC.h - Convert Arith to EmitC ----------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
9+
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
10+
11+
#include "mlir/Pass/Pass.h"
12+
13+
namespace mlir {
14+
class RewritePatternSet;
15+
16+
#define GEN_PASS_DECL_ARITHTOEMITCCONVERSIONPASS
17+
#include "mlir/Conversion/Passes.h.inc"
18+
19+
void populateArithToEmitCConversionPatterns(RewritePatternSet &patterns);
20+
} // namespace mlir
21+
22+
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1414
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
1515
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
16+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
1617
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1718
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
1819
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,18 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
133133
];
134134
}
135135

136+
//===----------------------------------------------------------------------===//
137+
// ArithToEmitC
138+
//===----------------------------------------------------------------------===//
139+
140+
def ArithToEmitCConversionPass : Pass<"convert-arith-to-emitc"> {
141+
let summary = "Convert Arith ops to EmitC ops";
142+
let description = [{
143+
Convert `arith` operations to operations in the `emitc` dialect.
144+
}];
145+
let dependentDialects = ["emitc::EmitCDialect"];
146+
}
147+
136148
//===----------------------------------------------------------------------===//
137149
// ArithToLLVM
138150
//===----------------------------------------------------------------------===//
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
//===- ArithToEmitC.cpp - Arith to EmitC conversion -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert arith ops into emitc ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
14+
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/IR/BuiltinTypes.h"
18+
#include "mlir/Support/LogicalResult.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
21+
namespace mlir {
22+
#define GEN_PASS_DEF_ARITHTOEMITCCONVERSIONPASS
23+
#include "mlir/Conversion/Passes.h.inc"
24+
} // namespace mlir
25+
26+
using namespace mlir;
27+
28+
namespace {
29+
30+
static bool isConvertibleToEmitC(Type type) {
31+
Type baseType = type;
32+
if (auto tensorType = dyn_cast<TensorType>(type)) {
33+
if (!tensorType.hasRank() || !tensorType.hasStaticShape()) {
34+
return false;
35+
}
36+
baseType = tensorType.getElementType();
37+
}
38+
39+
if (isa<IndexType>(baseType)) {
40+
return true;
41+
}
42+
43+
if (auto intType = dyn_cast<IntegerType>(baseType)) {
44+
switch (intType.getWidth()) {
45+
case 1:
46+
case 8:
47+
case 16:
48+
case 32:
49+
case 64:
50+
return true;
51+
}
52+
return false;
53+
}
54+
55+
if (auto floatType = dyn_cast<FloatType>(baseType)) {
56+
return floatType.isF32() || floatType.isF64();
57+
}
58+
59+
return false;
60+
}
61+
62+
class ArithConstantOpConversionPattern
63+
: public OpRewritePattern<arith::ConstantOp> {
64+
public:
65+
using OpRewritePattern::OpRewritePattern;
66+
67+
LogicalResult matchAndRewrite(arith::ConstantOp arithConst,
68+
PatternRewriter &rewriter) const override {
69+
70+
auto constantType = arithConst.getType();
71+
if (!isConvertibleToEmitC(constantType)) {
72+
return rewriter.notifyMatchFailure(arithConst.getLoc(),
73+
"Type cannot be converted to emitc");
74+
}
75+
76+
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, constantType,
77+
arithConst.getValue());
78+
return success();
79+
}
80+
};
81+
82+
struct ConvertArithToEmitCPass
83+
: public impl::ArithToEmitCConversionPassBase<ConvertArithToEmitCPass> {
84+
public:
85+
void runOnOperation() override {
86+
87+
ConversionTarget target(getContext());
88+
target.addIllegalDialect<arith::ArithDialect>();
89+
target.addLegalDialect<emitc::EmitCDialect>();
90+
RewritePatternSet patterns(&getContext());
91+
populateArithToEmitCConversionPatterns(patterns);
92+
93+
if (failed(applyPartialConversion(getOperation(), target,
94+
std::move(patterns)))) {
95+
signalPassFailure();
96+
}
97+
}
98+
};
99+
100+
} // namespace
101+
102+
void mlir::populateArithToEmitCConversionPatterns(RewritePatternSet &patterns) {
103+
patterns.add<ArithConstantOpConversionPattern>(patterns.getContext());
104+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(ArithToEmitC
2+
ArithToEmitC.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIREmitCDialect
15+
MLIRArithDialect
16+
MLIRTransforms
17+
)

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
33
add_subdirectory(ArithCommon)
44
add_subdirectory(ArithToAMDGPU)
55
add_subdirectory(ArithToArmSME)
6+
add_subdirectory(ArithToEmitC)
67
add_subdirectory(ArithToLLVM)
78
add_subdirectory(ArithToSPIRV)
89
add_subdirectory(ArmNeon2dToIntr)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
2+
3+
func.func @arith_constant_complex_tensor() -> (tensor<complex<i32>>) {
4+
// expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
5+
%c = arith.constant dense<(2, 2)> : tensor<complex<i32>>
6+
return %c : tensor<complex<i32>>
7+
}
8+
9+
// -----
10+
11+
func.func @arith_constant_invalid_int_type() -> (i10) {
12+
// expected-error @+1 {{failed to legalize operation 'arith.constant' that was explicitly marked illegal}}
13+
%c = arith.constant 0 : i10
14+
return %c : i10
15+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc %s | FileCheck %s
2+
3+
// CHECK-LABEL: arith_constants
4+
func.func @arith_constants() {
5+
// CHECK: emitc.constant
6+
// CHECK-SAME: value = 0 : index
7+
%c_index = arith.constant 0 : index
8+
// CHECK: emitc.constant
9+
// CHECK-SAME: value = 0 : i32
10+
%c_signless_int_32 = arith.constant 0 : i32
11+
// CHECK: emitc.constant
12+
// CHECK-SAME: value = 0.{{0+}}e+00 : f32
13+
%c_float_32 = arith.constant 0.0 : f32
14+
// CHECK: emitc.constant
15+
// CHECK-SAME: value = dense<0> : tensor<i32>
16+
%c_tensor_single_value = arith.constant dense<0> : tensor<i32>
17+
// CHECK: emitc.constant
18+
// CHECK-SAME: value{{.*}}[1, 2], [-3, 9], [0, 0], [2, -1]{{.*}}tensor<4x2xi64>
19+
%c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64>
20+
return
21+
}

0 commit comments

Comments
 (0)