Skip to content

Commit 953b07f

Browse files
authored
[mlir] AMDGPUToROCDL: RawBufferOpLowering fixes (llvm#120642)
1. We can use `getNumElements()` only for memrefs with trivial layout. 2. Buffer ops expecting sizes in i32 but descriptor values can be either i32 or i64, add appropriate casts. This implementation is not ideal as it can overflow, but it's still better than generating broken IR.
1 parent 5845298 commit 953b07f

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
9191
Type llvmI32 = this->typeConverter->convertType(i32);
9292
Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
9393

94+
auto toI32 = [&](Value val) -> Value {
95+
if (val.getType() == llvmI32)
96+
return val;
97+
98+
return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val);
99+
};
100+
94101
int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
95102
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
96103

@@ -166,22 +173,22 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
166173
Value stride = rewriter.create<LLVM::ConstantOp>(
167174
loc, llvmI16, rewriter.getI16IntegerAttr(0));
168175
Value numRecords;
169-
if (memrefType.hasStaticShape()) {
176+
if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
170177
numRecords = createI32Constant(
171178
rewriter, loc,
172179
static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
173180
} else {
174181
Value maxIndex;
175182
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
176-
Value size = memrefDescriptor.size(rewriter, loc, i);
177-
Value stride = memrefDescriptor.stride(rewriter, loc, i);
183+
Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
184+
Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
178185
stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
179186
Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
180187
maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
181188
maxThisDim)
182189
: maxThisDim;
183190
}
184-
numRecords = rewriter.create<LLVM::TruncOp>(loc, llvmI32, maxIndex);
191+
numRecords = maxIndex;
185192
}
186193

187194
// Flag word:
@@ -218,7 +225,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
218225
Value strideOp;
219226
if (ShapedType::isDynamic(strides[i])) {
220227
strideOp = rewriter.create<LLVM::MulOp>(
221-
loc, memrefDescriptor.stride(rewriter, loc, i), byteWidthConst);
228+
loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
229+
byteWidthConst);
222230
} else {
223231
strideOp =
224232
createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
@@ -240,7 +248,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
240248
sgprOffset = createI32Constant(rewriter, loc, 0);
241249
if (ShapedType::isDynamic(offset))
242250
sgprOffset = rewriter.create<LLVM::AddOp>(
243-
loc, memrefDescriptor.offset(rewriter, loc), sgprOffset);
251+
loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
244252
else if (offset > 0)
245253
sgprOffset = rewriter.create<LLVM::AddOp>(
246254
loc, sgprOffset, createI32Constant(rewriter, loc, offset));

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,25 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
3030
func.return %0 : i32
3131
}
3232

33+
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided
34+
func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 {
35+
// CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
36+
// CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
37+
// CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
38+
// CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32
39+
// CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
40+
// CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32
41+
// CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32
42+
// CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32
43+
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
44+
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
45+
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
46+
// CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
47+
// CHECK: return %[[ret]]
48+
%0 = amdgpu.raw_buffer_load {boundsCheck = true} %buf[%idx] : memref<64xi32, strided<[?], offset: ?>>, i32 -> i32
49+
func.return %0 : i32
50+
}
51+
3352
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_oob_off
3453
func.func @gpu_gcn_raw_buffer_load_i32_oob_off(%buf: memref<64xi32>, %idx: i32) -> i32 {
3554
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)

0 commit comments

Comments
 (0)