Skip to content

Commit d3c90a4

Browse files
authored
[MLIR] Add initial convert-memref-to-emitc pass (#30)
This translates memref types in func.func, func.call and func.return. Reviewers: TinaAMD Reviewed By: TinaAMD Pull Request: #113
1 parent 995b14c commit d3c90a4

File tree

8 files changed

+216
-0
lines changed

8 files changed

+216
-0
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- MemRefToEmitC.h - Convert MemRef to EmitC --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
9+
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
10+
11+
#include "mlir/Pass/Pass.h"
12+
13+
namespace mlir {
14+
class RewritePatternSet;
15+
class TypeConverter;
16+
17+
#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
18+
#include "mlir/Conversion/Passes.h.inc"
19+
20+
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
21+
TypeConverter &typeConverter);
22+
23+
std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();
24+
25+
} // namespace mlir
26+
27+
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4343
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
4444
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
45+
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
4546
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
4647
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
4748
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,15 @@ def ConvertMathToFuncs : Pass<"convert-math-to-funcs", "ModuleOp"> {
724724
];
725725
}
726726

727+
//===----------------------------------------------------------------------===//
728+
// MemRefToEmitC
729+
//===----------------------------------------------------------------------===//
730+
731+
def ConvertMemRefToEmitC : Pass<"convert-memref-to-emitc"> {
732+
let summary = "Convert MemRef dialect to EmitC dialect";
733+
let dependentDialects = ["emitc::EmitCDialect"];
734+
}
735+
727736
//===----------------------------------------------------------------------===//
728737
// MemRefToLLVM
729738
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ add_subdirectory(MathToFuncs)
3232
add_subdirectory(MathToLibm)
3333
add_subdirectory(MathToLLVM)
3434
add_subdirectory(MathToSPIRV)
35+
add_subdirectory(MemRefToEmitC)
3536
add_subdirectory(MemRefToLLVM)
3637
add_subdirectory(MemRefToSPIRV)
3738
add_subdirectory(NVGPUToNVVM)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(MLIRMemRefToEmitC
2+
MemRefToEmitC.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIREmitCDialect
15+
MLIRMemRefDialect
16+
MLIRTransforms
17+
)
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
//===- MemRefToEmitC.cpp - MemRef to EmitC conversion ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to convert memref ops into emitc ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
14+
15+
#include "mlir/Dialect/EmitC/IR/EmitC.h"
16+
#include "mlir/Dialect/Func/IR/FuncOps.h"
17+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
18+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
19+
#include "mlir/IR/Builders.h"
20+
#include "mlir/IR/BuiltinOps.h"
21+
#include "mlir/IR/IRMapping.h"
22+
#include "mlir/IR/MLIRContext.h"
23+
#include "mlir/IR/PatternMatch.h"
24+
#include "mlir/Interfaces/FunctionInterfaces.h"
25+
#include "mlir/Transforms/DialectConversion.h"
26+
#include "mlir/Transforms/Passes.h"
27+
28+
namespace mlir {
29+
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
30+
#include "mlir/Conversion/Passes.h.inc"
31+
} // namespace mlir
32+
33+
using namespace mlir;
34+
35+
namespace {
36+
37+
/// Disallow all memrefs even though we only have conversions
38+
/// for memrefs with static shape right now to have good diagnostics.
39+
bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
40+
41+
template <typename RangeT>
42+
std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
43+
!std::is_convertible<RangeT, Operation *>::value,
44+
bool>
45+
isLegal(RangeT &&range) {
46+
return llvm::all_of(range, [](Type type) { return isLegal(type); });
47+
}
48+
49+
bool isLegal(Operation *op) {
50+
return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
51+
}
52+
53+
bool isSignatureLegal(FunctionType ty) {
54+
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
55+
}
56+
57+
struct ConvertMemRefToEmitCPass
58+
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
59+
void runOnOperation() override {
60+
TypeConverter converter;
61+
// Pass through for all other types.
62+
converter.addConversion([](Type type) { return type; });
63+
64+
converter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
65+
if (memRefType.hasStaticShape()) {
66+
return emitc::ArrayType::get(memRefType.getShape(),
67+
memRefType.getElementType());
68+
}
69+
return {};
70+
});
71+
72+
converter.addConversion(
73+
[&converter](FunctionType ty) -> std::optional<Type> {
74+
SmallVector<Type> inputs;
75+
if (failed(converter.convertTypes(ty.getInputs(), inputs)))
76+
return std::nullopt;
77+
78+
SmallVector<Type> results;
79+
if (failed(converter.convertTypes(ty.getResults(), results)))
80+
return std::nullopt;
81+
82+
return FunctionType::get(ty.getContext(), inputs, results);
83+
});
84+
85+
RewritePatternSet patterns(&getContext());
86+
populateMemRefToEmitCConversionPatterns(patterns, converter);
87+
88+
ConversionTarget target(getContext());
89+
target.addDynamicallyLegalOp<func::FuncOp>(
90+
[](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
91+
target.addDynamicallyLegalDialect<func::FuncDialect>(
92+
[](Operation *op) { return isLegal(op); });
93+
target.addIllegalDialect<memref::MemRefDialect>();
94+
95+
if (failed(applyPartialConversion(getOperation(), target,
96+
std::move(patterns))))
97+
return signalPassFailure();
98+
}
99+
};
100+
} // namespace
101+
102+
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
103+
TypeConverter &converter) {
104+
105+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
106+
converter);
107+
populateCallOpTypeConversionPattern(patterns, converter);
108+
populateReturnOpTypeConversionPattern(patterns, converter);
109+
}
110+
111+
std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
112+
return std::make_unique<ConvertMemRefToEmitCPass>();
113+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file -verify-diagnostics
2+
3+
// Unranked memrefs are not converted
4+
// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
5+
func.func @memref_unranked(%arg0 : memref<*xf32>) {
6+
return
7+
}
8+
9+
// -----
10+
11+
// Memrefs with dynamic shapes are not converted
12+
// expected-error@+1 {{failed to legalize operation 'func.func' that was explicitly marked illegal}}
13+
func.func @memref_dynamic_shape(%arg0 : memref<2x?xf32>) {
14+
return
15+
}
16+
17+
// -----
18+
19+
// Memrefs with dynamic shapes are not converted
20+
func.func @memref_op(%arg0 : memref<2x4xf32>) {
21+
// expected-error@+1 {{failed to legalize operation 'memref.copy' that was explicitly marked illegal}}
22+
memref.copy %arg0, %arg0 : memref<2x4xf32> to memref<2x4xf32>
23+
return
24+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-opt -convert-memref-to-emitc %s -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: memref_arg
4+
// CHECK-SAME: !emitc.array<32xf32>)
5+
func.func @memref_arg(%arg0 : memref<32xf32>) {
6+
func.return
7+
}
8+
9+
// -----
10+
11+
// CHECK-LABEL: memref_return
12+
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>) -> !emitc.array<32xf32>
13+
func.func @memref_return(%arg0 : memref<32xf32>) -> memref<32xf32> {
14+
// CHECK: return %[[arg0]] : !emitc.array<32xf32>
15+
func.return %arg0 : memref<32xf32>
16+
}
17+
18+
// CHECK-LABEL: memref_call
19+
// CHECK-SAME: %[[arg0:.*]]: !emitc.array<32xf32>)
20+
func.func @memref_call(%arg0 : memref<32xf32>) {
21+
// CHECK: call @memref_return(%[[arg0]]) : (!emitc.array<32xf32>) -> !emitc.array<32xf32>
22+
func.call @memref_return(%arg0) : (memref<32xf32>) -> memref<32xf32>
23+
func.return
24+
}

0 commit comments

Comments
 (0)