Skip to content

Commit bedd22f

Browse files
authored
Merge pull request #146 from Xilinx/matthias.backport_emitc
Backport EmitC-related commits from upstream (FXML-4370)
2 parents e36f02b + e676113 commit bedd22f

File tree

24 files changed

+314
-263
lines changed

24 files changed

+314
-263
lines changed

mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,14 @@
88
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
99
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
1010

11-
#include "mlir/Pass/Pass.h"
12-
1311
namespace mlir {
1412
class RewritePatternSet;
1513
class TypeConverter;
1614

17-
#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
18-
#include "mlir/Conversion/Passes.h.inc"
15+
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);
1916

2017
void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
21-
TypeConverter &typeConverter);
22-
23-
std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();
24-
18+
TypeConverter &converter);
2519
} // namespace mlir
2620

2721
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- MemRefToEmitCPass.h - A Pass to convert MemRef to EmitC ------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
9+
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
10+
11+
#include <memory>
12+
13+
namespace mlir {
14+
class Pass;
15+
16+
#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
17+
#include "mlir/Conversion/Passes.h.inc"
18+
} // namespace mlir
19+
20+
#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4545
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
4646
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
47-
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
47+
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
4848
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
4949
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
5050
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ namespace mlir {
3131
namespace emitc {
3232
void buildTerminatedBody(OpBuilder &builder, Location loc);
3333
/// Determines whether \p type is a valid integer type in EmitC.
34-
bool isValidEmitCIntegerType(mlir::Type type);
34+
bool isSupportedIntegerType(mlir::Type type);
3535
/// Determines whether \p type is a valid floating-point type in EmitC.
36-
bool isValidEmitCFloatType(mlir::Type type);
36+
bool isSupportedFloatType(mlir::Type type);
3737
} // namespace emitc
3838
} // namespace mlir
3939

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
5151
def CExpression : NativeOpTrait<"emitc::CExpression">;
5252

5353
// Types only used in binary arithmetic operations.
54-
def IntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Integer_Type, Index, EmitC_OpaqueType]>;
55-
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Float_Type, IntegerIndexOrOpaqueType]>;
54+
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index, EmitC_OpaqueType]>;
55+
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;
5656

5757
def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
5858
let summary = "Addition operation";
@@ -1169,11 +1169,11 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript",
11691169
```mlir
11701170
%i = index.constant 1
11711171
%j = index.constant 7
1172-
%0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
1172+
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
11731173
```
11741174
}];
11751175
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
1176-
Variadic<Index>:$indices);
1176+
Variadic<IntegerIndexOrOpaqueType>:$indices);
11771177
let results = (outs AnyType:$result);
11781178

11791179
let builders = [
@@ -1183,7 +1183,7 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript",
11831183
];
11841184

11851185
let hasVerifier = 1;
1186-
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
1186+
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
11871187
}
11881188

11891189

mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
2222
// EmitC type definitions
2323
//===----------------------------------------------------------------------===//
2424

25-
def Valid_EmitC_Integer_Type : Type<CPred<"emitc::isValidEmitCIntegerType($_self)">,
26-
"EmitC integer type">;
25+
def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">,
26+
"integer type supported by EmitC">;
2727

28-
def Valid_EmitC_Float_Type : Type<CPred<"emitc::isValidEmitCFloatType($_self)">,
29-
"EmitC floating-point type">;
28+
def EmitCFloatType : Type<CPred<"emitc::isSupportedFloatType($_self)">,
29+
"floating-point type supported by EmitC">;
3030

