Skip to content

Commit f543dfd

Browse files
NFC: resolve TODO in LLVM dialect conversions (#91497)
Relaxes restriction that certain public utility functions only apply to the builtin ModuleOp.
1 parent 878deae commit f543dfd

File tree

3 files changed

+51
-49
lines changed

3 files changed

+51
-49
lines changed

mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,36 +33,36 @@ class LLVMFuncOp;
3333
/// external C function calls. The list of functions provided here must be
3434
/// implemented separately (e.g. as part of a support runtime library or as part
3535
/// of the libc).
36-
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(ModuleOp moduleOp);
37-
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(ModuleOp moduleOp);
38-
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(ModuleOp moduleOp);
39-
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(ModuleOp moduleOp);
40-
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(ModuleOp moduleOp);
41-
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(ModuleOp moduleOp);
36+
LLVM::LLVMFuncOp lookupOrCreatePrintI64Fn(Operation *moduleOp);
37+
LLVM::LLVMFuncOp lookupOrCreatePrintU64Fn(Operation *moduleOp);
38+
LLVM::LLVMFuncOp lookupOrCreatePrintF16Fn(Operation *moduleOp);
39+
LLVM::LLVMFuncOp lookupOrCreatePrintBF16Fn(Operation *moduleOp);
40+
LLVM::LLVMFuncOp lookupOrCreatePrintF32Fn(Operation *moduleOp);
41+
LLVM::LLVMFuncOp lookupOrCreatePrintF64Fn(Operation *moduleOp);
4242
/// Declares a function to print a C-string.
4343
/// If a custom runtime function is defined via `runtimeFunctionName`, it must
4444
/// have the signature void(char const*). The default function is `printString`.
4545
LLVM::LLVMFuncOp
46-
lookupOrCreatePrintStringFn(ModuleOp moduleOp,
46+
lookupOrCreatePrintStringFn(Operation *moduleOp,
4747
std::optional<StringRef> runtimeFunctionName = {});
48-
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(ModuleOp moduleOp);
49-
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(ModuleOp moduleOp);
50-
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(ModuleOp moduleOp);
51-
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(ModuleOp moduleOp);
52-
LLVM::LLVMFuncOp lookupOrCreateMallocFn(ModuleOp moduleOp, Type indexType);
53-
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
48+
LLVM::LLVMFuncOp lookupOrCreatePrintOpenFn(Operation *moduleOp);
49+
LLVM::LLVMFuncOp lookupOrCreatePrintCloseFn(Operation *moduleOp);
50+
LLVM::LLVMFuncOp lookupOrCreatePrintCommaFn(Operation *moduleOp);
51+
LLVM::LLVMFuncOp lookupOrCreatePrintNewlineFn(Operation *moduleOp);
52+
LLVM::LLVMFuncOp lookupOrCreateMallocFn(Operation *moduleOp, Type indexType);
53+
LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(Operation *moduleOp,
5454
Type indexType);
55-
LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp);
56-
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
55+
LLVM::LLVMFuncOp lookupOrCreateFreeFn(Operation *moduleOp);
56+
LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(Operation *moduleOp,
5757
Type indexType);
58-
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
58+
LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
5959
Type indexType);
60-
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp);
61-
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
60+
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(Operation *moduleOp);
61+
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
6262
Type unrankedDescriptorType);
6363

6464
/// Create a FuncOp with signature `resultType`(`paramTypes`)` and name `name`.
65-
LLVM::LLVMFuncOp lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
65+
LLVM::LLVMFuncOp lookupOrCreateFn(Operation *moduleOp, StringRef name,
6666
ArrayRef<Type> paramTypes = {},
6767
Type resultType = {}, bool isVarArg = false);
6868

mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,22 @@
1010
#include "mlir/Analysis/DataLayoutAnalysis.h"
1111
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1212
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
13+
#include "mlir/IR/SymbolTable.h"
1314

1415
using namespace mlir;
1516

