Skip to content

Commit 5d7f3e9

Browse files
committed
[mlir][emitc] Add func to emitc conversion
This adds patterns and a pass to convert the Func dialect to EmitC. A `func.func` op that is `private` is converted to `emitc.func` with a `"static"` specifier.
1 parent a03e89c commit 5d7f3e9

File tree

9 files changed

+292
-0
lines changed

9 files changed

+292
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===- FuncToEmitC.h - Func to EmitC Patterns -------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H
10+
#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H
11+
12+
namespace mlir {
13+
class RewritePatternSet;
14+
15+
void populateFuncToEmitCPatterns(RewritePatternSet &patterns);
16+
} // namespace mlir
17+
18+
#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITC_H
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===- FuncToEmitCPass.h - Func to EmitC Pass -------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H
10+
#define MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
17+
#define GEN_PASS_DECL_FUNCTOEMITC
18+
#include "mlir/Conversion/Passes.h.inc"
19+
20+
std::unique_ptr<Pass> createConvertFuncToEmitC();
21+
22+
} // namespace mlir
23+
24+
#endif // MLIR_CONVERSION_FUNCTOEMITC_FUNCTOEMITCPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRV.h"
2929
#include "mlir/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.h"
3030
#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
31+
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
3132
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
3233
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
3334
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,15 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> {
344344
];
345345
}
346346

347+
//===----------------------------------------------------------------------===//
348+
// FuncToEmitC
349+
//===----------------------------------------------------------------------===//
350+
351+
def ConvertFuncToEmitC : Pass<"convert-func-to-emitc", "ModuleOp"> {
352+
let summary = "Convert Func dialect to EmitC dialect";
353+
let dependentDialects = ["emitc::EmitCDialect"];
354+
}
355+
347356
//===----------------------------------------------------------------------===//
348357
// FuncToLLVM
349358
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_subdirectory(ControlFlowToLLVM)
1717
add_subdirectory(ControlFlowToSCF)
1818
add_subdirectory(ControlFlowToSPIRV)
1919
add_subdirectory(ConvertToLLVM)
20+
add_subdirectory(FuncToEmitC)
2021
add_subdirectory(FuncToLLVM)
2122
add_subdirectory(FuncToSPIRV)
2223
add_subdirectory(GPUCommon)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
add_mlir_conversion_library(MLIRFuncToEmitC
2+
FuncToEmitC.cpp
3+
FuncToEmitCPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/FuncToEmitC
7+
8+
DEPENDS
9+
MLIRConversionPassIncGen
10+
11+
LINK_LIBS PUBLIC
12+
MLIREmitCDialect
13+
MLIRFuncDialect
14+
MLIRPass
15+
MLIRTransformUtils
16+
)
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
//===- FuncToEmitC.cpp - Func to EmitC Patterns -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns to convert the Func dialect to the EmitC
10+
// dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
15+
16+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
20+
using namespace mlir;
21+
22+
//===----------------------------------------------------------------------===//
23+
// Conversion Patterns
24+
//===----------------------------------------------------------------------===//
25+
26+
namespace {
27+
class CallOpConversion final : public OpConversionPattern<func::CallOp> {
28+
public:
29+
using OpConversionPattern<func::CallOp>::OpConversionPattern;
30+
31+
LogicalResult
32+
matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
33+
ConversionPatternRewriter &rewriter) const override {
34+
// multiple results func was not converted to spirv.func
35+
if (callOp.getNumResults() > 1)
36+
return rewriter.notifyMatchFailure(
37+
callOp, "Only functions with zero or one result can be converted");
38+
39+
rewriter.replaceOpWithNewOp<emitc::CallOp>(
40+
callOp,
41+
callOp.getNumResults() ? callOp.getResult(0).getType() : nullptr,
42+
adaptor.getOperands(), callOp->getAttrs());
43+
44+
return success();
45+
}
46+
};
47+
48+
class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
49+
public:
50+
using OpConversionPattern<func::FuncOp>::OpConversionPattern;
51+
52+
LogicalResult
53+
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
54+
ConversionPatternRewriter &rewriter) const override {
55+
56+
if (funcOp.getFunctionType().getNumResults() > 1)
57+
return rewriter.notifyMatchFailure(
58+
funcOp, "Only functions with zero or one result can be converted");
59+
60+
if (funcOp.isDeclaration())
61+
return rewriter.notifyMatchFailure(funcOp,
62+
"Declarations cannot be converted");
63+
64+
// Create the converted emitc.func op.
65+
emitc::FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
66+
funcOp.getLoc(), funcOp.getName(), funcOp.getFunctionType());
67+
68+
// Copy over all attributes other than the function name and type.
69+
for (const auto &namedAttr : funcOp->getAttrs()) {
70+
if (namedAttr.getName() != funcOp.getFunctionTypeAttrName() &&
71+
namedAttr.getName() != SymbolTable::getSymbolAttrName())
72+
newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
73+
}
74+
75+
// Create add `static` to specifiers if `func.func` is private.
76+
if (funcOp.isPrivate()) {
77+
StringAttr specifier = rewriter.getStringAttr("static");
78+
ArrayAttr specifiers = rewriter.getArrayAttr(specifier);
79+
newFuncOp.setSpecifiersAttr(specifiers);
80+
}
81+
82+
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
83+
newFuncOp.end());
84+
rewriter.eraseOp(funcOp);
85+
86+
return success();
87+
}
88+
};
89+
90+
class ReturnOpConversion final : public OpConversionPattern<func::ReturnOp> {
91+
public:
92+
using OpConversionPattern<func::ReturnOp>::OpConversionPattern;
93+
94+
LogicalResult
95+
matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
96+
ConversionPatternRewriter &rewriter) const override {
97+
if (returnOp.getNumOperands() > 1)
98+
return rewriter.notifyMatchFailure(
99+
returnOp, "Only zero or one operand is supported");
100+
101+
rewriter.replaceOpWithNewOp<emitc::ReturnOp>(
102+
returnOp,
103+
returnOp.getNumOperands() ? adaptor.getOperands()[0] : nullptr);
104+
return success();
105+
}
106+
};
107+
} // namespace
108+
109+
//===----------------------------------------------------------------------===//
110+
// Pattern population
111+
//===----------------------------------------------------------------------===//
112+
113+
void mlir::populateFuncToEmitCPatterns(RewritePatternSet &patterns) {
114+
MLIRContext *ctx = patterns.getContext();
115+
116+
patterns.add<CallOpConversion>(ctx);
117+
patterns.add<FuncOpConversion>(ctx);
118+
patterns.add<ReturnOpConversion>(ctx);
119+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===- FuncToEmitC.cpp - Func to EmitC Pass ---------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert the Func dialect to the EmitC dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/FuncToEmitC/FuncToEmitCPass.h"
14+
15+
#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
16+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Pass/Pass.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
21+
namespace mlir {
22+
#define GEN_PASS_DEF_CONVERTFUNCTOEMITC
23+
#include "mlir/Conversion/Passes.h.inc"
24+
} // namespace mlir
25+
26+
using namespace mlir;
27+
28+
namespace {
29+
struct ConvertFuncToEmitC
30+
: public impl::ConvertFuncToEmitCBase<ConvertFuncToEmitC> {
31+
void runOnOperation() override;
32+
};
33+
} // namespace
34+
35+
void ConvertFuncToEmitC::runOnOperation() {
36+
ConversionTarget target(getContext());
37+
38+
target.addLegalDialect<emitc::EmitCDialect>();
39+
target.addIllegalOp<func::CallOp>();
40+
target.addIllegalOp<func::FuncOp>();
41+
target.addIllegalOp<func::ReturnOp>();
42+
43+
RewritePatternSet patterns(&getContext());
44+
populateFuncToEmitCPatterns(patterns);
45+
46+
if (failed(
47+
applyPartialConversion(getOperation(), target, std::move(patterns))))
48+
signalPassFailure();
49+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: mlir-opt -split-input-file -convert-func-to-emitc %s | FileCheck %s
2+
3+
// CHECK-LABEL: emitc.func @foo()
4+
// CHECK-NEXT: emitc.return
5+
func.func @foo() {
6+
return
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: emitc.func private @foo() attributes {specifiers = ["static"]}
12+
// CHECK-NEXT: emitc.return
13+
func.func private @foo() {
14+
return
15+
}
16+
17+
// -----
18+
19+
// CHECK-LABEL: emitc.func @foo(%arg0: i32)
20+
func.func @foo(%arg0: i32) {
21+
emitc.call_opaque "bar"(%arg0) : (i32) -> ()
22+
return
23+
}
24+
25+
// -----
26+
27+
// CHECK-LABEL: emitc.func @foo(%arg0: i32) -> i32
28+
// CHECK-NEXT: emitc.return %arg0 : i32
29+
func.func @foo(%arg0: i32) -> i32 {
30+
return %arg0 : i32
31+
}
32+
33+
// -----
34+
35+
// CHECK-LABEL: emitc.func @foo(%arg0: i32, %arg1: i32) -> i32
36+
func.func @foo(%arg0: i32, %arg1: i32) -> i32 {
37+
%0 = "emitc.add" (%arg0, %arg1) : (i32, i32) -> i32
38+
return %0 : i32
39+
}
40+
41+
// -----
42+
43+
// CHECK-LABEL: emitc.func private @return_i32(%arg0: i32) -> i32 attributes {specifiers = ["static"]}
44+
// CHECK-NEXT: emitc.return %arg0 : i32
45+
func.func private @return_i32(%arg0: i32) -> i32 {
46+
return %arg0 : i32
47+
}
48+
49+
// CHECK-LABEL: emitc.func @call(%arg0: i32) -> i32
50+
// CHECK-NEXT: %0 = emitc.call @return_i32(%arg0) : (i32) -> i32
51+
// CHECK-NEXT: emitc.return %0 : i32
52+
func.func @call(%arg0: i32) -> i32 {
53+
%0 = call @return_i32(%arg0) : (i32) -> (i32)
54+
return %0 : i32
55+
}

0 commit comments

Comments
 (0)