Skip to content

[mlir][AMDGPU] Fix raw buffer ptr ops lowering #122293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 13, 2025

Conversation

fabianmcg
Copy link
Contributor

@fabianmcg fabianmcg commented Jan 9, 2025

This patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include:

  • Incorrectly handling the offset of the memref, causing errors when using subviews.
  • Using the MaximumOp (float specific op) to calculate the number of records.
  • The number of records in the static shape case.
  • The lowering when index bitwidth=i64.

Furthermore this patch also switches to use MLIR's data layout to get the type size.

This patch fixes several bugs in the lowering of AMDGPU raw buffer operations.
These bugs include:
 - Incorrectly handling the offset of the memref, causing errors when using subviews. Furthermore, it cannot be assumed the memref offset can be put in a SGPR as it can be a thread dependent value.
 - Using the MaximumOp (float specific op) to calculate the number of records.
 - The number of records in the static shape case.
 - The lowering when index bitwidth=i64.

Furthermore this patch also switches to use MLIR's data layout to get the type size.
Copy link

github-actions bot commented Jan 9, 2025

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@fabianmcg fabianmcg marked this pull request as ready for review January 9, 2025 15:19
@llvmbot
Copy link
Member

llvmbot commented Jan 9, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-backend-amdgpu

Author: Fabian Mora (fabianmcg)

Changes

This patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include:

  • Incorrectly handling the offset of the memref, causing errors when using subviews. Furthermore, it cannot be assumed the memref offset can be put in a SGPR as it can be a thread dependent value.
  • Using the MaximumOp (float specific op) to calculate the number of records.
  • The number of records in the static shape case.
  • The lowering when index bitwidth=i64.

Furthermore this patch also switches to use MLIR's data layout to get the type size.


Full diff: https://github.com/llvm/llvm-project/pull/122293.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+75-52)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir (+8-6)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4100b086fad8ba..49ac4723c2fb94 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -30,10 +30,23 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::amdgpu;
 
+/// Convert an unsigned number `val` to i32.
+static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
+                                  Location loc, Value val) {
+  IntegerType i32 = rewriter.getI32Type();
+  // Force check that `val` is of int type.
+  auto valTy = cast<IntegerType>(val.getType());
+  if (i32 == valTy)
+    return val;
+  return valTy.getWidth() > 32
+             ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
+             : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+}
+
 static Value createI32Constant(ConversionPatternRewriter &rewriter,
                                Location loc, int32_t value) {
-  Type llvmI32 = rewriter.getI32Type();
-  return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
+  Type i32 = rewriter.getI32Type();
+  return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
 }
 
 static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
@@ -42,6 +55,28 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
   return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
 }
 
+/// Returns the linear index used to access an element in the memref.
+static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
+                               Location loc, MemRefDescriptor &memRefDescriptor,
+                               ValueRange indices, ArrayRef<int64_t> strides) {
+  IntegerType i32 = rewriter.getI32Type();
+  Value index;
+  for (int i = 0, e = indices.size(); i < e; ++i) {
+    Value increment = indices[i];
+    if (strides[i] != 1) { // Skip if stride is 1.
+      Value stride =
+          ShapedType::isDynamic(strides[i])
+              ? convertUnsignedToI32(rewriter, loc,
+                                     memRefDescriptor.stride(rewriter, loc, i))
+              : rewriter.create<LLVM::ConstantOp>(loc, i32, strides[i]);
+      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+    }
+    index =
+        index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+  }
+  return index ? index : createI32Constant(rewriter, loc, 0);
+}
+
 namespace {
 // Define commonly used chipsets versions for convenience.
 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
@@ -88,17 +123,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
     Type i32 = rewriter.getI32Type();
-    Type llvmI32 = this->typeConverter->convertType(i32);
-    Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
+    Type i16 = rewriter.getI16Type();
 
-    auto toI32 = [&](Value val) -> Value {
-      if (val.getType() == llvmI32)
-        return val;
-
-      return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val);
-    };
-
-    int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
+    // Get the type size in bytes.
+    DataLayout dataLayout = DataLayout::closest(gpuOp);
+    int64_t elementByteWidth =
+        dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
     Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
 
     // If we want to load a vector<NxT> with total size <= 32
@@ -114,7 +144,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     }
     if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
       uint32_t vecLen = dataVector.getNumElements();
-      uint32_t elemBits = dataVector.getElementTypeBitWidth();
+      uint32_t elemBits =
+          dataLayout.getTypeSizeInBits(dataVector.getElementType());
       uint32_t totalBits = elemBits * vecLen;
       bool usePackedFp16 =
           isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
@@ -167,28 +198,37 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     MemRefDescriptor memrefDescriptor(memref);
 
-    Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
+    Value ptr = memrefDescriptor.bufferPtr(
+        rewriter, loc, *this->getTypeConverter(), memrefType);
     // The stride value is always 0 for raw buffers. This also disables
     // swizling.
     Value stride = rewriter.create<LLVM::ConstantOp>(
-        loc, llvmI16, rewriter.getI16IntegerAttr(0));
+        loc, i16, rewriter.getI16IntegerAttr(0));
+    // Get the number of elements.
     Value numRecords;
-    if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
-      numRecords = createI32Constant(
-          rewriter, loc,
-          static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
+    if (memrefType.hasStaticShape() && !llvm::any_of(strides, [](int64_t v) {
+          return ShapedType::isDynamic(v);
+        })) {
+      int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+      ArrayRef<int64_t> shape = memrefType.getShape();
+      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+        size = std::max(shape[i] * strides[i], size);
+      size = size * elementByteWidth;
+      assert(size < std::numeric_limits<uint32_t>::max() &&
+             "the memref buffer is too large");
+      numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
     } else {
       Value maxIndex;
       for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
-        Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
-        Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
-        stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
-        Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
-        maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
-                                                               maxThisDim)
-                            : maxThisDim;
+        Value maxThisDim = rewriter.create<LLVM::MulOp>(
+            loc, memrefDescriptor.size(rewriter, loc, i),
+            memrefDescriptor.stride(rewriter, loc, i));
+        maxIndex =
+            maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+                     : maxThisDim;
       }
-      numRecords = maxIndex;
+      numRecords = rewriter.create<LLVM::MulOp>(
+          loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
     }
 
     // Flag word:
@@ -218,40 +258,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     args.push_back(resource);
 
     // Indexing (voffset)
-    Value voffset = createI32Constant(rewriter, loc, 0);
-    for (auto pair : llvm::enumerate(adaptor.getIndices())) {
-      size_t i = pair.index();
-      Value index = pair.value();
-      Value strideOp;
-      if (ShapedType::isDynamic(strides[i])) {
-        strideOp = rewriter.create<LLVM::MulOp>(
-            loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
-            byteWidthConst);
-      } else {
-        strideOp =
-            createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
-      }
-      index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
-      voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
-    }
-    if (adaptor.getIndexOffset()) {
-      int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
-      Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
+    Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
+                                      adaptor.getIndices(), strides);
+    if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
+        indexOffset && *indexOffset > 0) {
+      Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
       voffset =
           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
                   : extraOffsetConst;
     }
+    voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
     args.push_back(voffset);
 
+    // SGPR offset.
     Value sgprOffset = adaptor.getSgprOffset();
     if (!sgprOffset)
       sgprOffset = createI32Constant(rewriter, loc, 0);