1617
namespace {
17-
// TODO: Fix the LLVM utilities for looking up functions to take Operation*
18-
// with SymbolTable trait instead of ModuleOp and make similar change here. This
19-
// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead
20-
// of getParentOfType<ModuleOp> to pass down the operation.
2118
LLVM::LLVMFuncOp getNotalignedAllocFn(const LLVMTypeConverter *typeConverter,
22-
ModuleOp module, Type indexType) {
19+
Operation *module, Type indexType) {
2320
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
24-
2521
if (useGenericFn)
2622
return LLVM::lookupOrCreateGenericAllocFn(module, indexType);
2723

2824
return LLVM::lookupOrCreateMallocFn(module, indexType);
2925
}
3026

3127
LLVM::LLVMFuncOp getAlignedAllocFn(const LLVMTypeConverter *typeConverter,
32-
ModuleOp module, Type indexType) {
28+
Operation *module, Type indexType) {
3329
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;
3430

3531
if (useGenericFn)
@@ -79,7 +75,8 @@ std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign(
7975
// Allocate the underlying buffer.
8076
Type elementPtrType = this->getElementPtrType(memRefType);
8177
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn(
82-
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
78+
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
79+
getIndexType());
8380
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes);
8481

8582
Value allocatedPtr =
@@ -144,7 +141,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
144141

145142
Type elementPtrType = this->getElementPtrType(memRefType);
146143
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn(
147-
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType());
144+
getTypeConverter(), op->getParentWithTrait<OpTrait::SymbolTable>(),
145+
getIndexType());
148146
auto results = rewriter.create<LLVM::CallOp>(
149147
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes}));
150148

mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,49 +45,53 @@ static constexpr llvm::StringRef kGenericFree = "_mlir_memref_to_llvm_free";
4545
static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
4646

4747
/// Generic print function lookupOrCreate helper.
48-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
48+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(Operation *moduleOp,
49+
StringRef name,
4950
ArrayRef<Type> paramTypes,
5051
Type resultType, bool isVarArg) {
51-
auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
52+
assert(moduleOp->hasTrait<OpTrait::SymbolTable>() &&
53+
"expected SymbolTable operation");
54+
auto func = llvm::dyn_cast_or_null<LLVM::LLVMFuncOp>(
55+
SymbolTable::lookupSymbolIn(moduleOp, name));
5256
if (func)
5357
return func;
54-
OpBuilder b(moduleOp.getBodyRegion());
58+
OpBuilder b(moduleOp->getRegion(0));
5559
return b.create<LLVM::LLVMFuncOp>(
5660
moduleOp->getLoc(), name,
5761
LLVM::LLVMFunctionType::get(resultType, paramTypes, isVarArg));
5862
}
5963

60-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
64+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(Operation *moduleOp) {
6165
return lookupOrCreateFn(moduleOp, kPrintI64,
6266
IntegerType::get(moduleOp->getContext(), 64),
6367
LLVM::LLVMVoidType::get(moduleOp->getContext()));
6468
}
6569

66-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
70+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(Operation *moduleOp) {
6771
return lookupOrCreateFn(moduleOp, kPrintU64,
6872
IntegerType::get(moduleOp->getContext(), 64),
6973
LLVM::LLVMVoidType::get(moduleOp->getContext()));
7074
}
7175

72-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(ModuleOp moduleOp) {
76+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF16Fn(Operation *moduleOp) {
7377
return lookupOrCreateFn(moduleOp, kPrintF16,
7478
IntegerType::get(moduleOp->getContext(), 16), // bits!
7579
LLVM::LLVMVoidType::get(moduleOp->getContext()));
7680
}
7781

78-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(ModuleOp moduleOp) {
82+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintBF16Fn(Operation *moduleOp) {
7983
return lookupOrCreateFn(moduleOp, kPrintBF16,
8084
IntegerType::get(moduleOp->getContext(), 16), // bits!
8185
LLVM::LLVMVoidType::get(moduleOp->getContext()));
8286
}
8387

84-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
88+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(Operation *moduleOp) {
8589
return lookupOrCreateFn(moduleOp, kPrintF32,
8690
Float32Type::get(moduleOp->getContext()),
8791
LLVM::LLVMVoidType::get(moduleOp->getContext()));
8892
}
8993

90-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
94+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(Operation *moduleOp) {
9195
return lookupOrCreateFn(moduleOp, kPrintF64,
9296
Float64Type::get(moduleOp->getContext()),
9397
LLVM::LLVMVoidType::get(moduleOp->getContext()));
@@ -103,72 +107,72 @@ static LLVM::LLVMPointerType getVoidPtr(MLIRContext *context) {
103107
}
104108

105109
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintStringFn(
106-
ModuleOp moduleOp, std::optional<StringRef> runtimeFunctionName) {
110+
Operation *moduleOp, std::optional<StringRef> runtimeFunctionName) {
107111
return lookupOrCreateFn(moduleOp, runtimeFunctionName.value_or(kPrintString),
108112
getCharPtr(moduleOp->getContext()),
109113
LLVM::LLVMVoidType::get(moduleOp->getContext()));
110114
}
111115

112-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
116+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(Operation *moduleOp) {
113117
return lookupOrCreateFn(moduleOp, kPrintOpen, {},
114118
LLVM::LLVMVoidType::get(moduleOp->getContext()));
115119
}
116120

117-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
121+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(Operation *moduleOp) {
118122
return lookupOrCreateFn(moduleOp, kPrintClose, {},
119123
LLVM::LLVMVoidType::get(moduleOp->getContext()));
120124
}
121125

122-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
126+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(Operation *moduleOp) {
123127
return lookupOrCreateFn(moduleOp, kPrintComma, {},
124128
LLVM::LLVMVoidType::get(moduleOp->getContext()));
125129
}
126130

127-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
131+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(Operation *moduleOp) {
128132
return lookupOrCreateFn(moduleOp, kPrintNewline, {},
129133
LLVM::LLVMVoidType::get(moduleOp->getContext()));
130134
}
131135

132-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
136+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(Operation *moduleOp,
133137
Type indexType) {
134138
return LLVM::lookupOrCreateFn(moduleOp, kMalloc, indexType,
135139
getVoidPtr(moduleOp->getContext()));
136140
}
137141

138-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
142+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(Operation *moduleOp,
139143
Type indexType) {
140144
return LLVM::lookupOrCreateFn(moduleOp, kAlignedAlloc, {indexType, indexType},
141145
getVoidPtr(moduleOp->getContext()));
142146
}
143147

144-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
148+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(Operation *moduleOp) {
145149
return LLVM::lookupOrCreateFn(
146150
moduleOp, kFree, getVoidPtr(moduleOp->getContext()),
147151
LLVM::LLVMVoidType::get(moduleOp->getContext()));
148152
}
149153

150-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
154+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(Operation *moduleOp,
151155
Type indexType) {
152156
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlloc, indexType,
153157
getVoidPtr(moduleOp->getContext()));
154158
}
155159

156160
LLVM::LLVMFuncOp
157-
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
161+
mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(Operation *moduleOp,
158162
Type indexType) {
159163
return LLVM::lookupOrCreateFn(moduleOp, kGenericAlignedAlloc,
160164
{indexType, indexType},
161165
getVoidPtr(moduleOp->getContext()));
162166
}
163167

164-
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
168+
LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(Operation *moduleOp) {
165169
return LLVM::lookupOrCreateFn(
166170
moduleOp, kGenericFree, getVoidPtr(moduleOp->getContext()),
167171
LLVM::LLVMVoidType::get(moduleOp->getContext()));
168172
}
169173

170174
LLVM::LLVMFuncOp
171-
mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
175+
mlir::LLVM::lookupOrCreateMemRefCopyFn(Operation *moduleOp, Type indexType,
172176
Type unrankedDescriptorType) {
173177
return LLVM::lookupOrCreateFn(
174178
moduleOp, kMemRefCopy,

0 commit comments

Comments
 (0)