Skip to content

Backport EmitC-related commits from upstream (FXML-4370) #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,14 @@
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H

#include "mlir/Pass/Pass.h"

namespace mlir {
class RewritePatternSet;
class TypeConverter;

#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
#include "mlir/Conversion/Passes.h.inc"
void populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter);

void populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter);

std::unique_ptr<OperationPass<>> createConvertMemRefToEmitCPass();

TypeConverter &converter);
} // namespace mlir

#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITC_H
20 changes: 20 additions & 0 deletions mlir/include/mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- MemRefToEmitCPass.h - A Pass to convert MemRef to EmitC ------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
#define MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H

#include <memory>

namespace mlir {
class Pass;

#define GEN_PASS_DECL_CONVERTMEMREFTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

#endif // MLIR_CONVERSION_MEMREFTOEMITC_MEMREFTOEMITCPASS_H
2 changes: 1 addition & 1 deletion mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h"
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ namespace mlir {
namespace emitc {
void buildTerminatedBody(OpBuilder &builder, Location loc);
/// Determines whether \p type is a valid integer type in EmitC.
bool isValidEmitCIntegerType(mlir::Type type);
bool isSupportedIntegerType(mlir::Type type);
/// Determines whether \p type is a valid floating-point type in EmitC.
bool isValidEmitCFloatType(mlir::Type type);
bool isSupportedFloatType(mlir::Type type);
} // namespace emitc
} // namespace mlir

Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class EmitC_BinaryOp<string mnemonic, list<Trait> traits = []> :
def CExpression : NativeOpTrait<"emitc::CExpression">;

// Types only used in binary arithmetic operations.
def IntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Integer_Type, Index, EmitC_OpaqueType]>;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[Valid_EmitC_Float_Type, IntegerIndexOrOpaqueType]>;
def IntegerIndexOrOpaqueType : AnyTypeOf<[EmitCIntegerType, Index, EmitC_OpaqueType]>;
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;

def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
let summary = "Addition operation";
Expand Down Expand Up @@ -1169,11 +1169,11 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript",
```mlir
%i = index.constant 1
%j = index.constant 7
%0 = emitc.subscript %arg0[%i][%j] : (!emitc.array<4x8xf32>) -> f32
%0 = emitc.subscript %arg0[%i, %j] : <4x8xf32>, index, index
```
}];
let arguments = (ins Arg<EmitC_ArrayType, "the reference to load from">:$array,
Variadic<Index>:$indices);
Variadic<IntegerIndexOrOpaqueType>:$indices);
let results = (outs AnyType:$result);

let builders = [
Expand All @@ -1183,7 +1183,7 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript",
];

let hasVerifier = 1;
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array)";
let assemblyFormat = "$array `[` $indices `]` attr-dict `:` type($array) `,` type($indices)";
}


Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// EmitC type definitions
//===----------------------------------------------------------------------===//

def Valid_EmitC_Integer_Type : Type<CPred<"emitc::isValidEmitCIntegerType($_self)">,
"EmitC integer type">;
def EmitCIntegerType : Type<CPred<"emitc::isSupportedIntegerType($_self)">,
"integer type supported by EmitC">;

def Valid_EmitC_Float_Type : Type<CPred<"emitc::isValidEmitCFloatType($_self)">,
"EmitC floating-point type">;
def EmitCFloatType : Type<CPred<"emitc::isSupportedFloatType($_self)">,
"floating-point type supported by EmitC">;