-    if (ShapedType::isDynamic(offset))
-      sgprOffset = rewriter.create<LLVM::AddOp>(
-          loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
-    else if (offset > 0)
-      sgprOffset = rewriter.create<LLVM::AddOp>(
-          loc, sgprOffset, createI32Constant(rewriter, loc, offset));
+    sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
     args.push_back(sgprOffset);
 
     // bit 0: GLC = 0 (atomics drop value, less coherency)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 4c7515dc810516..92ecbff3e691dc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -32,14 +32,16 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
 
 // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided
 func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 {
-  // CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
-  // CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+  // CHECK: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+  // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+  // CHECK: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
   // CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-  // CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32
   // CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-  // CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32
-  // CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32
-  // CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32
+  // CHECK: %[[tmp:.*]] = llvm.mul %[[size]], %[[stride]] : i64
+  // CHECK: %[[num_elem:.*]] = llvm.trunc %[[tmp]] : i64 to i32
+  // CHECK: %[[numRecords:.*]] = llvm.mul %[[num_elem]], %[[elem_size]] : i32
   // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
   // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
   // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>

@llvmbot
Copy link
Member

llvmbot commented Jan 9, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Fabian Mora (fabianmcg)

Changes

This patch fixes several bugs in the lowering of AMDGPU raw buffer operations. These bugs include:

  • Incorrectly handling the offset of the memref, causing errors when using subviews. Furthermore, it cannot be assumed the memref offset can be put in a SGPR as it can be a thread dependent value.
  • Using the MaximumOp (float specific op) to calculate the number of records.
  • The number of records in the static shape case.
  • The lowering when index bitwidth=i64.

Furthermore this patch also switches to use MLIR's data layout to get the type size.


Full diff: https://github.com/llvm/llvm-project/pull/122293.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+75-52)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir (+8-6)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 4100b086fad8ba..49ac4723c2fb94 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -30,10 +30,23 @@ namespace mlir {
 using namespace mlir;
 using namespace mlir::amdgpu;
 
+/// Convert an unsigned number `val` to i32.
+static Value convertUnsignedToI32(ConversionPatternRewriter &rewriter,
+                                  Location loc, Value val) {
+  IntegerType i32 = rewriter.getI32Type();
+  // Force check that `val` is of int type.
+  auto valTy = cast<IntegerType>(val.getType());
+  if (i32 == valTy)
+    return val;
+  return valTy.getWidth() > 32
+             ? Value(rewriter.create<LLVM::TruncOp>(loc, i32, val))
+             : Value(rewriter.create<LLVM::ZExtOp>(loc, i32, val));
+}
+
 static Value createI32Constant(ConversionPatternRewriter &rewriter,
                                Location loc, int32_t value) {
-  Type llvmI32 = rewriter.getI32Type();
-  return rewriter.create<LLVM::ConstantOp>(loc, llvmI32, value);
+  Type i32 = rewriter.getI32Type();
+  return rewriter.create<LLVM::ConstantOp>(loc, i32, value);
 }
 
 static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
@@ -42,6 +55,28 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
   return rewriter.create<LLVM::ConstantOp>(loc, llvmI1, value);
 }
 
+/// Returns the linear index used to access an element in the memref.
+static Value getLinearIndexI32(ConversionPatternRewriter &rewriter,
+                               Location loc, MemRefDescriptor &memRefDescriptor,
+                               ValueRange indices, ArrayRef<int64_t> strides) {
+  IntegerType i32 = rewriter.getI32Type();
+  Value index;
+  for (int i = 0, e = indices.size(); i < e; ++i) {
+    Value increment = indices[i];
+    if (strides[i] != 1) { // Skip if stride is 1.
+      Value stride =
+          ShapedType::isDynamic(strides[i])
+              ? convertUnsignedToI32(rewriter, loc,
+                                     memRefDescriptor.stride(rewriter, loc, i))
+              : rewriter.create<LLVM::ConstantOp>(loc, i32, strides[i]);
+      increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
+    }
+    index =
+        index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
+  }
+  return index ? index : createI32Constant(rewriter, loc, 0);
+}
+
 namespace {
 // Define commonly used chipsets versions for convenience.
 constexpr Chipset kGfx908 = Chipset(9, 0, 8);
@@ -88,17 +123,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     Type llvmWantedDataType = this->typeConverter->convertType(wantedDataType);
 
     Type i32 = rewriter.getI32Type();
-    Type llvmI32 = this->typeConverter->convertType(i32);
-    Type llvmI16 = this->typeConverter->convertType(rewriter.getI16Type());
+    Type i16 = rewriter.getI16Type();
 
-    auto toI32 = [&](Value val) -> Value {
-      if (val.getType() == llvmI32)
-        return val;
-
-      return rewriter.create<LLVM::TruncOp>(loc, llvmI32, val);
-    };
-
-    int64_t elementByteWidth = memrefType.getElementTypeBitWidth() / 8;
+    // Get the type size in bytes.
+    DataLayout dataLayout = DataLayout::closest(gpuOp);
+    int64_t elementByteWidth =
+        dataLayout.getTypeSizeInBits(memrefType.getElementType()) / 8;
     Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
 
     // If we want to load a vector<NxT> with total size <= 32
@@ -114,7 +144,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     }
     if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
       uint32_t vecLen = dataVector.getNumElements();
-      uint32_t elemBits = dataVector.getElementTypeBitWidth();
+      uint32_t elemBits =
+          dataLayout.getTypeSizeInBits(dataVector.getElementType());
       uint32_t totalBits = elemBits * vecLen;
       bool usePackedFp16 =
           isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2;
@@ -167,28 +198,37 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
 
     MemRefDescriptor memrefDescriptor(memref);
 
-    Value ptr = memrefDescriptor.alignedPtr(rewriter, loc);
+    Value ptr = memrefDescriptor.bufferPtr(
+        rewriter, loc, *this->getTypeConverter(), memrefType);
     // The stride value is always 0 for raw buffers. This also disables
     // swizling.
     Value stride = rewriter.create<LLVM::ConstantOp>(
-        loc, llvmI16, rewriter.getI16IntegerAttr(0));
+        loc, i16, rewriter.getI16IntegerAttr(0));
+    // Get the number of elements.
     Value numRecords;