3131
class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
3232
: TypeDef<EmitC_Dialect, name, traits> {
@@ -45,7 +45,7 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
4545
// Array emitted as `int32_t[10]`
4646
!emitc.array<10xi32>
4747
// Array emitted as `float[10][20]`
48-
!emitc.ptr<10x20xf32>
48+
!emitc.array<10x20xf32>
4949
```
5050
}];
5151

mlir/include/mlir/Target/Cpp/CppEmitter.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
#ifndef MLIR_TARGET_CPP_CPPEMITTER_H
1414
#define MLIR_TARGET_CPP_CPPEMITTER_H
1515

16-
#include "mlir/IR/BuiltinTypes.h"
17-
#include "mlir/IR/Value.h"
18-
#include "llvm/ADT/ScopedHashTable.h"
1916
#include "llvm/Support/raw_ostream.h"
20-
#include <stack>
2117

2218
namespace mlir {
19+
struct LogicalResult;
20+
class Operation;
2321
namespace emitc {
2422

2523
/// Translates the given operation to C++ code. The operation or operations in

mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_conversion_library(MLIRMemRefToEmitC
22
MemRefToEmitC.cpp
3+
MemRefToEmitCPass.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
@@ -12,8 +13,6 @@ add_mlir_conversion_library(MLIRMemRefToEmitC
1213

1314
LINK_LIBS PUBLIC
1415
MLIREmitCDialect
15-
MLIRFuncDialect
16-
MLIRFuncTransforms
1716
MLIRMemRefDialect
1817
MLIRTransforms
1918
)

mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp

Lines changed: 55 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -6,53 +6,49 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file implements a pass to convert memref ops into emitc ops.
9+
// This file implements patterns to convert memref ops into emitc ops.
1010
//
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
1414

1515
#include "mlir/Dialect/EmitC/IR/EmitC.h"
16-
#include "mlir/Dialect/Func/IR/FuncOps.h"
17-
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
1816
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1917
#include "mlir/IR/Builders.h"
20-
#include "mlir/IR/BuiltinOps.h"
21-
#include "mlir/IR/IRMapping.h"
22-
#include "mlir/IR/MLIRContext.h"
2318
#include "mlir/IR/PatternMatch.h"
24-
#include "mlir/Interfaces/FunctionInterfaces.h"
2519
#include "mlir/Transforms/DialectConversion.h"
26-
#include "mlir/Transforms/Passes.h"
27-
28-
namespace mlir {
29-
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
30-
#include "mlir/Conversion/Passes.h.inc"
31-
} // namespace mlir
3220

3321
using namespace mlir;
3422

3523
namespace {
24+
struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
25+
using OpConversionPattern::OpConversionPattern;
3626

37-
/// Disallow all memrefs even though we only have conversions
38-
/// for memrefs with static shape right now to have good diagnostics.
39-
bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
27+
LogicalResult
28+
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
29+
ConversionPatternRewriter &rewriter) const override {
4030

41-
template <typename RangeT>
42-
std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
43-
!std::is_convertible<RangeT, Operation *>::value,
44-
bool>
45-
isLegal(RangeT &&range) {
46-
return llvm::all_of(range, [](Type type) { return isLegal(type); });
47-
}
31+
if (!op.getType().hasStaticShape()) {
32+
return rewriter.notifyMatchFailure(
33+
op.getLoc(), "cannot transform alloca with dynamic shape");
34+
}
4835

49-
bool isLegal(Operation *op) {
50-
return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
51-
}
36+
if (op.getAlignment().value_or(1) > 1) {
37+
// TODO: Allow alignment if it is not more than the natural alignment
38+
// of the C array.
39+
return rewriter.notifyMatchFailure(
40+
op.getLoc(), "cannot transform alloca with alignment requirement");
41+
}
5242

53-
bool isSignatureLegal(FunctionType ty) {
54-
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
55-
}
43+
auto resultTy = getTypeConverter()->convertType(op.getType());
44+
if (!resultTy) {
45+
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
46+
}
47+
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
48+
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
49+
return success();
50+
}
51+
};
5652

5753
struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
5854
using OpConversionPattern::OpConversionPattern;
@@ -61,8 +57,20 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
6157
matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
6258
ConversionPatternRewriter &rewriter) const override {
6359

64-
rewriter.replaceOpWithNewOp<emitc::SubscriptOp>(op, operands.getMemref(),
65-
operands.getIndices());
60+
auto resultTy = getTypeConverter()->convertType(op.getType());
61+
if (!resultTy) {
62+
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
63+
}
64+
65+
auto subscript = rewriter.create<emitc::SubscriptOp>(
66+
op.getLoc(), operands.getMemref(), operands.getIndices());
67+
68+
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
69+
auto var =
70+
rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);
71+
72+
rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
73+
rewriter.replaceOp(op, var);
6674
return success();
6775
}
6876
};
@@ -81,90 +89,26 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
8189
return success();
8290
}
8391
};
92+
} // namespace
8493

85-
struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
86-
using OpConversionPattern::OpConversionPattern;
87-
88-
LogicalResult
89-
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
90-
ConversionPatternRewriter &rewriter) const override {
91-
92-
if (!op.getType().hasStaticShape()) {
93-
return rewriter.notifyMatchFailure(
94-
op.getLoc(), "cannot transform alloca with dynamic shape");
95-
}
96-
97-
if (op.getAlignment().value_or(1) > 1) {
98-
// TODO: Allow alignment if it is not more than the natural alignment
99-
// of the C array.
100-
return rewriter.notifyMatchFailure(
101-
op.getLoc(), "cannot transform alloca with alignment requirement");
102-
}
103-
104-
auto resultTy = getTypeConverter()->convertType(op.getType());
105-
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
106-
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
107-
return success();
108-
}
109-
};
110-
111-
struct ConvertMemRefToEmitCPass
112-
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
113-
void runOnOperation() override {
114-
TypeConverter converter;
115-
// Pass through for all other types.
116-
converter.addConversion([](Type type) { return type; });
117-
118-
converter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
119-
if (memRefType.hasStaticShape()) {
94+
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
95+
typeConverter.addConversion(
96+
[&](MemRefType memRefType) -> std::optional<Type> {
97+
if (!memRefType.hasStaticShape() ||
98+
!memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
99+
return {};
100+
}
101+
Type convertedElementType =
102+
typeConverter.convertType(memRefType.getElementType());
103+
if (!convertedElementType)
104+
return {};
120105
return emitc::ArrayType::get(memRefType.getShape(),
121-
memRefType.getElementType());
122-
}
123-
return {};
124-
});
125-
126-
converter.addConversion(
127-
[&converter](FunctionType ty) -> std::optional<Type> {
128-
SmallVector<Type> inputs;
129-
if (failed(converter.convertTypes(ty.getInputs(), inputs)))
130-
return std::nullopt;
131-
132-
SmallVector<Type> results;
133-
if (failed(converter.convertTypes(ty.getResults(), results)))
134-
return std::nullopt;
135-
136-
return FunctionType::get(ty.getContext(), inputs, results);
137-
});
138-
139-
RewritePatternSet patterns(&getContext());
140-
populateMemRefToEmitCConversionPatterns(patterns, converter);
141-
142-
ConversionTarget target(getContext());
143-
target.addDynamicallyLegalOp<func::FuncOp>(
144-
[](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
145-
target.addDynamicallyLegalDialect<func::FuncDialect>(
146-
[](Operation *op) { return isLegal(op); });
147-
target.addIllegalDialect<memref::MemRefDialect>();
148-
target.addLegalDialect<emitc::EmitCDialect>();
149-
150-
if (failed(applyPartialConversion(getOperation(), target,
151-
std::move(patterns))))
152-
return signalPassFailure();
153-
}
154-
};
155-
} // namespace
106+
convertedElementType);
107+
});
108+
}
156109

157110
void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
158111
TypeConverter &converter) {
159-
160-
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
161-
converter);
162-
populateCallOpTypeConversionPattern(patterns, converter);
163-
populateReturnOpTypeConversionPattern(patterns, converter);
164-
patterns.add<ConvertLoad, ConvertStore, ConvertAlloca>(converter,
112+
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
165113
patterns.getContext());
166114
}
167-
168-
std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
169-
return std::make_unique<ConvertMemRefToEmitCPass>();
170-
}

0 commit comments

Comments
 (0)