class EmitC_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<EmitC_Dialect, name, traits> {
Expand All @@ -45,7 +45,7 @@ def EmitC_ArrayType : EmitC_Type<"Array", "array", [ShapedTypeInterface]> {
// Array emitted as `int32_t[10]`
!emitc.array<10xi32>
// Array emitted as `float[10][20]`
!emitc.ptr<10x20xf32>
!emitc.array<10x20xf32>
```
}];

Expand Down
6 changes: 2 additions & 4 deletions mlir/include/mlir/Target/Cpp/CppEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
#ifndef MLIR_TARGET_CPP_CPPEMITTER_H
#define MLIR_TARGET_CPP_CPPEMITTER_H

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/Support/raw_ostream.h"
#include <stack>

namespace mlir {
struct LogicalResult;
class Operation;
namespace emitc {

/// Translates the given operation to C++ code. The operation or operations in
Expand Down
3 changes: 1 addition & 2 deletions mlir/lib/Conversion/MemRefToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
add_mlir_conversion_library(MLIRMemRefToEmitC
MemRefToEmitC.cpp
MemRefToEmitCPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MemRefToEmitC
Expand All @@ -12,8 +13,6 @@ add_mlir_conversion_library(MLIRMemRefToEmitC

LINK_LIBS PUBLIC
MLIREmitCDialect
MLIRFuncDialect
MLIRFuncTransforms
MLIRMemRefDialect
MLIRTransforms
)
166 changes: 55 additions & 111 deletions mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,53 +6,49 @@
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to convert memref ops into emitc ops.
// This file implements patterns to convert memref ops into emitc ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"

#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"

namespace mlir {
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {
struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
using OpConversionPattern::OpConversionPattern;

/// Disallow all memrefs even though we only have conversions
/// for memrefs with static shape right now to have good diagnostics.
bool isLegal(Type t) { return !isa<BaseMemRefType>(t); }
LogicalResult
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

template <typename RangeT>
std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
!std::is_convertible<RangeT, Operation *>::value,
bool>
isLegal(RangeT &&range) {
return llvm::all_of(range, [](Type type) { return isLegal(type); });
}
if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with dynamic shape");
}

bool isLegal(Operation *op) {
return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes());
}
if (op.getAlignment().value_or(1) > 1) {
// TODO: Allow alignment if it is not more than the natural alignment
// of the C array.
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with alignment requirement");
}

bool isSignatureLegal(FunctionType ty) {
return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults()));
}
auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
return success();
}
};

struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
using OpConversionPattern::OpConversionPattern;
Expand All @@ -61,8 +57,20 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
matchAndRewrite(memref::LoadOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

rewriter.replaceOpWithNewOp<emitc::SubscriptOp>(op, operands.getMemref(),
operands.getIndices());
auto resultTy = getTypeConverter()->convertType(op.getType());
if (!resultTy) {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}

auto subscript = rewriter.create<emitc::SubscriptOp>(
op.getLoc(), operands.getMemref(), operands.getIndices());

auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
rewriter.create<emitc::VariableOp>(op.getLoc(), resultTy, noInit);

rewriter.create<emitc::AssignOp>(op.getLoc(), var, subscript);
rewriter.replaceOp(op, var);
return success();
}
};
Expand All @@ -81,90 +89,26 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
return success();
}
};
} // namespace

struct ConvertAlloca final : public OpConversionPattern<memref::AllocaOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::AllocaOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {

if (!op.getType().hasStaticShape()) {
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with dynamic shape");
}

if (op.getAlignment().value_or(1) > 1) {
// TODO: Allow alignment if it is not more than the natural alignment
// of the C array.
return rewriter.notifyMatchFailure(
op.getLoc(), "cannot transform alloca with alignment requirement");
}

auto resultTy = getTypeConverter()->convertType(op.getType());
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
rewriter.replaceOpWithNewOp<emitc::VariableOp>(op, resultTy, noInit);
return success();
}
};

struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
void runOnOperation() override {
TypeConverter converter;
// Pass through for all other types.
converter.addConversion([](Type type) { return type; });

converter.addConversion([](MemRefType memRefType) -> std::optional<Type> {
if (memRefType.hasStaticShape()) {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
if (!memRefType.hasStaticShape() ||
!memRefType.getLayout().isIdentity() || memRefType.getRank() == 0) {
return {};
}
Type convertedElementType =
typeConverter.convertType(memRefType.getElementType());
if (!convertedElementType)
return {};
return emitc::ArrayType::get(memRefType.getShape(),
memRefType.getElementType());
}
return {};
});

converter.addConversion(
[&converter](FunctionType ty) -> std::optional<Type> {
SmallVector<Type> inputs;
if (failed(converter.convertTypes(ty.getInputs(), inputs)))
return std::nullopt;

SmallVector<Type> results;
if (failed(converter.convertTypes(ty.getResults(), results)))
return std::nullopt;

return FunctionType::get(ty.getContext(), inputs, results);
});

RewritePatternSet patterns(&getContext());
populateMemRefToEmitCConversionPatterns(patterns, converter);

ConversionTarget target(getContext());
target.addDynamicallyLegalOp<func::FuncOp>(
[](func::FuncOp op) { return isSignatureLegal(op.getFunctionType()); });
target.addDynamicallyLegalDialect<func::FuncDialect>(
[](Operation *op) { return isLegal(op); });
target.addIllegalDialect<memref::MemRefDialect>();
target.addLegalDialect<emitc::EmitCDialect>();

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
convertedElementType);
});
}

void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &converter) {

populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
converter);
populateCallOpTypeConversionPattern(patterns, converter);
populateReturnOpTypeConversionPattern(patterns, converter);
patterns.add<ConvertLoad, ConvertStore, ConvertAlloca>(converter,
patterns.add<ConvertAlloca, ConvertLoad, ConvertStore>(converter,
patterns.getContext());
}

std::unique_ptr<OperationPass<>> mlir::createConvertMemRefToEmitCPass() {
return std::make_unique<ConvertMemRefToEmitCPass>();
}
Loading