-
Notifications
You must be signed in to change notification settings - Fork 14.3k
NFC: resolve TODO in LLVM dialect conversions #91497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm Author: Christopher Bate (christopherbate) ChangesRelaxes restriction that certain public utility functions only apply Full diff: https://github.com/llvm/llvm-project/pull/91497.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 123ce36cb0a79..852490cf7428f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -33,36 +33,36 @@ class LLVMFuncOp;
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
-LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
LLVM::LLVMFuncOp
-lookupOrCreatePrintStringFn(ModuleOp moduleOp,
+lookupOrCreatePrintStringFn(Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName = {});
-LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
-LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
Type resultType = {}, bool isVarArg = false);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index b29abc94ce400..e48ca5180b706 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -10,18 +10,14 @@
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/SymbolTable.h"
using namespace mlir;
namespace {
-// TODO: Fix the LLVM utilities for looking up functions to take Operation*
-// with SymbolTable trait instead of ModuleOp and make similar change here. This
-// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
-// of getParentOfType<ModuleOp> to pass down the operation.
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
- ModuleOp module, Type indexType) {
+ Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
if (useGenericFn)
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
@@ -29,7 +25,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
}
LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
- ModuleOp module, Type indexType) {
+ Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
@@ -79,7 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
// Allocate the underlying buffer.
Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
- getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+ getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType());
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
Value allocatedPtr =
@@ -144,7 +141,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
- getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+ getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType());
auto results = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 0004c2e3403e5..88421a16ccf9f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,49 +45,53 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
/// Generic print function lookupOrCreate helper.
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
+ StringRef name,
ArrayRef<Type> paramTypes,
Type resultType, bool isVarArg) {
- auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
+ assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+ "expected SymbolTable operation");
+ auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, name));
if (func)
return func;
- OpBuilder b(moduleOp.getBodyRegion());
+ OpBuilder b(moduleOp->getRegion(0));
return b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintI64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintU64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF64,
Float64Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
@@ -103,72 +107,72 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
- ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
+ Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
getVoidPtr(moduleOp->getContext()));
}
LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
{indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType) {
return LLVM::lookupOrCreateFn(
moduleOp, kMemRefCopy,
|
@llvm/pr-subscribers-mlir Author: Christopher Bate (christopherbate) ChangesRelaxes restriction that certain public utility functions only apply Full diff: https://github.com/llvm/llvm-project/pull/91497.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 123ce36cb0a79..852490cf7428f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -33,36 +33,36 @@ class LLVMFuncOp;
/// external C function calls. The list of functions provided here must be
/// implemented separately (e.g. as part of a support runtime library or as part
/// of the libc).
-LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
/// Declares a function to print a C-string.
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
/// have the signature void(char const*). The default function is `printString`.
LLVM::LLVMFuncOp
-lookupOrCreatePrintStringFn(ModuleOp moduleOp,
+lookupOrCreatePrintStringFn(Operation *moduleOp,
std::optional<StringRef> runtimeFunctionName = {});
-LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType);
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
-LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
ArrayRef<Type> paramTypes = {},
Type resultType = {}, bool isVarArg = false);
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index b29abc94ce400..e48ca5180b706 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -10,18 +10,14 @@
#include "mlir/Analysis/DataLayoutAnalysis.h"
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/SymbolTable.h"
using namespace mlir;
namespace {
-// TODO: Fix the LLVM utilities for looking up functions to take Operation*
-// with SymbolTable trait instead of ModuleOp and make similar change here. This
-// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
-// of getParentOfType<ModuleOp> to pass down the operation.
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
- ModuleOp module, Type indexType) {
+ Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
if (useGenericFn)
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
@@ -29,7 +25,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
}
LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
- ModuleOp module, Type indexType) {
+ Operation *module, Type indexType) {
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
if (useGenericFn)
@@ -79,7 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
// Allocate the underlying buffer.
Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
- getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+ getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType());
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
Value allocatedPtr =
@@ -144,7 +141,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
Type elementPtrType = this->getElementPtrType(memRefType);
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
- getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+ getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+ getIndexType());
auto results = rewriter.create<LLVM::CallOp>(
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 0004c2e3403e5..88421a16ccf9f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,49 +45,53 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
/// Generic print function lookupOrCreate helper.
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
+ StringRef name,
ArrayRef<Type> paramTypes,
Type resultType, bool isVarArg) {
- auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
+ assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+ "expected SymbolTable operation");
+ auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(moduleOp, name));
if (func)
return func;
- OpBuilder b(moduleOp.getBodyRegion());
+ OpBuilder b(moduleOp->getRegion(0));
return b.create<LLVM::LLVMFuncOp>(
moduleOp->getLoc(), name,
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintI64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintU64,
IntegerType::get(moduleOp->getContext(), 64),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintBF16,
IntegerType::get(moduleOp->getContext(), 16), // bits!
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF32,
Float32Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintF64,
Float64Type::get(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
@@ -103,72 +107,72 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
}
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
- ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
+ Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
getCharPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintClose, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintComma, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
getVoidPtr(moduleOp->getContext()));
}
LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
Type indexType) {
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
{indexType, indexType},
getVoidPtr(moduleOp->getContext()));
}
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
return LLVM::lookupOrCreateFn(
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
LLVM::LLVMVoidType::get(moduleOp->getContext()));
}
LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
Type unrankedDescriptorType) {
return LLVM::lookupOrCreateFn(
moduleOp, kMemRefCopy,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for taking care of this 🙂
Relaxes restriction that certain public utility functions only apply to the builtin ModuleOp.
a020bb9
to
d7b0113
Compare
Can we land this? |
Relaxes restriction that certain public utility functions only apply
to the builtin ModuleOp.