Skip to content

NFC: resolve TODO in LLVM dialect conversions #91497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,36 +33,36 @@ class LLVMFuncOp;
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
LLVM::LLVMFuncOp
lookupOrCreatePrintStringFn(ModuleOp moduleOp,
lookupOrCreatePrintStringFn(Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName = {});
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType);
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);

/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
Type resultType = {}, bool isVarArg = false);

Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,22 @@
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/SymbolTable.h"

using namespace mlir;

namespace {
// TODO: Fix the LLVM utilities for looking up functions to take Operation*
// with SymbolTable trait instead of ModuleOp and make similar change here. This
// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
// of getParentOfType<ModuleOp> to pass down the operation.
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
ModuleOp module, Type indexType) {
Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;

if (useGenericFn)
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);

return LLVM::lookupOrCreateMallocFn(module, indexType);
}

LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
ModuleOp module, Type indexType) {
Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;

if (useGenericFn)
Expand Down Expand Up @@ -79,7 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
// Allocate the underlying buffer.
Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);

Value allocatedPtr =
Expand Down Expand Up @@ -144,7 +141,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(

Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
getIndexType());
auto results = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));

Expand Down
46 changes: 25 additions & 21 deletions mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,49 +45,53 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";

/// Generic print function lookupOrCreate helper.
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
StringRef name,
ArrayRef<Type> paramTypes,
Type resultType, bool isVarArg) {
auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
"expected SymbolTable operation");
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
SymbolTable::lookupSymbolIn(moduleOp, name));
if (func)
return func;
OpBuilder b(moduleOp.getBodyRegion());
OpBuilder b(moduleOp->getRegion(0));
return b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintI64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintU64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF64,
Float64Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
Expand All @@ -103,72 +107,72 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
getVoidPtr(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
getVoidPtr(moduleOp->getContext()));
}

LLVM::LLVMFuncOp
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
{indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}

LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}

LLVM::LLVMFuncOp
mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType) {
return LLVM::lookupOrCreateFn(
moduleOp, kMemRefCopy,
Expand Down
Loading