Skip to content

Commit 5f6e7a5

Browse files
marbremgehre-amd
authored andcommitted
[mlir][EmitC] Add Arith to EmitC conversions (llvm#84151)
This adds patterns and a pass to convert the Arith dialect to EmitC. For now, this covers arithemtic binary ops operating on floating point types. It is not checked within the patterns whether the types, such as the Tensor type, are supported in the respective EmitC operations. If unsupported types should be converted, the conversion will fail anyway because no legal EmitC operation can be created. This can clearly be improved in a follow up, also resulting in better error messages. Functions for such checks should not solely be used in the conversions and should also be (re)used in the verifier.
1 parent add43c2 commit 5f6e7a5

File tree

9 files changed

+163
-113
lines changed

9 files changed

+163
-113
lines changed
Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
1-
//===- ArithToEmitC.h - Convert Arith to EmitC ----------------------------===//
1+
//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
89
#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
910
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
1011

11-
#include "mlir/Pass/Pass.h"
12-
1312
namespace mlir {
1413
class RewritePatternSet;
14+
class TypeConverter;
1515

16-
#define GEN_PASS_DECL_ARITHTOEMITCCONVERSIONPASS
17-
#include "mlir/Conversion/Passes.h.inc"
18-
19-
void populateArithToEmitCConversionPatterns(RewritePatternSet &patterns);
16+
void populateArithToEmitCPatterns(TypeConverter &typeConverter,
17+
RewritePatternSet &patterns);
2018
} // namespace mlir
2119

2220
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
10+
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_CONVERTARITHTOEMITC
18+
#include "mlir/Conversion/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h"
1313
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
1414
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
15-
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
15+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
1616
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
1717
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
1818
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,17 +125,20 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
125125
}];
126126

127127
let dependentDialects = ["amdgpu::AMDGPUDialect", "vector::VectorDialect"];
128+
129+
let options = [
130+
Option<"saturateFP8Truncf", "saturate-fp8-truncf", "bool",
131+
/*default=*/"false",
132+
"Use saturating truncation for 8-bit float types">,
133+
];
128134
}
129135

130136
//===----------------------------------------------------------------------===//
131137
// ArithToEmitC
132138
//===----------------------------------------------------------------------===//
133139

134-
def ArithToEmitCConversionPass : Pass<"convert-arith-to-emitc"> {
135-
let summary = "Convert Arith ops to EmitC ops";
136-
let description = [{
137-
Convert `arith` operations to operations in the `emitc` dialect.
138-
}];
140+
def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
141+
let summary = "Convert Arith dialect to EmitC dialect";
139142
let dependentDialects = ["emitc::EmitCDialect"];
140143
}
141144

Lines changed: 31 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,60 @@
1-
//===- ArithToEmitC.cpp - Arith to EmitC conversion -----------------------===//
1+
//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements a pass to convert arith ops into emitc ops.
9+
// This file implements patterns to convert the Arith dialect to the EmitC
10+
// dialect.
1011
//
1112
//===----------------------------------------------------------------------===//
1213

1314
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
1415

1516
#include "mlir/Dialect/Arith/IR/Arith.h"
1617
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17-
#include "mlir/IR/BuiltinTypes.h"
18-
#include "mlir/Support/LogicalResult.h"
1918
#include "mlir/Transforms/DialectConversion.h"
2019

21-
namespace mlir {
22-
#define GEN_PASS_DEF_ARITHTOEMITCCONVERSIONPASS
23-
#include "mlir/Conversion/Passes.h.inc"
24-
} // namespace mlir
25-
2620
using namespace mlir;
2721

