Skip to content

merge code for llvm.emit_c_interface into convertFuncOpToLLVMFuncOp #92986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 25 additions & 39 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,61 +449,47 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
"region types conversion failed");
}

if (!shouldUseBarePtrCallConv(funcOp, &converter)) {
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
if (newFuncOp.isVarArg())
return funcOp.emitError("C interface for variadic functions is not "
"supported yet.");

if (newFuncOp.isExternal())
wrapExternalFunction(rewriter, funcOp->getLoc(), converter, funcOp,
newFuncOp);
else
wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
newFuncOp);
}
} else {
modifyFuncOpToUseBarePtrCallingConv(
rewriter, funcOp->getLoc(), converter, newFuncOp,
llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
}

return newFuncOp;
}

namespace {

struct FuncOpConversionBase : public ConvertOpToLLVMPattern<func::FuncOp> {
protected:
using ConvertOpToLLVMPattern<func::FuncOp>::ConvertOpToLLVMPattern;

// Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided
// to this legalization pattern.
FailureOr<LLVM::LLVMFuncOp>
convertFuncOpToLLVMFuncOp(func::FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
return mlir::convertFuncOpToLLVMFuncOp(
cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
*getTypeConverter());
}
};

/// FuncOp legalization pattern that converts MemRef arguments to pointers to
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
/// information.
struct FuncOpConversion : public FuncOpConversionBase {
struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
FuncOpConversion(const LLVMTypeConverter &converter)
: FuncOpConversionBase(converter) {}
: ConvertOpToLLVMPattern(converter) {}

LogicalResult
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<LLVM::LLVMFuncOp> newFuncOp =
convertFuncOpToLLVMFuncOp(funcOp, rewriter);
FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
*getTypeConverter());
if (failed(newFuncOp))
return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");

if (!shouldUseBarePtrCallConv(funcOp, this->getTypeConverter())) {
if (funcOp->getAttrOfType<UnitAttr>(
LLVM::LLVMDialect::getEmitCWrapperAttrName())) {
if (newFuncOp->isVarArg())
return funcOp->emitError("C interface for variadic functions is not "
"supported yet.");

if (newFuncOp->isExternal())
wrapExternalFunction(rewriter, funcOp->getLoc(), *getTypeConverter(),
funcOp, *newFuncOp);
else
wrapForExternalCallers(rewriter, funcOp->getLoc(),
*getTypeConverter(), funcOp, *newFuncOp);
}
} else {
modifyFuncOpToUseBarePtrCallingConv(rewriter, funcOp->getLoc(),
*getTypeConverter(), *newFuncOp,
funcOp.getFunctionType().getInputs());
}

rewriter.eraseOp(funcOp);
return success();
}
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Transforms/test-convert-func-op.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: mlir-opt %s -test-convert-func-op | FileCheck %s

// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
%res = arith.addi %arg0, %arg1 : i32
return %res : i32
}
// CHECK-LABEL: llvm.func @_mlir_ciface_add
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]+]]: i32
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]+]]: i32
// CHECK-NEXT: [[RES:%.*]] = llvm.call @add([[ARG0]], [[ARG1]])
// CHECK-NEXT: llvm.return [[RES]]
1 change: 1 addition & 0 deletions mlir/test/lib/Conversion/FuncToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRTestFuncToLLVM
TestConvertCallOp.cpp
TestConvertFuncOp.cpp

EXCLUDE_FROM_LIBMLIR

Expand Down
93 changes: 93 additions & 0 deletions mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//===- TestConvertFuncOp.cpp - Test LLVM Conversion of Func FuncOp --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "TestDialect.h"

#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {

/// Test helper Conversion Pattern to directly call `convertFuncOpToLLVMFuncOp`
/// to verify this utility function includes all functionalities of conversion
struct FuncOpConversion : public ConvertOpToLLVMPattern<func::FuncOp> {
FuncOpConversion(const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter) {}

LogicalResult
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
FailureOr<LLVM::LLVMFuncOp> newFuncOp = mlir::convertFuncOpToLLVMFuncOp(
cast<FunctionOpInterface>(funcOp.getOperation()), rewriter,
*getTypeConverter());
if (failed(newFuncOp))
return rewriter.notifyMatchFailure(funcOp, "Could not convert funcop");

rewriter.eraseOp(funcOp);
return success();
}
};

struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
ReturnOpConversion(const LLVMTypeConverter &converter)
: ConvertOpToLLVMPattern(converter) {}

LogicalResult
matchAndRewrite(func::ReturnOp returnOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(returnOp,
returnOp->getOperands());
return success();
}
};

struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)

void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<LLVM::LLVMDialect>();
}

StringRef getArgument() const final { return "test-convert-func-op"; }

StringRef getDescription() const final {
return "Tests conversion of `func.func` to `llvm.func` for different "
"attributes";
}

void runOnOperation() override {
MLIRContext *ctx = &getContext();

LowerToLLVMOptions options(ctx);
// Populate type conversions.
LLVMTypeConverter typeConverter(ctx, options);

RewritePatternSet patterns(ctx);
patterns.add<FuncOpConversion>(typeConverter);
patterns.add<ReturnOpConversion>(typeConverter);

LLVMConversionTarget target(getContext());
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};

} // namespace

namespace mlir::test {
void registerConvertFuncOpPass() { PassRegistration<TestConvertFuncOp>(); }
} // namespace mlir::test
2 changes: 2 additions & 0 deletions mlir/tools/mlir-opt/mlir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ namespace test {
void registerTestCompositePass();
void registerCommutativityUtils();
void registerConvertCallOpPass();
void registerConvertFuncOpPass();
void registerInliner();
void registerMemRefBoundCheck();
void registerPatternsTestPass();
Expand Down Expand Up @@ -199,6 +200,7 @@ void registerTestPasses() {
mlir::test::registerTestCompositePass();
mlir::test::registerCommutativityUtils();
mlir::test::registerConvertCallOpPass();
mlir::test::registerConvertFuncOpPass();
mlir::test::registerInliner();
mlir::test::registerMemRefBoundCheck();
mlir::test::registerPatternsTestPass();
Expand Down
Loading