Skip to content

Commit 0c1c49f

Browse files
fabianmcgkuhar
andauthored
[mlir][AMDGPU] Fix raw buffer ptr ops lowering (#122293)
This patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include: - Incorrectly handling the offset of the memref, causing errors when using subviews. - Using the MaximumOp (float specific op) to calculate the number of records. - The number of records in the static shape case. - The lowering when index bitwidth=i64. Furthermore this patch also switches to use MLIR's data layout to get the type size. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent ec3525f commit 0c1c49f

File tree

2 files changed

+103
-66
lines changed

2 files changed

+103
-66
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 72 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,23 @@ namespace mlir {
3030
using namespace mlir;
3131
using namespace mlir::amdgpu;
3232

33+
/// Convert an unsigned number `val` to i32.
34+
static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
35+
Location loc, Value val) {
36+
IntegerType i32 = rewriter.getI32Type();
37+
// Force check that `val` is of int type.
38+
auto valTy = cast<IntegerType>(val.getType());
39+
if (i32 == valTy)
40+
return val;
41+
return valTy.getWidth() > 32
42+
? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
43+
: Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
44+
}
45+
3346
static Value createI32Constant(ConversionPatternRewriter &rewriter,
3447
Location loc, int32_t value) {
35-
Type llvmI32 = rewriter.getI32Type();
36-
return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
48+
Type i32 = rewriter.getI32Type();
49+
return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
3750
}
3851

3952
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
@@ -42,6 +55,27 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
4255
return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
4356
}
4457

58+
/// Returns the linear index used to access an element in the memref.
59+
static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
60+
Location loc, MemRefDescriptor &memRefDescriptor,
61+
ValueRange indices, ArrayRef<int64_t> strides) {
62+
IntegerType i32 = rewriter.getI32Type();
63+
Value index;
64+
for (auto [i, increment, stride] : llvm::enumerate(indices, strides)) {
65+
if (stride != 1) { // Skip if stride is 1.
66+
Value strideValue =
67+
ShapedType::isDynamic(stride)
68+
? convertUnsignedToI32(rewriter, loc,
69+
memRefDescriptor.stride(rewriter, loc, i))
70+
: rewriter.create<LLVM::ConstantOp>(loc, i32, stride);
71+
increment = rewriter.create<LLVM::MulOp>(loc, increment, strideValue);
72+
}
73+
index =
74+
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
75+
}
76+
return index ? index : createI32Constant(rewriter, loc, 0);
77+
}
78+
4579
namespace {
4680
// Define commonly used chipsets versions for convenience.
4781
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
@@ -88,17 +122,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
88122
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
89123

90124
Type i32 = rewriter.getI32Type();
91-
Type llvmI32 = this->typeConverter->convertType(i32);
92-
Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
125+
Type i16 = rewriter.getI16Type();
93126

94-
auto toI32 = [&](Value val) -> Value {
95-
if (val.getType() == llvmI32)
96-
return val;
97-
98-
return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val);
99-
};
100-
101-
int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
127+
// Get the type size in bytes.
128+
DataLayout dataLayout = DataLayout::closest(gpuOp);
129+
int64_t elementByteWidth =
130+
dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
102131
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
103132

104133
// If we want to load a vector<NxT> with total size <= 32
@@ -114,7 +143,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
114143
}
115144
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
116145
uint32_t vecLen = dataVector.getNumElements();
117-
uint32_t elemBits = dataVector.getElementTypeBitWidth();
146+
uint32_t elemBits =
147+
dataLayout.getTypeSizeInBits(dataVector.getElementType());
118148
uint32_t totalBits = elemBits * vecLen;
119149
bool usePackedFp16 =
120150
isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
@@ -167,28 +197,36 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
167197

168198
MemRefDescriptor memrefDescriptor(memref);
169199

170-
Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
200+
Value ptr = memrefDescriptor.bufferPtr(
201+
rewriter, loc, *this->getTypeConverter(), memrefType);
171202
// The stride value is always 0 for raw buffers. This also disables
172203
// swizling.
173204
Value stride = rewriter.create<LLVM::ConstantOp>(
174-
loc, llvmI16, rewriter.getI16IntegerAttr(0));
205+
loc, i16, rewriter.getI16IntegerAttr(0));
206+
// Get the number of elements.
175207
Value numRecords;
176-
if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
177-
numRecords = createI32Constant(
178-
rewriter, loc,
179-
static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
208+
if (memrefType.hasStaticShape() &&
209+
!llvm::any_of(strides, ShapedType::isDynamic)) {
210+
int64_t size = memrefType.getRank() == 0 ? 1 : 0;
211+
ArrayRef<int64_t> shape = memrefType.getShape();
212+
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
213+
size = std::max(shape[i] * strides[i], size);
214+
size = size * elementByteWidth;
215+
assert(size < std::numeric_limits<uint32_t>::max() &&
216+
"the memref buffer is too large");
217+
numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
180218
} else {
181219
Value maxIndex;
182220
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
183-
Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
184-
Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
185-
stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
221+
Value size = memrefDescriptor.size(rewriter, loc, i);
222+
Value stride = memrefDescriptor.stride(rewriter, loc, i);
186223
Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
187-
maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
188-
maxThisDim)
189-
: maxThisDim;
224+
maxIndex =
225+
maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
226+
: maxThisDim;
190227
}
191-
numRecords = maxIndex;
228+
numRecords = rewriter.create<LLVM::MulOp>(
229+
loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
192230
}
193231

194232
// Flag word:
@@ -218,40 +256,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
218256
args.push_back(resource);
219257

220258
// Indexing (voffset)
221-
Value voffset = createI32Constant(rewriter, loc, 0);
222-
for (auto pair : llvm::enumerate(adaptor.getIndices())) {
223-
size_t i = pair.index();
224-
Value index = pair.value();
225-
Value strideOp;
226-
if (ShapedType::isDynamic(strides[i])) {
227-
strideOp = rewriter.create<LLVM::MulOp>(
228-
loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
229-
byteWidthConst);
230-
} else {
231-
strideOp =
232-
createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
233-
}
234-
index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
235-
voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
236-
}
237-
if (adaptor.getIndexOffset()) {
238-
int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
239-
Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
259+
Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
260+
adaptor.getIndices(), strides);
261+
if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
262+
indexOffset && *indexOffset > 0) {
263+
Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
240264
voffset =
241265
voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
242266
: extraOffsetConst;
243267
}
268+
voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
244269
args.push_back(voffset);
245270

271+
// SGPR offset.
246272
Value sgprOffset = adaptor.getSgprOffset();
247273
if (!sgprOffset)
248274
sgprOffset = createI32Constant(rewriter, loc, 0);
249-
if (ShapedType::isDynamic(offset))
250-
sgprOffset = rewriter.create<LLVM::AddOp>(
251-
loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
252-
else if (offset > 0)
253-
sgprOffset = rewriter.create<LLVM::AddOp>(
254-
loc, sgprOffset, createI32Constant(rewriter, loc, offset));
275+
sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
255276
args.push_back(sgprOffset);
256277

257278
// bit 0: GLC = 0 (atomics drop value, less coherency)

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,37 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
3131
}
3232

3333
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided
34-
func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 {
35-
// CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
36-
// CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
37-
// CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
38-
// CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32
39-
// CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
40-
// CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32
41-
// CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32
42-
// CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32
43-
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
44-
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
45-
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
46-
// CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
47-
// CHECK: return %[[ret]]
48-
%0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32, strided<[?], offset: ?>>, i32 -> i32
34+
func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[?, ?], offset: ?>>, %i: i32, %j: i32) -> i32 {
35+
// CHECK: %[[descriptor:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<16x16xi32, strided<[?, ?], offset: ?>> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
36+
// CHECK: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32) : i32
37+
// CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %[[descriptor]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
38+
// CHECK: %[[offset:.*]] = llvm.extractvalue %[[descriptor]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
39+
// CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
40+
// CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
41+
// CHECK: %[[sz_i:.*]] = llvm.extractvalue %[[descriptor]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
42+
// CHECK: %[[stride_i:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
43+
// CHECK: %[[ext_i:.*]] = llvm.mul %[[sz_i]], %[[stride_i]] : i64
44+
// CHECK: %[[sz_j:.*]] = llvm.extractvalue %[[descriptor]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
45+
// CHECK: %[[stride_j:.*]] = llvm.extractvalue %[[descriptor]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
46+
// CHECK: %[[ext_j:.*]] = llvm.mul %[[sz_j]], %[[stride_j]] : i64
47+
// CHECK: %[[num_records:.*]] = llvm.intr.umax(%[[ext_i]], %[[ext_j]]) : (i64, i64) -> i64
48+
// CHECK: %[[num_rec_i32:.*]] = llvm.trunc %[[num_records]] : i64 to i32
49+
// CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size]] : i32
50+
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %[[stride]], %[[num_rec_bytes_i32]], %{{.*}} : !llvm.ptr to <8>
51+
// CHECK: %[[stride_i_1:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
52+
// CHECK: %[[stride_i_i32:.*]] = llvm.trunc %[[stride_i_1]] : i64 to i32
53+
// CHECK: %[[t_0:.*]] = llvm.mul %{{.*}}, %[[stride_i_i32]] : i32
54+
// CHECK: %[[stride_j_1:.*]] = llvm.extractvalue %[[descriptor]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
55+
// CHECK: %[[stride_j_i32:.*]] = llvm.trunc %[[stride_j_1]] : i64 to i32
56+
// CHECK: %[[t_1:.*]] = llvm.mul %{{.*}}, %[[stride_j_i32]] : i32
57+
// CHECK: %[[index:.*]] = llvm.add %[[t_0]], %[[t_1]] : i32
58+
// CHECK: %[[vgpr_off:.*]] = llvm.mul %[[index]], %[[elem_size]] : i32
59+
// CHECK: %[[zero_0:.*]] = llvm.mlir.constant(0 : i32) : i32
60+
// CHECK: %[[sgpr_off:.*]] = llvm.mul %[[zero_0]], %[[elem_size]] : i32
61+
// CHECK: %[[zero_1:.*]] = llvm.mlir.constant(0 : i32) : i32
62+
// CHECK: %[[v:.*]] = rocdl.raw.ptr.buffer.load %[[rsrc]], %[[vgpr_off]], %[[sgpr_off]], %[[zero_1]] : i32
63+
// CHECK: return %[[v]] : i32
64+
%0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%i, %j] : memref<16x16xi32, strided<[?, ?], offset: ?>>, i32, i32 -> i32
4965
func.return %0 : i32
5066
}
5167

0 commit comments

Comments
 (0)