Skip to content

Commit 0aa6d57

Browse files
authored
[MLIR] Add initial convert-memref-to-emitc pass (#85389)
This converts `memref.alloca`, `memref.load` & `memref.store` to `emitc.variable`, `emitc.subscript` and `emitc.assign`.
1 parent 538257b commit 0aa6d57

File tree

11 files changed

+334
-0
lines changed

11 files changed

+334
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
9+
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
10+
11+
namespace mlir {
12+
class RewritePatternSet;
13+
class TypeConverter;
14+
15+
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
16+
17+
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
18+
TypeConverter &converter);
19+
} // namespace mlir
20+
21+
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- MemRefToEmitCPass.h - A Pass to convert MemRef to EmitC ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
9+
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
10+
11+
#include <memory>
12+
13+
namespace mlir {
14+
class Pass;
15+
16+
#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
17+
#include "mlir/Conversion/Passes.h.inc"
18+
} // namespace mlir
19+
20+
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4646
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
4747
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
48+
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
4849
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
4950
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
5051
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
753753
];
754754
}
755755

756+
//===----------------------------------------------------------------------===//
757+
// MemRefToEmitC
758+
//===----------------------------------------------------------------------===//
759+
760+
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
761+
let summary = "Convert MemRef dialect to EmitC dialect";
762+
let dependentDialects = ["emitc::EmitCDialect"];
763+
}
764+
756765
//===----------------------------------------------------------------------===//
757766
// MemRefToLLVM
758767
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ add_subdirectory(MathToFuncs)
3535
add_subdirectory(MathToLibm)
3636
add_subdirectory(MathToLLVM)
3737
add_subdirectory(MathToSPIRV)
38+
add_subdirectory(MemRefToEmitC)
3839
add_subdirectory(MemRefToLLVM)
3940
add_subdirectory(MemRefToSPIRV)
4041
add_subdirectory(NVGPUToNVVM)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_conversion_library(MLIRMemRefToEmitC
2+
MemRefToEmitC.cpp
3+
MemRefToEmitCPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
7+
8+
DEPENDS
9+
MLIRConversionPassIncGen
10+
11+
LINK_COMPONENTS
12+
Core
13+
14+
LINK_LIBS PUBLIC
15+
MLIREmitCDialect
16+
MLIRMemRefDialect
17+
MLIRTransforms
18+
)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements patterns to convert memref ops into emitc ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
14+
15+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
16+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/IR/Builders.h"
18+
#include "mlir/IR/PatternMatch.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
21+
using namespace mlir;
22+
23+
namespace {
24+
struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
25+
using OpConversionPattern::OpConversionPattern;
26+
27+
LogicalResult
28+
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
29+
ConversionPatternRewriter &rewriter) const override {
30+
31+
if (!op.getType().hasStaticShape()) {
32+
return rewriter.notifyMatchFailure(
33+
op.getLoc(), "cannot transform alloca with dynamic shape");
34+
}
35+
36+
if (op.getAlignment().value_or(1) > 1) {
37+
// TODO: Allow alignment if it is not more than the natural alignment
38+
// of the C array.
39+
return rewriter.notifyMatchFailure(
40+
op.getLoc(), "cannot transform alloca with alignment requirement");
41+
}
42+
43+
auto resultTy = getTypeConverter()->convertType(op.getType());
44+
if (!resultTy) {
45+
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
46+
}
47+
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
48+
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
49+
return success();
50+
}
51+
};
52+
53+
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
54+
using OpConversionPattern::OpConversionPattern;
55+
56+
LogicalResult
57+
matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
58+
ConversionPatternRewriter &rewriter) const override {
59+
60+
auto resultTy = getTypeConverter()->convertType(op.getType());
61+
if (!resultTy) {
62+
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
63+
}
64+
65+
auto subscript = rewriter.create<emitc::SubscriptOp>(
66+
op.getLoc(), operands.getMemref(), operands.getIndices());
67+
68+
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
69+
auto var =
70+
rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
71+
72+
rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
73+
rewriter.replaceOp(op, var);
74+
return success();
75+
}
76+
};
77+
78+
struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
79+
using OpConversionPattern::OpConversionPattern;
80+
81+
LogicalResult
82+
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
83+
ConversionPatternRewriter &rewriter) const override {
84+
85+
auto subscript = rewriter.create<emitc::SubscriptOp>(
86+
op.getLoc(), operands.getMemref(), operands.getIndices());
87+
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
88+
operands.getValue());
89+
return success();
90+
}
91+
};
92+
} // namespace
93+
94+
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
95+
typeConverter.addConversion(
96+
[&](MemRefType memRefType) -> std::optional<Type> {
97+
if (!memRefType.hasStaticShape() ||
98+
!memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
99+
return {};
100+
}
101+
Type convertedElementType =
102+
typeConverter.convertType(memRefType.getElementType());
103+
if (!convertedElementType)
104+
return {};
105+
return emitc::ArrayType::get(memRefType.getShape(),
106+
convertedElementType);
107+
});
108+
}
109+
110+
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
111+
TypeConverter &converter) {
112+
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
113+
patterns.getContext());
114+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert memref ops into emitc ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
14+
15+
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
16+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Pass/Pass.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
21+
namespace mlir {
22+
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
23+
#include "mlir/Conversion/Passes.h.inc"
24+
} // namespace mlir
25+
26+
using namespace mlir;
27+
28+
namespace {
29+
struct ConvertMemRefToEmitCPass
30+
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
31+
void runOnOperation() override {
32+
TypeConverter converter;
33+
34+
// Fallback for other types.
35+
converter.addConversion([](Type type) -> std::optional<Type> {
36+
if (isa<MemRefType>(type))
37+
return {};
38+
return type;
39+
});
40+
41+
populateMemRefToEmitCTypeConversion(converter);
42+
43+
RewritePatternSet patterns(&getContext());
44+
populateMemRefToEmitCConversionPatterns(patterns, converter);
45+
46+
ConversionTarget target(getContext());
47+
target.addIllegalDialect<memref::MemRefDialect>();
48+
target.addLegalDialect<emitc::EmitCDialect>();
49+
50+
if (failed(applyPartialConversion(getOperation(), target,
51+
std::move(patterns))))
52+
return signalPassFailure();
53+
}
54+
};
55+
} // namespace
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
2+
3+
func.func @memref_op(%arg0 : memref<2x4xf32>) {
4+
// expected-error@+1 {{failed to legalize operation 'memref.copy'}}
5+
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
6+
return
7+
}
8+
9+
// -----
10+
11+
func.func @alloca_with_dynamic_shape() {
12+
%0 = index.constant 1
13+
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
14+
%1 = memref.alloca(%0) : memref<4x?xf32>
15+
return
16+
}
17+
18+
// -----
19+
20+
func.func @alloca_with_alignment() {
21+
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
22+
%0 = memref.alloca() {alignment = 64 : i64}: memref<4xf32>
23+
return
24+
}
25+
26+
// -----
27+
28+
func.func @non_identity_layout() {
29+
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
30+
%0 = memref.alloca() : memref<4x3xf32, affine_map<(d0, d1) -> (d1, d0)>>
31+
return
32+
}
33+
34+
// -----
35+
36+
func.func @zero_rank() {
37+
// expected-error@+1 {{failed to legalize operation 'memref.alloca'}}
38+
%0 = memref.alloca() : memref<f32>
39+
return
40+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: memref_store
4+
// CHECK-SAME: %[[v:.*]]: f32, %[[i:.*]]: index, %[[j:.*]]: index
5+
func.func @memref_store(%v : f32, %i: index, %j: index) {
6+
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
7+
%0 = memref.alloca() : memref<4x8xf32>
8+
9+
// CHECK: %[[SUBSCRIPT:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
10+
// CHECK: emitc.assign %[[v]] : f32 to %[[SUBSCRIPT:.*]] : f32
11+
memref.store %v, %0[%i, %j] : memref<4x8xf32>
12+
return
13+
}
14+
// -----
15+
16+
// CHECK-LABEL: memref_load
17+
// CHECK-SAME: %[[i:.*]]: index, %[[j:.*]]: index
18+
func.func @memref_load(%i: index, %j: index) -> f32 {
19+
// CHECK: %[[ALLOCA:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.array<4x8xf32>
20+
%0 = memref.alloca() : memref<4x8xf32>
21+
22+
// CHECK: %[[LOAD:.*]] = emitc.subscript %[[ALLOCA]][%[[i]], %[[j]]] : <4x8xf32>
23+
// CHECK: %[[VAR:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32
24+
// CHECK: emitc.assign %[[LOAD]] : f32 to %[[VAR]] : f32
25+
%1 = memref.load %0[%i, %j] : memref<4x8xf32>
26+
// CHECK: return %[[VAR]] : f32
27+
return %1 : f32
28+
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4186,6 +4186,7 @@ cc_library(
41864186
":MathToLLVM",
41874187
":MathToLibm",
41884188
":MathToSPIRV",
4189+
":MemRefToEmitC",
41894190
":MemRefToLLVM",
41904191
":MemRefToSPIRV",
41914192
":NVGPUToNVVM",
@@ -8256,6 +8257,32 @@ cc_library(
82568257
],
82578258
)
82588259

8260+
cc_library(
8261+
name = "MemRefToEmitC",
8262+
srcs = glob([
8263+
"lib/Conversion/MemRefToEmitC/*.cpp",
8264+
"lib/Conversion/MemRefToEmitC/*.h",
8265+
]),
8266+
hdrs = glob([
8267+
"include/mlir/Conversion/MemRefToEmitC/*.h",
8268+
]),
8269+
includes = [
8270+
"include",
8271+
"lib/Conversion/MemRefToEmitC",
8272+
],
8273+
deps = [
8274+
":ConversionPassIncGen",
8275+
":EmitCDialect",
8276+
":MemRefDialect",
8277+
":IR",
8278+
":Pass",
8279+
":Support",
8280+
":TransformUtils",
8281+
":Transforms",
8282+
"//llvm:Support",
8283+
],
8284+
)
8285+
82598286
cc_library(
82608287
name = "MemRefToLLVM",
82618288
srcs = glob(["lib/Conversion/MemRefToLLVM/*.cpp"]),

0 commit comments

Comments
 (0)