-    if (memrefType.hasStaticShape() && memrefType.getLayout().isIdentity()) {
-      numRecords = createI32Constant(
-          rewriter, loc,
-          static_cast<int32_t>(memrefType.getNumElements() * elementByteWidth));
+    if (memrefType.hasStaticShape() && !llvm::any_of(strides, [](int64_t v) {
+          return ShapedType::isDynamic(v);
+        })) {
+      int64_t size = memrefType.getRank() == 0 ? 1 : 0;
+      ArrayRef<int64_t> shape = memrefType.getShape();
+      for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
+        size = std::max(shape[i] * strides[i], size);
+      size = size * elementByteWidth;
+      assert(size < std::numeric_limits<uint32_t>::max() &&
+             "the memref buffer is too large");
+      numRecords = createI32Constant(rewriter, loc, static_cast<int32_t>(size));
     } else {
       Value maxIndex;
       for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
-        Value size = toI32(memrefDescriptor.size(rewriter, loc, i));
-        Value stride = toI32(memrefDescriptor.stride(rewriter, loc, i));
-        stride = rewriter.create<LLVM::MulOp>(loc, stride, byteWidthConst);
-        Value maxThisDim = rewriter.create<LLVM::MulOp>(loc, size, stride);
-        maxIndex = maxIndex ? rewriter.create<LLVM::MaximumOp>(loc, maxIndex,
-                                                               maxThisDim)
-                            : maxThisDim;
+        Value maxThisDim = rewriter.create<LLVM::MulOp>(
+            loc, memrefDescriptor.size(rewriter, loc, i),
+            memrefDescriptor.stride(rewriter, loc, i));
+        maxIndex =
+            maxIndex ? rewriter.create<LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
+                     : maxThisDim;
       }
-      numRecords = maxIndex;
+      numRecords = rewriter.create<LLVM::MulOp>(
+          loc, convertUnsignedToI32(rewriter, loc, maxIndex), byteWidthConst);
     }
 
     // Flag word:
@@ -218,40 +258,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
     args.push_back(resource);
 
     // Indexing (voffset)
