Skip to content

Commit d3a9807

Browse files
committed
[mlir] Remove most uses of LLVMDialect::getModule
This prepares for the removal of llvm::Module and LLVMContext from the mlir::LLVMDialect. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D85371
1 parent d40c44e commit d3a9807

File tree

8 files changed

+24
-39
lines changed

8 files changed

+24
-39
lines changed

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,7 @@ class LLVMTypeConverter : public TypeConverter {
118118
unsigned getPointerBitwidth(unsigned addressSpace = 0);
119119

120120
protected:
121-
/// LLVM IR module used to parse/create types.
122-
llvm::Module *module;
121+
/// Pointer to the LLVM dialect.
123122
LLVM::LLVMDialect *llvmDialect;
124123

125124
private:
@@ -400,9 +399,6 @@ class ConvertToLLVMPattern : public ConversionPattern {
400399
/// Returns the LLVM IR context.
401400
llvm::LLVMContext &getContext() const;
402401

403-
/// Returns the LLVM IR module associated with the LLVM dialect.
404-
llvm::Module &getModule() const;
405-
406402
/// Gets the MLIR type wrapping the LLVM integer type whose bit width is
407403
/// defined by the used type converter.
408404
LLVM::LLVMType getIndexType() const;
@@ -437,8 +433,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
437433
ConversionPatternRewriter &rewriter) const;
438434

439435
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
440-
ValueRange indices, ConversionPatternRewriter &rewriter,
441-
llvm::Module &module) const;
436+
ValueRange indices,
437+
ConversionPatternRewriter &rewriter) const;
442438

443439
/// Returns the type of a pointer to an element of the memref.
444440
Type getElementPtrType(MemRefType type) const;

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def LLVM_Dialect : Dialect {
2525
llvm::LLVMContext &getLLVMContext();
2626
llvm::Module &getLLVMModule();
2727
llvm::sys::SmartMutex<true> &getLLVMContextMutex();
28+
const llvm::DataLayout &getDataLayout();
2829

2930
private:
3031
friend LLVMType;

mlir/lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,15 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
6666
private:
6767
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
6868

69-
llvm::LLVMContext &getLLVMContext() {
70-
return getLLVMDialect()->getLLVMContext();
71-
}
72-
7369
void initializeCachedTypes() {
74-
const llvm::Module &module = llvmDialect->getLLVMModule();
7570
llvmVoidType = LLVM::LLVMType::getVoidTy(llvmDialect);
7671
llvmPointerType = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
7772
llvmPointerPointerType = llvmPointerType.getPointerTo();
7873
llvmInt8Type = LLVM::LLVMType::getInt8Ty(llvmDialect);
7974
llvmInt32Type = LLVM::LLVMType::getInt32Ty(llvmDialect);
8075
llvmInt64Type = LLVM::LLVMType::getInt64Ty(llvmDialect);
8176
llvmIntPtrType = LLVM::LLVMType::getIntNTy(
82-
llvmDialect, module.getDataLayout().getPointerSizeInBits());
77+
llvmDialect, llvmDialect->getDataLayout().getPointerSizeInBits());
8378
}
8479

8580
LLVM::LLVMType getVoidType() { return llvmVoidType; }
@@ -95,9 +90,9 @@ class GpuLaunchFuncToGpuRuntimeCallsPass
9590
LLVM::LLVMType getInt64Type() { return llvmInt64Type; }
9691

9792
LLVM::LLVMType getIntPtrType() {
98-
const llvm::Module &module = getLLVMDialect()->getLLVMModule();
9993
return LLVM::LLVMType::getIntNTy(
100-
getLLVMDialect(), module.getDataLayout().getPointerSizeInBits());
94+
getLLVMDialect(),
95+
getLLVMDialect()->getDataLayout().getPointerSizeInBits());
10196
}
10297

10398
// Allocate a void pointer on the stack.

mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ class VulkanLaunchFuncToVulkanCallsPass
5959
private:
6060
LLVM::LLVMDialect *getLLVMDialect() { return llvmDialect; }
6161

62-
llvm::LLVMContext &getLLVMContext() {
63-
return getLLVMDialect()->getLLVMContext();
64-
}
65-
6662
void initializeCachedTypes() {
6763
llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
6864
llvmFloatType = LLVM::LLVMType::getFloatTy(llvmDialect);

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
128128
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
129129
options(options) {
130130
assert(llvmDialect && "LLVM IR dialect is not registered");
131-
module = &llvmDialect->getLLVMModule();
132131
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
133132
this->options.indexBitwidth =
134-
module->getDataLayout().getPointerSizeInBits();
133+
llvmDialect->getDataLayout().getPointerSizeInBits();
135134

136135
// Register conversions for the standard types.
137136
addConversion([&](ComplexType type) { return convertComplexType(type); });
@@ -196,15 +195,15 @@ MLIRContext &LLVMTypeConverter::getContext() {
196195

197196
/// Get the LLVM context.
198197
llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() {
199-
return module->getContext();
198+
return llvmDialect->getLLVMContext();
200199
}
201200

202201
LLVM::LLVMType LLVMTypeConverter::getIndexType() {
203202
return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth());
204203
}
205204

206205
unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) {
207-
return module->getDataLayout().getPointerSizeInBits(addressSpace);
206+
return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace);
208207
}
209208

210209
Type LLVMTypeConverter::convertIndexType(IndexType type) {
@@ -849,10 +848,6 @@ llvm::LLVMContext &ConvertToLLVMPattern::getContext() const {
849848
return typeConverter.getLLVMContext();
850849
}
851850

852-
llvm::Module &ConvertToLLVMPattern::getModule() const {
853-
return getDialect().getLLVMModule();
854-
}
855-
856851
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const {
857852
return typeConverter.getIndexType();
858853
}
@@ -910,10 +905,9 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
910905
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
911906
}
912907

913-
Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
914-
Value memRefDesc, ValueRange indices,
915-
ConversionPatternRewriter &rewriter,
916-
llvm::Module &module) const {
908+
Value ConvertToLLVMPattern::getDataPtr(
909+
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
910+
ConversionPatternRewriter &rewriter) const {
917911
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
918912
int64_t offset;
919913
SmallVector<int64_t, 4> strides;
@@ -2451,7 +2445,7 @@ struct LoadOpLowering : public LoadStoreOpLowering<LoadOp> {
24512445
auto type = loadOp.getMemRefType();
24522446

24532447
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
2454-
transformed.indices(), rewriter, getModule());
2448+
transformed.indices(), rewriter);
24552449
rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, dataPtr);
24562450
return success();
24572451
}
@@ -2469,7 +2463,7 @@ struct StoreOpLowering : public LoadStoreOpLowering<StoreOp> {
24692463
StoreOp::Adaptor transformed(operands);
24702464

24712465
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
2472-
transformed.indices(), rewriter, getModule());
2466+
transformed.indices(), rewriter);
24732467
rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, transformed.value(),
24742468
dataPtr);
24752469
return success();
@@ -2489,7 +2483,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering<PrefetchOp> {
24892483
auto type = prefetchOp.getMemRefType();
24902484

24912485
Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(),
2492-
transformed.indices(), rewriter, getModule());
2486+
transformed.indices(), rewriter);
24932487

24942488
// Replace with llvm.prefetch.
24952489
auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32));
@@ -3086,7 +3080,7 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
30863080
auto resultType = adaptor.value().getType();
30873081
auto memRefType = atomicOp.getMemRefType();
30883082
auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(),
3089-
adaptor.indices(), rewriter, getModule());
3083+
adaptor.indices(), rewriter);
30903084
rewriter.replaceOpWithNewOp<LLVM::AtomicRMWOp>(
30913085
op, resultType, *maybeKind, dataPtr, adaptor.value(),
30923086
LLVM::AtomicOrdering::acq_rel);
@@ -3152,7 +3146,7 @@ struct GenericAtomicRMWOpLowering
31523146
rewriter.setInsertionPointToEnd(initBlock);
31533147
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
31543148
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
3155-
adaptor.indices(), rewriter, getModule());
3149+
adaptor.indices(), rewriter);
31563150
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
31573151
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);
31583152

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op,
131131
LLVM::LLVMDialect *dialect = typeConverter.getDialect();
132132
align = LLVM::TypeToLLVMIRTranslator(dialect->getLLVMContext())
133133
.getPreferredAlignment(elementTy.cast<LLVM::LLVMType>(),
134-
dialect->getLLVMModule().getDataLayout());
134+
dialect->getDataLayout());
135135
return success();
136136
}
137137

@@ -1152,7 +1152,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
11521152
// address space 0.
11531153
// TODO: support alignment when possible.
11541154
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
1155-
adaptor.indices(), rewriter, getModule());
1155+
adaptor.indices(), rewriter);
11561156
auto vecTy =
11571157
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
11581158
Value vectorDataPtr;

mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class VectorTransferConversion : public ConvertToLLVMPattern {
103103
// indices, so no need to calculat offset size in bytes again in
104104
// the MUBUF instruction.
105105
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
106-
adaptor.indices(), rewriter, getModule());
106+
adaptor.indices(), rewriter);
107107

108108
// 1. Create and fill a <4 x i32> dwordConfig with:
109109
// 1st two elements holding the address of dataPtr.

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,6 +1741,9 @@ llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
17411741
llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
17421742
return impl->mutex;
17431743
}
1744+
const llvm::DataLayout &LLVMDialect::getDataLayout() {
1745+
return impl->module.getDataLayout();
1746+
}
17441747

17451748
/// Parse a type registered to this dialect.
17461749
Type LLVMDialect::parseType(DialectAsmParser &parser) const {

0 commit comments

Comments
 (0)