Skip to content

[mlir][EmitC] Add Arith to EmitC conversions #84151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- ArithToEmitC.h - Arith to EmitC Patterns -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H

namespace mlir {
class RewritePatternSet;
class TypeConverter;

void populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns);
} // namespace mlir

#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H
21 changes: 21 additions & 0 deletions mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- ArithToEmitCPass.h - Arith to EmitC Pass -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
#define MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTARITHTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITCPASS_H
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h"
#include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h"
#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
Expand Down
9 changes: 9 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
];
}

//===----------------------------------------------------------------------===//
// ArithToEmitC
//===----------------------------------------------------------------------===//

def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
let summary = "Convert Arith dialect to EmitC dialect";
let dependentDialects = ["emitc::EmitCDialect"];
}

//===----------------------------------------------------------------------===//
// ArithToLLVM
//===----------------------------------------------------------------------===//
Expand Down
60 changes: 60 additions & 0 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements patterns to convert the Arith dialect to the EmitC
// dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Conversion Patterns
//===----------------------------------------------------------------------===//

namespace {
template <typename ArithOp, typename EmitCOp>
class ArithOpConversion final : public OpConversionPattern<ArithOp> {
public:
using OpConversionPattern<ArithOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
adaptor.getOperands());

return success();
}
};
} // namespace

//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//

void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *ctx = patterns.getContext();

// clang-format off
patterns.add<
ArithOpConversion<arith::AddFOp, emitc::AddOp>,
ArithOpConversion<arith::DivFOp, emitc::DivOp>,
ArithOpConversion<arith::MulFOp, emitc::MulOp>,
ArithOpConversion<arith::SubFOp, emitc::SubOp>
>(typeConverter, ctx);
// clang-format on
}
53 changes: 53 additions & 0 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//===- ArithToEmitCPass.cpp - Arith to EmitC Pass ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert the Arith dialect to the EmitC
// dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h"

#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTARITHTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {
struct ConvertArithToEmitC
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
void runOnOperation() override;
};
} // namespace

void ConvertArithToEmitC::runOnOperation() {
ConversionTarget target(getContext());

target.addLegalDialect<emitc::EmitCDialect>();
target.addIllegalDialect<arith::ArithDialect>();
target.addLegalOp<arith::ConstantOp>();

RewritePatternSet patterns(&getContext());

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

populateArithToEmitCPatterns(typeConverter, patterns);

if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
signalPassFailure();
}
16 changes: 16 additions & 0 deletions mlir/lib/Conversion/ArithToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
add_mlir_conversion_library(MLIRArithToEmitC
ArithToEmitC.cpp
ArithToEmitCPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArithToEmitC

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
MLIREmitCDialect
MLIRPass
MLIRTransformUtils
)
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_subdirectory(AMDGPUToROCDL)
add_subdirectory(ArithCommon)
add_subdirectory(ArithToAMDGPU)
add_subdirectory(ArithToArmSME)
add_subdirectory(ArithToEmitC)
add_subdirectory(ArithToLLVM)
add_subdirectory(ArithToSPIRV)
add_subdirectory(ArmNeon2dToIntr)
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-opt -convert-arith-to-emitc %s | FileCheck %s

func.func @arith_ops(%arg0: f32, %arg1: f32) {
// CHECK: [[V0:[^ ]*]] = emitc.add %arg0, %arg1 : (f32, f32) -> f32
%0 = arith.addf %arg0, %arg1 : f32
// CHECK: [[V1:[^ ]*]] = emitc.div %arg0, %arg1 : (f32, f32) -> f32
%1 = arith.divf %arg0, %arg1 : f32
// CHECK: [[V2:[^ ]*]] = emitc.mul %arg0, %arg1 : (f32, f32) -> f32
%2 = arith.mulf %arg0, %arg1 : f32
// CHECK: [[V3:[^ ]*]] = emitc.sub %arg0, %arg1 : (f32, f32) -> f32
%3 = arith.subf %arg0, %arg1 : f32

return
}
27 changes: 27 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -4011,6 +4011,7 @@ cc_library(
":AffineToStandard",
":ArithToAMDGPU",
":ArithToArmSME",
":ArithToEmitC",
":ArithToLLVM",
":ArithToSPIRV",
":ArmNeon2dToIntr",
Expand Down Expand Up @@ -8156,6 +8157,32 @@ cc_library(
],
)

cc_library(
name = "ArithToEmitC",
srcs = glob([
"lib/Conversion/ArithToEmitC/*.cpp",
"lib/Conversion/ArithToEmitC/*.h",
]),
hdrs = glob([
"include/mlir/Conversion/ArithToEmitC/*.h",
]),
includes = [
"include",
"lib/Conversion/ArithToEmitC",
],
deps = [
":ArithDialect",
":ConversionPassIncGen",
":EmitCDialect",
":IR",
":Pass",
":Support",
":TransformUtils",
":Transforms",
"//llvm:Support",
],
)

cc_library(
name = "ArithToLLVM",
srcs = glob(["lib/Conversion/ArithToLLVM/*.cpp"]),
Expand Down