-    Value voffset = createI32Constant(rewriter, loc, 0);
-    for (auto pair : llvm::enumerate(adaptor.getIndices())) {
-      size_t i = pair.index();
-      Value index = pair.value();
-      Value strideOp;
-      if (ShapedType::isDynamic(strides[i])) {
-        strideOp = rewriter.create<LLVM::MulOp>(
-            loc, toI32(memrefDescriptor.stride(rewriter, loc, i)),
-            byteWidthConst);
-      } else {
-        strideOp =
-            createI32Constant(rewriter, loc, strides[i] * elementByteWidth);
-      }
-      index = rewriter.create<LLVM::MulOp>(loc, index, strideOp);
-      voffset = rewriter.create<LLVM::AddOp>(loc, voffset, index);
-    }
-    if (adaptor.getIndexOffset()) {
-      int32_t indexOffset = *gpuOp.getIndexOffset() * elementByteWidth;
-      Value extraOffsetConst = createI32Constant(rewriter, loc, indexOffset);
+    Value voffset = getLinearIndexI32(rewriter, loc, memrefDescriptor,
+                                      adaptor.getIndices(), strides);
+    if (std::optional<int32_t> indexOffset = adaptor.getIndexOffset();
+        indexOffset && *indexOffset > 0) {
+      Value extraOffsetConst = createI32Constant(rewriter, loc, *indexOffset);
       voffset =
           voffset ? rewriter.create<LLVM::AddOp>(loc, voffset, extraOffsetConst)
                   : extraOffsetConst;
     }
+    voffset = rewriter.create<LLVM::MulOp>(loc, voffset, byteWidthConst);
     args.push_back(voffset);
 
+    // SGPR offset.
     Value sgprOffset = adaptor.getSgprOffset();
     if (!sgprOffset)
       sgprOffset = createI32Constant(rewriter, loc, 0);
-    if (ShapedType::isDynamic(offset))
-      sgprOffset = rewriter.create<LLVM::AddOp>(
-          loc, toI32(memrefDescriptor.offset(rewriter, loc)), sgprOffset);
-    else if (offset > 0)
-      sgprOffset = rewriter.create<LLVM::AddOp>(
-          loc, sgprOffset, createI32Constant(rewriter, loc, offset));
+    sgprOffset = rewriter.create<LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
     args.push_back(sgprOffset);
 
     // bit 0: GLC = 0 (atomics drop value, less coherency)
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
index 4c7515dc810516..92ecbff3e691dc 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir
@@ -32,14 +32,16 @@ func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
 
 // CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32_strided
 func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<64xi32, strided<[?], offset: ?>>, %idx: i32) -> i32 {
-  // CHECK-DAG: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
-  // CHECK-DAG: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+  // CHECK: %[[elem_size:.*]] = llvm.mlir.constant(4 : i32)
+  // CHECK: %[[algn_ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> 
+  // CHECK: %[[ptr:.*]] = llvm.getelementptr %[[algn_ptr]][%[[offset]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32
+  // CHECK: %[[rstride:.*]] = llvm.mlir.constant(0 : i16)
   // CHECK: %[[size:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-  // CHECK: %[[size32:.*]] = llvm.trunc %[[size]] : i64 to i32
   // CHECK: %[[stride:.*]] = llvm.extractvalue %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-  // CHECK: %[[stride32:.*]] = llvm.trunc %[[stride]] : i64 to i32
-  // CHECK: %[[tmp:.*]] = llvm.mul %[[stride32]], %[[elem_size]] : i32
-  // CHECK: %[[numRecords:.*]] = llvm.mul %[[size32]], %[[tmp]] : i32
+  // CHECK: %[[tmp:.*]] = llvm.mul %[[size]], %[[stride]] : i64
+  // CHECK: %[[num_elem:.*]] = llvm.trunc %[[tmp]] : i64 to i32
+  // CHECK: %[[numRecords:.*]] = llvm.mul %[[num_elem]], %[[elem_size]] : i32
   // GFX9:  %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
   // RDNA:  %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
   // CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %[[rstride]], %[[numRecords]], %[[flags]] : !llvm.ptr to <8>

@kuhar kuhar requested a review from giuseros January 9, 2025 17:24
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have tests for all these bugs?

@fabianmcg
Copy link
Contributor Author

Do we have tests for all these bugs?

We only have one, but it was incorrect, I'll another one more comprehensive.

@krzysz00
Copy link
Contributor

I'm happy with approving this is folks are happy with the tests

@krzysz00
Copy link
Contributor

I'll be happy to land this once we're happy with the tests

@fabianmcg
Copy link
Contributor Author

I'll be happy to land this once we're happy with the tests

Great, I'm still trying to figure out what's happening with the windows side. Because linux passes, and the error is happening in FileCheck, but I've been unable to reproduce.

@fabianmcg fabianmcg requested a review from krzysz00 January 10, 2025 23:06
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might want a separate file with a few tests where the index bitwidth is 64, but, that aside, approved

@fabianmcg
Copy link
Contributor Author

You might want a separate file with a few tests where the index bitwidth is 64, but, that aside, approved

I'll create a new PR for that, currently AMDGPUToROCDL doesn't support setting index-bitwidth. For testing that one needs to use gpu-to-rocdl.

I was also thinking in changing the AMDGPU ops to take index instead of i32 and convert to i32 during lowering. That would make the lowering less buggy prone.

@krzysz00
Copy link
Contributor

Yeah, entirely fair on the pass not being set up for it - though for amdgpu-to-rocdl, you could just read the bitwidth out of the data layout instead of adding a pass option. But that's still a separate PR.

So yeah, approved

@fabianmcg fabianmcg merged commit 0c1c49f into llvm:main Jan 13, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants