-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][AMDGPU] Fix raw buffer ptr ops lowering #122293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include: - Incorrectly handling the offset of the memref, causing errors when using subviews. Furthermore, it cannot be assumed the memref offset can be put in a SGPR as it can be a thread dependent value. - Using the MaximumOp (float specific op) to calculate the number of records. - The number of records in the static shape case. - The lowering when index bitwidth=i64. Furthermore this patch also switches to use MLIR's data layout to get the type size.
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-backend-amdgpu Author: Fabian Mora (fabianmcg) ChangesThis patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include:
Furthermore this patch also switches to use MLIR's data layout to get the type size. Full diff: https://github.com/llvm/llvm-project/pull/122293.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4100b086fad8ba..49ac4723c2fb94 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -30,10 +30,23 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
+/// Convert an unsigned number `val` to i32.
+static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
+ Location loc, Value val) {
+ IntegerType i32 = rewriter.getI32Type();
+ // Force check that `val` is of int type.
+ auto valTy = cast<IntegerType>(val.getType());
+ if (i32 == valTy)
+ return val;
+ return valTy.getWidth() > 32
+ ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
+ : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+}
+
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
- Type llvmI32 = rewriter.getI32Type();
- return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
+ Type i32 = rewriter.getI32Type();
+ return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
@@ -42,6 +55,28 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
}
+/// Returns the linear index used to access an element in the memref.
+static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
+ Location loc, MemRefDescriptor &memRefDescriptor,
+ ValueRange indices, ArrayRef<int64_t> strides) {
+ IntegerType i32 = rewriter.getI32Type();
+ Value index;
+ for (int i = 0, e = indices.size(); i < e; ++i) {
+ Value increment = indices[i];
+ if (strides[i] != 1) { // Skip if stride is 1.
+ Value stride =
+ ShapedType::isDynamic(strides[i])
+ ? convertUnsignedToI32(rewriter, loc,
+ memRefDescriptor.stride(rewriter, loc, i))
+ : rewriter.create<LLVM::ConstantOp>(loc, i32, strides[i]);
+ increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+ }
+ index =
+ index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ }
+ return index ? index : createI32Constant(rewriter, loc, 0);
+}
+
namespace {
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
@@ -88,17 +123,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
- Type llvmI32 = this->typeConverter->convertType(i32);
- Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
+ Type i16 = rewriter.getI16Type();
- auto toI32 = [&](Value val) -> Value {
- if (val.getType() == llvmI32)
- return val;
-
- return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val);
- };
-
- int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
+ // Get the type size in bytes.
+ DataLayout dataLayout = DataLayout::closest(gpuOp);
+ int64_t elementByteWidth =
+ dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
// If we want to load a vector<NxT> with total size <= 32
@@ -114,7 +144,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
}
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t vecLen = dataVector.getNumElements();
- uint32_t elemBits = dataVector.getElementTypeBitWidth();
+ uint32_t elemBits =
+ dataLayout.getTypeSizeInBits(dataVector.getElementType());
uint32_t totalBits = elemBits * vecLen;
bool usePackedFp16 =
isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
@@ -167,28 +198,37 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
MemRefDescriptor memrefDescriptor(memref);
- Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
+ Value ptr = memrefDescriptor.bufferPtr(
+ rewriter, loc, *this->getTypeConverter(), memrefType);
// The stride value is always 0 for raw buffers. This also disables
// swizling.
Value stride = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI16, rewriter.getI16IntegerAttr(0));
+ loc, i16, rewriter.getI16IntegerAttr(0));
+ // Get the number of elements.
Value numRecords;
- if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
- numRecords = createI32Constant(
- rewriter, loc,
- static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
+ if (memrefType.hasStaticShape() && !llvm::any_of(strides, [](int64_t v) {
+ return ShapedType::isDynamic(v);
+ })) {
+ int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+ size = std::max(shape[i] * strides[i], size);
+ size = size * elementByteWidth;
+ assert(size < std::numeric_limits<uint32_t>::max() &&
+ "the memref buffer is too large");
+ numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
} else {
Value maxIndex;
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
- Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
- Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
- stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
- maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
- maxThisDim)
- : maxThisDim;
+ Value maxThisDim = rewriter.create<LLVM::MulOp>(
+ loc, memrefDescriptor.size(rewriter, loc, i),
+ memrefDescriptor.stride(rewriter, loc, i));
+ maxIndex =
+ maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ : maxThisDim;
}
- numRecords = maxIndex;
+ numRecords = rewriter.create<LLVM::MulOp>(
+ loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
}
// Flag word:
@@ -218,40 +258,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
args.push_back(resource);
// Indexing (voffset)
- Value voffset = createI32Constant(rewriter, loc, 0);
- for (auto pair : llvm::enumerate(adaptor.getIndices())) {
- size_t i = pair.index();
- Value index = pair.value();
- Value strideOp;
- if (ShapedType::isDynamic(strides[i])) {
- strideOp = rewriter.create<LLVM::MulOp>(
- loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
- byteWidthConst);
- } else {
- strideOp =
- createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
- }
- index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
- voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
- }
- if (adaptor.getIndexOffset()) {
- int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
- Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
+ Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
+ adaptor.getIndices(), strides);
+ if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
+ indexOffset && *indexOffset > 0) {
+ Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
voffset =
voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
: extraOffsetConst;
}
+ voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
args.push_back(voffset);
+ // SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
- if (ShapedType::isDynamic(offset))
- sgprOffset = rewriter.create<LLVM::AddOp>(
- loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
- else if (offset > 0)
- sgprOffset = rewriter.create<LLVM::AddOp>(
- loc, sgprOffset, createI32Constant(rewriter, loc, offset));
+ sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 4c7515dc810516..92ecbff3e691dc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -32,14 +32,16 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided
func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 {
- // CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
- // CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+ // CHECK: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+ // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+ // CHECK: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
// CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32
// CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32
- // CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32
- // CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32
+ // CHECK: %[[tmp:.*]] = llvm.mul %[[size]], %[[stride]] : i64
+ // CHECK: %[[num_elem:.*]] = llvm.trunc %[[tmp]] : i64 to i32
+ // CHECK: %[[numRecords:.*]] = llvm.mul %[[num_elem]], %[[elem_size]] : i32
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
|
@llvm/pr-subscribers-mlir-gpu Author: Fabian Mora (fabianmcg) ChangesThis patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include:
Furthermore this patch also switches to use MLIR's data layout to get the type size. Full diff: https://github.com/llvm/llvm-project/pull/122293.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4100b086fad8ba..49ac4723c2fb94 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -30,10 +30,23 @@ namespace mlir {
using namespace mlir;
using namespace mlir::amdgpu;
+/// Convert an unsigned number `val` to i32.
+static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
+ Location loc, Value val) {
+ IntegerType i32 = rewriter.getI32Type();
+ // Force check that `val` is of int type.
+ auto valTy = cast<IntegerType>(val.getType());
+ if (i32 == valTy)
+ return val;
+ return valTy.getWidth() > 32
+ ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
+ : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+}
+
static Value createI32Constant(ConversionPatternRewriter &rewriter,
Location loc, int32_t value) {
- Type llvmI32 = rewriter.getI32Type();
- return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
+ Type i32 = rewriter.getI32Type();
+ return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
}
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
@@ -42,6 +55,28 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
}
+/// Returns the linear index used to access an element in the memref.
+static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
+ Location loc, MemRefDescriptor &memRefDescriptor,
+ ValueRange indices, ArrayRef<int64_t> strides) {
+ IntegerType i32 = rewriter.getI32Type();
+ Value index;
+ for (int i = 0, e = indices.size(); i < e; ++i) {
+ Value increment = indices[i];
+ if (strides[i] != 1) { // Skip if stride is 1.
+ Value stride =
+ ShapedType::isDynamic(strides[i])
+ ? convertUnsignedToI32(rewriter, loc,
+ memRefDescriptor.stride(rewriter, loc, i))
+ : rewriter.create<LLVM::ConstantOp>(loc, i32, strides[i]);
+ increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+ }
+ index =
+ index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+ }
+ return index ? index : createI32Constant(rewriter, loc, 0);
+}
+
namespace {
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx908 = Chipset(9, 0, 8);
@@ -88,17 +123,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
Type i32 = rewriter.getI32Type();
- Type llvmI32 = this->typeConverter->convertType(i32);
- Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
+ Type i16 = rewriter.getI16Type();
- auto toI32 = [&](Value val) -> Value {
- if (val.getType() == llvmI32)
- return val;
-
- return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val);
- };
-
- int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
+ // Get the type size in bytes.
+ DataLayout dataLayout = DataLayout::closest(gpuOp);
+ int64_t elementByteWidth =
+ dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
// If we want to load a vector<NxT> with total size <= 32
@@ -114,7 +144,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
}
if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
uint32_t vecLen = dataVector.getNumElements();
- uint32_t elemBits = dataVector.getElementTypeBitWidth();
+ uint32_t elemBits =
+ dataLayout.getTypeSizeInBits(dataVector.getElementType());
uint32_t totalBits = elemBits * vecLen;
bool usePackedFp16 =
isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
@@ -167,28 +198,37 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
MemRefDescriptor memrefDescriptor(memref);
- Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
+ Value ptr = memrefDescriptor.bufferPtr(
+ rewriter, loc, *this->getTypeConverter(), memrefType);
// The stride value is always 0 for raw buffers. This also disables
// swizling.
Value stride = rewriter.create<LLVM::ConstantOp>(
- loc, llvmI16, rewriter.getI16IntegerAttr(0));
+ loc, i16, rewriter.getI16IntegerAttr(0));
+ // Get the number of elements.
Value numRecords;
- if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
- numRecords = createI32Constant(
- rewriter, loc,
- static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
+ if (memrefType.hasStaticShape() && !llvm::any_of(strides, [](int64_t v) {
+ return ShapedType::isDynamic(v);
+ })) {
+ int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+ ArrayRef<int64_t> shape = memrefType.getShape();
+ for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+ size = std::max(shape[i] * strides[i], size);
+ size = size * elementByteWidth;
+ assert(size < std::numeric_limits<uint32_t>::max() &&
+ "the memref buffer is too large");
+ numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
} else {
Value maxIndex;
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
- Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
- Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
- stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
- Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
- maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
- maxThisDim)
- : maxThisDim;
+ Value maxThisDim = rewriter.create<LLVM::MulOp>(
+ loc, memrefDescriptor.size(rewriter, loc, i),
+ memrefDescriptor.stride(rewriter, loc, i));
+ maxIndex =
+ maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+ : maxThisDim;
}
- numRecords = maxIndex;
+ numRecords = rewriter.create<LLVM::MulOp>(
+ loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
}
// Flag word:
@@ -218,40 +258,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
args.push_back(resource);
// Indexing (voffset)
- Value voffset = createI32Constant(rewriter, loc, 0);
- for (auto pair : llvm::enumerate(adaptor.getIndices())) {
- size_t i = pair.index();
- Value index = pair.value();
- Value strideOp;
- if (ShapedType::isDynamic(strides[i])) {
- strideOp = rewriter.create<LLVM::MulOp>(
- loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
- byteWidthConst);
- } else {
- strideOp =
- createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
- }
- index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
- voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
- }
- if (adaptor.getIndexOffset()) {
- int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
- Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
+ Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
+ adaptor.getIndices(), strides);
+ if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
+ indexOffset && *indexOffset > 0) {
+ Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
voffset =
voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
: extraOffsetConst;
}
+ voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
args.push_back(voffset);
+ // SGPR offset.
Value sgprOffset = adaptor.getSgprOffset();
if (!sgprOffset)
sgprOffset = createI32Constant(rewriter, loc, 0);
- if (ShapedType::isDynamic(offset))
- sgprOffset = rewriter.create<LLVM::AddOp>(
- loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
- else if (offset > 0)
- sgprOffset = rewriter.create<LLVM::AddOp>(
- loc, sgprOffset, createI32Constant(rewriter, loc, offset));
+ sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
args.push_back(sgprOffset);
// bit 0: GLC = 0 (atomics drop value, less coherency)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 4c7515dc810516..92ecbff3e691dc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -32,14 +32,16 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided
func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 {
- // CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
- // CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+ // CHECK: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+ // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+ // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+ // CHECK: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
// CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32
// CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
- // CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32
- // CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32
- // CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32
+ // CHECK: %[[tmp:.*]] = llvm.mul %[[size]], %[[stride]] : i64
+ // CHECK: %[[num_elem:.*]] = llvm.trunc %[[tmp]] : i64 to i32
+ // CHECK: %[[numRecords:.*]] = llvm.mul %[[num_elem]], %[[elem_size]] : i32
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have tests for all these bugs?
We only have one, but it was incorrect, I'll another one more comprehensive. |
Co-authored-by: Jakub Kuderski <[email protected]>
I'm happy with approving this is folks are happy with the tests |
I'll be happy to land this once we're happy with the tests |
Great, I'm still trying to figure out what's happening with the windows side. Because linux passes, and the error is happening in FileCheck, but I've been unable to reproduce. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might want a separate file with a few tests where the index bitwidth is 64, but, that aside, approved
I'll create a new PR for that, currently I was also thinking in changing the AMDGPU ops to take |
Yeah, entirely fair on the pass not being set up for it - though for So yeah, approved |
This patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include:
Furthermore this patch also switches to use MLIR's data layout to get the type size.