Skip to content

Commit 620e2bb

Browse files
[mlir][LLVM] NFC - Remove createIndexConstant method
This revision removes the createIndexConstant method, which implicitly creates constants of the getIndexType type and updates all uses to the more explicit createIndexAttrConstant which requires an explicit Type parameter. This is an NFC step towards entangling index type conversion in LLVM lowering. The selection of which index type to use requires finer granularity than the existing implementations which all rely on pass level flags and end up in mismatches, especially on GPUs with multiple address spaces of different capacities. This revision also includes an NFC fix to MemRefToLLVM.cpp that prevents a crash in cases where an integer memory space cannot be derived for a MemRef. Differential Revision: https://reviews.llvm.org/D156854
1 parent 821687a commit 620e2bb

File tree

7 files changed

+82
-65
lines changed

7 files changed

+82
-65
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,6 @@ class ConvertToLLVMPattern : public ConversionPattern {
6565
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
6666
Type resultType, int64_t value);
6767

68-
/// Create an LLVM dialect operation defining the given index constant.
69-
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
70-
uint64_t value) const;
71-
7268
// This is a strided getElementPtr variant that linearizes subscripts as:
7369
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
7470
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
@@ -155,9 +151,9 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
155151
ConversionPatternRewriter &rewriter) const final {
156152
if constexpr (SourceOp::hasProperties())
157153
return rewrite(cast<SourceOp>(op),
158-
OpAdaptor(operands, op->getDiscardableAttrDictionary(),
159-
cast<SourceOp>(op).getProperties()),
160-
rewriter);
154+
OpAdaptor(operands, op->getDiscardableAttrDictionary(),
155+
cast<SourceOp>(op).getProperties()),
156+
rewriter);
161157
rewrite(cast<SourceOp>(op),
162158
OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter);
163159
}

mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace mlir {
1515

1616
/// Lowering for memory allocation ops.
1717
struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
18-
using ConvertToLLVMPattern::createIndexConstant;
18+
using ConvertToLLVMPattern::createIndexAttrConstant;
1919
using ConvertToLLVMPattern::getIndexType;
2020
using ConvertToLLVMPattern::getVoidPtrType;
2121

@@ -43,7 +43,9 @@ struct AllocationOpLLVMLowering : public ConvertToLLVMPattern {
4343
MemRefType memRefType = op.getType();
4444
Value alignment;
4545
if (auto alignmentAttr = op.getAlignment()) {
46-
alignment = createIndexConstant(rewriter, loc, *alignmentAttr);
46+
Type indexType = getIndexType();
47+
alignment =
48+
createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr);
4749
} else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) {
4850
// In the case where no alignment is specified, we may want to override
4951
// `malloc's` behavior. `malloc` typically aligns at the size of the

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
168168
Value lowHalf = rewriter.create<LLVM::TruncOp>(loc, llvmI32, ptrAsInt);
169169
resource = rewriter.create<LLVM::InsertElementOp>(
170170
loc, llvm4xI32, resource, lowHalf,
171-
this->createIndexConstant(rewriter, loc, 0));
171+
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 0));
172172

173173
// Bits 48-63 are used both for the stride of the buffer and (on gfx10) for
174174
// enabling swizzling. Prevent the high bits of pointers from accidentally
@@ -180,7 +180,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
180180
createI32Constant(rewriter, loc, 0x0000ffff));
181181
resource = rewriter.create<LLVM::InsertElementOp>(
182182
loc, llvm4xI32, resource, highHalfTruncated,
183-
this->createIndexConstant(rewriter, loc, 1));
183+
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 1));
184184

185185
Value numRecords;
186186
if (memrefType.hasStaticShape()) {
@@ -202,7 +202,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
202202
}
203203
resource = rewriter.create<LLVM::InsertElementOp>(
204204
loc, llvm4xI32, resource, numRecords,
205-
this->createIndexConstant(rewriter, loc, 2));
205+
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 2));
206206

207207
// Final word:
208208
// bits 0-11: dst sel, ignored by these intrinsics
@@ -227,7 +227,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
227227
Value word3Const = createI32Constant(rewriter, loc, word3);
228228
resource = rewriter.create<LLVM::InsertElementOp>(
229229
loc, llvm4xI32, resource, word3Const,
230-
this->createIndexConstant(rewriter, loc, 3));
230+
this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 3));
231231
args.push_back(resource);
232232

233233
// Indexing (voffset)

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
6767
protected:
6868
Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
6969
MemRefType type, MemRefDescriptor desc) const {
70+
Type indexType = ConvertToLLVMPattern::getIndexType();
7071
return type.hasStaticShape()
71-
? ConvertToLLVMPattern::createIndexConstant(
72-
rewriter, loc, type.getNumElements())
72+
? ConvertToLLVMPattern::createIndexAttrConstant(
73+
rewriter, loc, indexType, type.getNumElements())
7374
// For identity maps (verified by caller), the number of
7475
// elements is stride[0] * size[0].
7576
: rewriter.create<LLVM::MulOp>(loc,

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,6 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
6060
builder.getIndexAttr(value));
6161
}
6262

63-
Value ConvertToLLVMPattern::createIndexConstant(
64-
ConversionPatternRewriter &builder, Location loc, uint64_t value) const {
65-
return createIndexAttrConstant(builder, loc, getIndexType(), value);
66-
}
67-
6863
Value ConvertToLLVMPattern::getStridedElementPtr(
6964
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
7065
ConversionPatternRewriter &rewriter) const {
@@ -79,13 +74,15 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
7974
Value base =
8075
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
8176

77+
Type indexType = getIndexType();
8278
Value index;
8379
for (int i = 0, e = indices.size(); i < e; ++i) {
8480
Value increment = indices[i];
8581
if (strides[i] != 1) { // Skip if stride is 1.
86-
Value stride = ShapedType::isDynamic(strides[i])
87-
? memRefDescriptor.stride(rewriter, loc, i)
88-
: createIndexConstant(rewriter, loc, strides[i]);
82+
Value stride =
83+
ShapedType::isDynamic(strides[i])
84+
? memRefDescriptor.stride(rewriter, loc, i)
85+
: createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
8986
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
9087
}
9188
index =
@@ -130,15 +127,17 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
130127

131128
sizes.reserve(memRefType.getRank());
132129
unsigned dynamicIndex = 0;
130+
Type indexType = getIndexType();
133131
for (int64_t size : memRefType.getShape()) {
134-
sizes.push_back(size == ShapedType::kDynamic
135-
? dynamicSizes[dynamicIndex++]
136-
: createIndexConstant(rewriter, loc, size));
132+
sizes.push_back(
133+
size == ShapedType::kDynamic
134+
? dynamicSizes[dynamicIndex++]
135+
: createIndexAttrConstant(rewriter, loc, indexType, size));
137136
}
138137

139138
// Strides: iterate sizes in reverse order and multiply.
140139
int64_t stride = 1;
141-
Value runningStride = createIndexConstant(rewriter, loc, 1);
140+
Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1);
142141
strides.resize(memRefType.getRank());
143142
for (auto i = memRefType.getRank(); i-- > 0;) {
144143
strides[i] = runningStride;
@@ -158,7 +157,7 @@ void ConvertToLLVMPattern::getMemRefDescriptorSizes(
158157
runningStride =
159158
rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]);
160159
else
161-
runningStride = createIndexConstant(rewriter, loc, stride);
160+
runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride);
162161
}
163162
if (sizeInBytes) {
164163
// Buffer size in bytes.
@@ -195,22 +194,25 @@ Value ConvertToLLVMPattern::getNumElements(
195194
static_cast<ssize_t>(dynamicSizes.size()) &&
196195
"dynamicSizes size doesn't match dynamic sizes count in memref shape");
197196

197+
Type indexType = getIndexType();
198198
Value numElements = memRefType.getRank() == 0
199-
? createIndexConstant(rewriter, loc, 1)
199+
? createIndexAttrConstant(rewriter, loc, indexType, 1)
200200
: nullptr;
201201
unsigned dynamicIndex = 0;
202202

203203
// Compute the total number of memref elements.
204204
for (int64_t staticSize : memRefType.getShape()) {
205205
if (numElements) {
206-
Value size = staticSize == ShapedType::kDynamic
207-
? dynamicSizes[dynamicIndex++]
208-
: createIndexConstant(rewriter, loc, staticSize);
206+
Value size =
207+
staticSize == ShapedType::kDynamic
208+
? dynamicSizes[dynamicIndex++]
209+
: createIndexAttrConstant(rewriter, loc, indexType, staticSize);
209210
numElements = rewriter.create<LLVM::MulOp>(loc, numElements, size);
210211
} else {
211-
numElements = staticSize == ShapedType::kDynamic
212-
? dynamicSizes[dynamicIndex++]
213-
: createIndexConstant(rewriter, loc, staticSize);
212+
numElements =
213+
staticSize == ShapedType::kDynamic
214+
? dynamicSizes[dynamicIndex++]
215+
: createIndexAttrConstant(rewriter, loc, indexType, staticSize);
214216
}
215217
}
216218
return numElements;
@@ -231,8 +233,9 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
231233
memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr);
232234

233235
// Field 3: Offset in aligned pointer.
234-
memRefDescriptor.setOffset(rewriter, loc,
235-
createIndexConstant(rewriter, loc, 0));
236+
Type indexType = getIndexType();
237+
memRefDescriptor.setOffset(
238+
rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0));
236239

237240
// Fields 4: Sizes.
238241
for (const auto &en : llvm::enumerate(sizes))

mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ bool AllocationOpLLVMLowering::isMemRefSizeMultipleOf(
138138
Value AllocationOpLLVMLowering::allocateBufferAutoAlign(
139139
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes,
140140
Operation *op, const DataLayout *defaultLayout, int64_t alignment) const {
141-
Value allocAlignment = createIndexConstant(rewriter, loc, alignment);
141+
Value allocAlignment =
142+
createIndexAttrConstant(rewriter, loc, getIndexType(), alignment);
142143

143144
MemRefType memRefType = getMemRefResultType(op);
144145
// Function aligned_alloc requires size to be a multiple of alignment; we pad

0 commit comments

Comments
 (0)