Skip to content

Commit 88cdc80

Browse files
committed
[MLIR] Add initial convert-memref-to-emitc pass
This translates memref types in func.func, func.call and func.return to emitc.array and it translates memref.alloca, memref.load & memref.store to emitc.variable, emitc.subscipt and emitc.assign.
1 parent 01a31ce commit 88cdc80

File tree

9 files changed

+338
-0
lines changed

9 files changed

+338
-0
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
9+
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
10+
11+
#include "mlir/Pass/Pass.h"
12+
13+
namespace mlir {
14+
class RewritePatternSet;
15+
class TypeConverter;
16+
17+
#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
18+
#include "mlir/Conversion/Passes.h.inc"
19+
20+
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
21+
22+
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
23+
TypeConverter &converter);
24+
25+
std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();
26+
27+
} // namespace mlir
28+
29+
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4646
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
4747
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
48+
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
4849
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
4950
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
5051
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
753753
];
754754
}
755755

756+
//===----------------------------------------------------------------------===//
757+
// MemRefToEmitC
758+
//===----------------------------------------------------------------------===//
759+
760+
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
761+
let summary = "Convert MemRef dialect to EmitC dialect";
762+
let dependentDialects = ["emitc::EmitCDialect"];
763+
}
764+
756765
//===----------------------------------------------------------------------===//
757766
// MemRefToLLVM
758767
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
3535
add_subdirectory(MathToLibm)
3636
add_subdirectory(MathToLLVM)
3737
add_subdirectory(MathToSPIRV)
38+
add_subdirectory(MemRefToEmitC)
3839
add_subdirectory(MemRefToLLVM)
3940
add_subdirectory(MemRefToSPIRV)
4041
add_subdirectory(NVGPUToNVVM)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRMemRefToEmitC
2+
MemRefToEmitC.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIREmitCDialect
15+
MLIRFuncDialect
16+
MLIRFuncTransforms
17+
MLIRMemRefDialect
18+
MLIRTransforms
19+
)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert memref ops into emitc ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
14+
15+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
18+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19+
#include "mlir/IR/Builders.h"
20+
#include "mlir/IR/PatternMatch.h"
21+
#include "mlir/Transforms/DialectConversion.h"
22+
23+
namespace mlir {
24+
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
25+
#include "mlir/Conversion/Passes.h.inc"
26+
} // namespace mlir
27+
28+
using namespace mlir;
29+
30+
namespace {
31+
32+
/// Disallow all memrefs even though we only have conversions
33+
/// for memrefs with static shape right now to have good diagnostics.
34+
bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
35+
36+
template <typename RangeT>
37+
bool areLegal(RangeT &&range) {
38+
return llvm::all_of(range, [](Type type) { return isLegal(type); });
39+
}
40+
41+
bool isLegal(Operation *op) {
42+
return areLegal(op->getOperandTypes()) && areLegal(op->getResultTypes());
43+
}
44+
45+
bool isSignatureLegal(FunctionType ty) {
46+
return areLegal(ty.getInputs()) && areLegal(ty.getResults());
47+
}
48+
49+
struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
50+
using OpConversionPattern::OpConversionPattern;
51+
52+
LogicalResult
53+
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
54+
ConversionPatternRewriter &rewriter) const override {
55+
56+
if (!op.getType().hasStaticShape()) {
57+
return rewriter.notifyMatchFailure(
58+
op.getLoc(), "cannot transform alloca with dynamic shape");
59+
}
60+
61+
if (op.getAlignment().value_or(1) > 1) {
62+
// TODO: Allow alignment if it is not more than the natural alignment
63+
// of the C array.
64+
return rewriter.notifyMatchFailure(
65+
op.getLoc(), "cannot transform alloca with alignment requirement");
66+
}
67+
68+
auto resultTy = getTypeConverter()->convertType(op.getType());
69+
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
70+
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
71+
return success();
72+
}
73+
};
74+
75+
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
76+
using OpConversionPattern::OpConversionPattern;
77+
78+
LogicalResult
79+
matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
80+
ConversionPatternRewriter &rewriter) const override {
81+
82+
rewriter.replaceOpWithNewOp<emitc::SubscriptOp>(op, operands.getMemref(),
83+
operands.getIndices());
84+
return success();
85+
}
86+
};
87+
88+
struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
89+
using OpConversionPattern::OpConversionPattern;
90+
91+
LogicalResult
92+
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
93+
ConversionPatternRewriter &rewriter) const override {
94+
95+
auto subscript = rewriter.create<emitc::SubscriptOp>(
96+
op.getLoc(), operands.getMemref(), operands.getIndices());
97+
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
98+
operands.getValue());
99+
return success();
100+
}
101+
};
102+
103+
struct ConvertMemRefToEmitCPass
104+
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
105+
void runOnOperation() override {
106+
TypeConverter converter;
107+
// Fallback for other types.
108+
converter.addConversion([](Type type) { return type; });
109+
populateMemRefToEmitCTypeConversion(converter);
110+
converter.addConversion(
111+
[&converter](FunctionType ty) -> std::optional<Type> {
112+
SmallVector<Type> inputs;
113+
if (failed(converter.convertTypes(ty.getInputs(), inputs)))
114+
return std::nullopt;
115+
116+
SmallVector<Type> results;
117+
if (failed(converter.convertTypes(ty.getResults(), results)))
118+
return std::nullopt;
119+
120+
return FunctionType::get(ty.getContext(), inputs, results);
121+
});
122+
123+
RewritePatternSet patterns(&getContext());
124+
populateMemRefToEmitCConversionPatterns(patterns, converter);
125+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
126+
converter);
127+
populateCallOpTypeConversionPattern(patterns, converter);
128+
populateReturnOpTypeConversionPattern(patterns, converter);
129+
130+
ConversionTarget target(getContext());
131+
target.addDynamicallyLegalOp<func::FuncOp>(
132+
[](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
133+
target.addDynamicallyLegalDialect<func::FuncDialect>(
134+
[](Operation *op) { return isLegal(op); });
135+
target.addIllegalDialect<memref::MemRefDialect>();
136+
target.addLegalDialect<emitc::EmitCDialect>();
137+
138+
if (failed(applyPartialConversion(getOperation(), target,
139+
std::move(patterns))))
140+
return signalPassFailure();
141+
}
142+
};
143+
} // namespace
144+
145+
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
146+
typeConverter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
147+
if (memRefType.hasStaticShape()) {
148+
return emitc::ArrayType::get(memRefType.getShape(),
149+
memRefType.getElementType());
150+
}
151+
return {};
152+
});
153+
}
154+
155+
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
156+
TypeConverter &converter) {
157+
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
158+
patterns.getContext());
159+
}
160+
161+
std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
162+
return std::make_unique<ConvertMemRefToEmitCPass>();
163+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
2+
3+
// Unranked memrefs are not converted
4+
// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
5+
func.func @memref_unranked(%arg0 : memref<*xf32>) {
6+
return
7+
}
8+
9+
// -----
10+
11+
// Memrefs with dynamic shapes are not converted
12+
// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
13+
func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) {
14+
return
15+
}
16+
17+
// -----
18+
19+
func.func @memref_op(%arg0 : memref<2x4xf32>) {
20+
// expected-error@+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
21+
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
22+
return
23+
}
24+
25+
// -----
26+
27+
func.func @alloca_with_dynamic_shape() {
28+
%0 = index.constant 1
29+
// expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
30+
%1 = memref.alloca(%0) : memref<4x?xf32>
31+
return
32+
}
33+
34+
// -----
35+
36+
func.func @alloca_with_alignment() {
37+
// expected-error@+1 {{failed to legalize operation 'memref.alloca' that was explicitly marked illegal}}
38+
%1 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
39+
return
40+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: memref_arg
4+
// CHECK-SAME: !emitc.array<32xf32>)
5+
func.func @memref_arg(%arg0 : memref<32xf32>) {
6+
func.return
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: memref_return
12+
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32>
13+
func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> {
14+
// CHECK: return %[[arg0]] : !emitc.array<32xf32>
15+
func.return %arg0 : memref<32xf32>
16+
}
17+
18+
// CHECK-LABEL: memref_call
19+
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>)
20+
func.func @memref_call(%arg0 : memref<32xf32>) {
21+
// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32>
22+
func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
23+
func.return
24+
}
25+
26+
// -----
27+
28+
// CHECK-LABEL: alloca
29+
func.func @alloca() {
30+
// CHECK "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
31+
%0 = memref.alloca() : memref<4x8xf32>
32+
return
33+
}
34+
35+
// -----
36+
37+
// CHECK-LABEL: memref_load_store
38+
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<4x8xf32>, %[[arg1:.*]]: !emitc.array<3x5xf32>
39+
// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
40+
func.func @memref_load_store(%in: memref<4x8xf32>, %out: memref<3x5xf32>, %i: index, %j: index) {
41+
// CHECK: %[[load:.*]] = emitc.subscript %[[arg0]][%[[i]], %[[j]]] : <4x8xf32>
42+
%0 = memref.load %in[%i, %j] : memref<4x8xf32>
43+
// CHECK: %[[store_loc:.*]] = emitc.subscript %[[arg1]][%[[i]], %[[j]]] : <3x5xf32>
44+
// CHECK: emitc.assign %[[load]] : f32 to %[[store_loc:.*]] : f32
45+
memref.store %0, %out[%i, %j] : memref<3x5xf32>
46+
return
47+
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4168,6 +4168,7 @@ cc_library(
41684168
":MathToLLVM",
41694169
":MathToLibm",
41704170
":MathToSPIRV",
4171+
":MemRefToEmitC",
41714172
":MemRefToLLVM",
41724173
":MemRefToSPIRV",
41734174
":NVGPUToNVVM",
@@ -8180,6 +8181,34 @@ cc_library(
81808181
],
81818182
)
81828183

8184+
cc_library(
8185+
name = "MemRefToEmitC",
8186+
srcs = glob([
8187+
"lib/Conversion/MemRefToEmitC/*.cpp",
8188+
"lib/Conversion/MemRefToEmitC/*.h",
8189+
]),
8190+
hdrs = glob([
8191+
"include/mlir/Conversion/MemRefToEmitC/*.h",
8192+
]),
8193+
includes = [
8194+
"include",
8195+
"lib/Conversion/MemRefToEmitC",
8196+
],
8197+
deps = [
8198+
":ConversionPassIncGen",
8199+
":EmitCDialect",
8200+
":FuncDialect",
8201+
":FuncTransforms",
8202+
":MemRefDialect",
8203+
":IR",
8204+
":Pass",
8205+
":Support",
8206+
":TransformUtils",
8207+
":Transforms",
8208+
"//llvm:Support",
8209+
],
8210+
)
8211+
81838212
cc_library(
81848213
name = "MemRefToLLVM",
81858214
srcs = glob(["lib/Conversion/MemRefToLLVM/*.cpp"]),

0 commit comments

Comments
 (0)