28-
namespace {
29-
30-
static bool isConvertibleToEmitC(Type type) {
31-
Type baseType = type;
32-
if (auto tensorType = dyn_cast<TensorType>(type)) {
33-
if (!tensorType.hasRank() || !tensorType.hasStaticShape()) {
34-
return false;
35-
}
36-
baseType = tensorType.getElementType();
37-
}
38-
39-
if (isa<IndexType>(baseType)) {
40-
return true;
41-
}
42-
43-
if (auto intType = dyn_cast<IntegerType>(baseType)) {
44-
switch (intType.getWidth()) {
45-
case 1:
46-
case 8:
47-
case 16:
48-
case 32:
49-
case 64:
50-
return true;
51-
}
52-
return false;
53-
}
54-
55-
if (auto floatType = dyn_cast<FloatType>(baseType)) {
56-
return floatType.isF32() || floatType.isF64();
57-
}
58-
59-
return false;
60-
}
22+
//===----------------------------------------------------------------------===//
23+
// Conversion Patterns
24+
//===----------------------------------------------------------------------===//
6125

62-
class ArithConstantOpConversionPattern
63-
: public OpRewritePattern<arith::ConstantOp> {
26+
namespace {
27+
template <typename ArithOp, typename EmitCOp>
28+
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
6429
public:
65-
using OpRewritePattern::OpRewritePattern;
30+
using OpConversionPattern<ArithOp>::OpConversionPattern;
6631

67-
LogicalResult matchAndRewrite(arith::ConstantOp arithConst,
68-
PatternRewriter &rewriter) const override {
32+
LogicalResult
33+
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
34+
ConversionPatternRewriter &rewriter) const override {
6935

70-
auto constantType = arithConst.getType();
71-
if (!isConvertibleToEmitC(constantType)) {
72-
return rewriter.notifyMatchFailure(arithConst.getLoc(),
73-
"Type cannot be converted to emitc");
74-
}
36+
rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
37+
adaptor.getOperands());
7538

76-
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, constantType,
77-
arithConst.getValue());
7839
return success();
7940
}
8041
};
81-
82-
struct ConvertArithToEmitCPass
83-
: public impl::ArithToEmitCConversionPassBase<ConvertArithToEmitCPass> {
84-
public:
85-
void runOnOperation() override {
86-
87-
ConversionTarget target(getContext());
88-
target.addIllegalDialect<arith::ArithDialect>();
89-
target.addLegalDialect<emitc::EmitCDialect>();
90-
RewritePatternSet patterns(&getContext());
91-
populateArithToEmitCConversionPatterns(patterns);
92-
93-
if (failed(applyPartialConversion(getOperation(), target,
94-
std::move(patterns)))) {
95-
signalPassFailure();
96-
}
97-
}
98-
};
99-
10042
} // namespace
10143

102-
void mlir::populateArithToEmitCConversionPatterns(RewritePatternSet &patterns) {
103-
patterns.add<ArithConstantOpConversionPattern>(patterns.getContext());
44+
//===----------------------------------------------------------------------===//
45+
// Pattern population
46+
//===----------------------------------------------------------------------===//
47+
48+
void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
49+
RewritePatternSet &patterns) {
50+
MLIRContext *ctx = patterns.getContext();
51+
52+
// clang-format off
53+
patterns.add<
54+
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
55+
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
56+
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
57+
ArithOpConversion<arith::SubFOp, emitc::SubOp>
58+
>(typeConverter, ctx);
59+
// clang-format on
10460
}
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert the Arith dialect to the EmitC
10+
// dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
15+
16+
#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
17+
#include "mlir/Dialect/Arith/IR/Arith.h"
18+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
19+
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Transforms/DialectConversion.h"
21+
22+
namespace mlir {
23+
#define GEN_PASS_DEF_CONVERTARITHTOEMITC
24+
#include "mlir/Conversion/Passes.h.inc"
25+
} // namespace mlir
26+
27+
using namespace mlir;
28+
29+
namespace {
30+
struct ConvertArithToEmitC
31+
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
32+
void runOnOperation() override;
33+
};
34+
} // namespace
35+
36+
void ConvertArithToEmitC::runOnOperation() {
37+
ConversionTarget target(getContext());
38+
39+
target.addLegalDialect<emitc::EmitCDialect>();
40+
target.addIllegalDialect<arith::ArithDialect>();
41+
target.addLegalOp<arith::ConstantOp>();
42+
43+
RewritePatternSet patterns(&getContext());
44+
45+
TypeConverter typeConverter;
46+
typeConverter.addConversion([](Type type) { return type; });
47+
48+
populateArithToEmitCPatterns(typeConverter, patterns);
49+
50+
if (failed(
51+
applyPartialConversion(getOperation(), target, std::move(patterns))))
52+
signalPassFailure();
53+
}
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
add_mlir_conversion_library(ArithToEmitC
1+
add_mlir_conversion_library(MLIRArithToEmitC
22
ArithToEmitC.cpp
3+
ArithToEmitCPass.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC
67

78
DEPENDS
89
MLIRConversionPassIncGen
910

10-
LINK_COMPONENTS
11-
Core
12-
1311
LINK_LIBS PUBLIC
14-
MLIREmitCDialect
1512
MLIRArithDialect
16-
MLIRTransforms
17-
)
13+
MLIREmitCDialect
14+
MLIRPass
15+
MLIRTransformUtils
16+
)
Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,14 @@
1-
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc %s | FileCheck %s
1+
// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s
2+
3+
func.func @arith_ops(%arg0: f32, %arg1: f32) {
4+
// CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32
5+
%0 = arith.addf %arg0, %arg1 : f32
6+
// CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32
7+
%1 = arith.divf %arg0, %arg1 : f32
8+
// CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32
9+
%2 = arith.mulf %arg0, %arg1 : f32
10+
// CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32
11+
%3 = arith.subf %arg0, %arg1 : f32
212

3-
// CHECK-LABEL: arith_constants
4-
func.func @arith_constants() {
5-
// CHECK: emitc.constant
6-
// CHECK-SAME: value = 0 : index
7-
%c_index = arith.constant 0 : index
8-
// CHECK: emitc.constant
9-
// CHECK-SAME: value = 0 : i32
10-
%c_signless_int_32 = arith.constant 0 : i32
11-
// CHECK: emitc.constant
12-
// CHECK-SAME: value = 0.{{0+}}e+00 : f32
13-
%c_float_32 = arith.constant 0.0 : f32
14-
// CHECK: emitc.constant
15-
// CHECK-SAME: value = dense<0> : tensor<i32>
16-
%c_tensor_single_value = arith.constant dense<0> : tensor<i32>
17-
// CHECK: emitc.constant
18-
// CHECK-SAME: value{{.*}}[1, 2], [-3, 9], [0, 0], [2, -1]{{.*}}tensor<4x2xi64>
19-
%c_tensor_value = arith.constant dense<[[1, 2], [-3, 9], [0, 0], [2, -1]]> : tensor<4x2xi64>
2013
return
2114
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3841,6 +3841,7 @@ cc_library(
38413841
":AMDGPUToROCDL",
38423842
":AffineToStandard",
38433843
":ArithToAMDGPU",
3844+
":ArithToEmitC",
38443845
":ArithToLLVM",
38453846
":ArithToSPIRV",
38463847
":ArmNeon2dToIntr",
@@ -7967,6 +7968,32 @@ cc_library(
79677968
],
79687969
)
79697970

7971+
cc_library(
7972+
name = "ArithToEmitC",
7973+
srcs = glob([
7974+
"lib/Conversion/ArithToEmitC/*.cpp",
7975+
"lib/Conversion/ArithToEmitC/*.h",
7976+
]),
7977+
hdrs = glob([
7978+
"include/mlir/Conversion/ArithToEmitC/*.h",
7979+
]),
7980+
includes = [
7981+
"include",
7982+
"lib/Conversion/ArithToEmitC",
7983+
],
7984+
deps = [
7985+
":ArithDialect",
7986+
":ConversionPassIncGen",
7987+
":EmitCDialect",
7988+
":IR",
7989+
":Pass",
7990+
":Support",
7991+
":TransformUtils",
7992+
":Transforms",
7993+
"//llvm:Support",
7994+
],
7995+
)
7996+
79707997
cc_library(
79717998
name = "ArithToLLVM",
79727999
srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),

0 commit comments

Comments
 (0)