Skip to content

NFC: resolve TODO in LLVM dialect conversions #91497

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2024

Conversation

christopherbate
Copy link
Contributor

Relaxes restriction that certain public utility functions only apply
to the builtin ModuleOp.

@llvmbot
Copy link
Member

llvmbot commented May 8, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Christopher Bate (christopherbate)

Changes

Relaxes restriction that certain public utility functions only apply
to the builtin ModuleOp.


Full diff: https://github.com/llvm/llvm-project/pull/91497.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h (+19-19)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp (+7-9)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp (+25-21)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 123ce36cb0a79..852490cf7428f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -33,36 +33,36 @@ class LLVMFuncOp;
 /// external C function calls. The list of functions provided here must be
 /// implemented separately (e.g. as part of a support runtime library or as part
 /// of the libc).
-LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
 LLVM::LLVMFuncOp
-lookupOrCreatePrintStringFn(ModuleOp moduleOp,
+lookupOrCreatePrintStringFn(Operation *moduleOp,
                             std::optional<StringRef> runtimeFunctionName = {});
-LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                               Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                               Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                      Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
                                             Type unrankedDescriptorType);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
-LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
                                   ArrayRef<Type> paramTypes = {},
                                   Type resultType = {}, bool isVarArg = false);
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index b29abc94ce400..e48ca5180b706 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -10,18 +10,14 @@
 #include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/SymbolTable.h"
 
 using namespace mlir;
 
 namespace {
-// TODO: Fix the LLVM utilities for looking up functions to take Operation*
-// with SymbolTable trait instead of ModuleOp and make similar change here. This
-// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
-// of getParentOfType<ModuleOp> to pass down the operation.
 LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                      ModuleOp module, Type indexType) {
+                                      Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
   if (useGenericFn)
     return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
 
@@ -29,7 +25,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
 }
 
 LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                   ModuleOp module, Type indexType) {
+                                   Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
@@ -79,7 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
   // Allocate the underlying buffer.
   Type elementPtrType = this->getElementPtrType(memRefType);
   LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
-      getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+      getIndexType());
   auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
 
   Value allocatedPtr =
@@ -144,7 +141,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
 
   Type elementPtrType = this->getElementPtrType(memRefType);
   LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
-      getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+      getIndexType());
   auto results = rewriter.create<LLVM::CallOp>(
       loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 0004c2e3403e5..88421a16ccf9f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,49 +45,53 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 
 /// Generic print function lookupOrCreate helper.
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
+                                              StringRef name,
                                               ArrayRef<Type> paramTypes,
                                               Type resultType, bool isVarArg) {
-  auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
+  assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+         "expected SymbolTable operation");
+  auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
+      SymbolTable::lookupSymbolIn(moduleOp, name));
   if (func)
     return func;
-  OpBuilder b(moduleOp.getBodyRegion());
+  OpBuilder b(moduleOp->getRegion(0));
   return b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
       LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintI64,
                           IntegerType::get(moduleOp->getContext(), 64),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintU64,
                           IntegerType::get(moduleOp->getContext(), 64),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintBF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF32,
                           Float32Type::get(moduleOp->getContext()),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF64,
                           Float64Type::get(moduleOp->getContext()),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
@@ -103,72 +107,72 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
-    ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
+    Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
   return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
                           getCharPtr(moduleOp->getContext()),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintClose, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintComma, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
                                                     Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
                                 getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
                                 getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
                                 getVoidPtr(moduleOp->getContext()));
 }
 
 LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
                                 {indexType, indexType},
                                 getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
                                        Type unrankedDescriptorType) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kMemRefCopy,

@llvmbot
Copy link
Member

llvmbot commented May 8, 2024

@llvm/pr-subscribers-mlir

Author: Christopher Bate (christopherbate)

Changes

Relaxes restriction that certain public utility functions only apply
to the builtin ModuleOp.


Full diff: https://github.com/llvm/llvm-project/pull/91497.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h (+19-19)
  • (modified) mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp (+7-9)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp (+25-21)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
index 123ce36cb0a79..852490cf7428f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
@@ -33,36 +33,36 @@ class LLVMFuncOp;
 /// external C function calls. The list of functions provided here must be
 /// implemented separately (e.g. as part of a support runtime library or as part
 /// of the libc).
-LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
 /// Declares a function to print a C-string.
 /// If a custom runtime function is defined via `runtimeFunctionName`, it must
 /// have the signature void(char const*). The default function is `printString`.
 LLVM::LLVMFuncOp
-lookupOrCreatePrintStringFn(ModuleOp moduleOp,
+lookupOrCreatePrintStringFn(Operation *moduleOp,
                             std::optional<StringRef> runtimeFunctionName = {});
-LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
+LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                               Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                               Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                      Type indexType);
-LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp);
-LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
+LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
                                             Type unrankedDescriptorType);
 
 /// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
-LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
                                   ArrayRef<Type> paramTypes = {},
                                   Type resultType = {}, bool isVarArg = false);
 
diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
index b29abc94ce400..e48ca5180b706 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
@@ -10,18 +10,14 @@
 #include "mlir/Analysis/DataLayoutAnalysis.h"
 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/IR/SymbolTable.h"
 
 using namespace mlir;
 
 namespace {
-// TODO: Fix the LLVM utilities for looking up functions to take Operation*
-// with SymbolTable trait instead of ModuleOp and make similar change here. This
-// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
-// of getParentOfType<ModuleOp> to pass down the operation.
 LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                      ModuleOp module, Type indexType) {
+                                      Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
-
   if (useGenericFn)
     return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
 
@@ -29,7 +25,7 @@ LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
 }
 
 LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
-                                   ModuleOp module, Type indexType) {
+                                   Operation *module, Type indexType) {
   bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
 
   if (useGenericFn)
@@ -79,7 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
   // Allocate the underlying buffer.
   Type elementPtrType = this->getElementPtrType(memRefType);
   LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
-      getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+      getIndexType());
   auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
 
   Value allocatedPtr =
@@ -144,7 +141,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
 
   Type elementPtrType = this->getElementPtrType(memRefType);
   LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
-      getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
+      getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
+      getIndexType());
   auto results = rewriter.create<LLVM::CallOp>(
       loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
index 0004c2e3403e5..88421a16ccf9f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp
@@ -45,49 +45,53 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
 
 /// Generic print function lookupOrCreate helper.
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
+                                              StringRef name,
                                               ArrayRef<Type> paramTypes,
                                               Type resultType, bool isVarArg) {
-  auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
+  assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
+         "expected SymbolTable operation");
+  auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
+      SymbolTable::lookupSymbolIn(moduleOp, name));
   if (func)
     return func;
-  OpBuilder b(moduleOp.getBodyRegion());
+  OpBuilder b(moduleOp->getRegion(0));
   return b.create<LLVM::LLVMFuncOp>(
       moduleOp->getLoc(), name,
       LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintI64,
                           IntegerType::get(moduleOp->getContext(), 64),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintU64,
                           IntegerType::get(moduleOp->getContext(), 64),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintBF16,
                           IntegerType::get(moduleOp->getContext(), 16), // bits!
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF32,
                           Float32Type::get(moduleOp->getContext()),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintF64,
                           Float64Type::get(moduleOp->getContext()),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
@@ -103,72 +107,72 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
 }
 
 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
-    ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
+    Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
   return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
                           getCharPtr(moduleOp->getContext()),
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintClose, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintComma, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
                                                     Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
                                 getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
                                 getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
                                                           Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
                                 getVoidPtr(moduleOp->getContext()));
 }
 
 LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
+mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
                                                 Type indexType) {
   return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
                                 {indexType, indexType},
                                 getVoidPtr(moduleOp->getContext()));
 }
 
-LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
+LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
       LLVM::LLVMVoidType::get(moduleOp->getContext()));
 }
 
 LLVM::LLVMFuncOp
-mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
+mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
                                        Type unrankedDescriptorType) {
   return LLVM::lookupOrCreateFn(
       moduleOp, kMemRefCopy,

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for taking care of this 🙂

Relaxes restriction that certain public utility functions only apply
to the builtin ModuleOp.
@christopherbate christopherbate force-pushed the mlir-fix-llvm-dialect-todo branch from a020bb9 to d7b0113 Compare May 30, 2024 18:47
@Dinistro
Copy link
Contributor

Dinistro commented Jun 6, 2024

Can we land this?

@ftynse ftynse merged commit f543dfd into llvm:main Jun 6, 2024
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants