Skip to content

[MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU #133498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 8, 2025
36 changes: 36 additions & 0 deletions mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -767,4 +767,40 @@ def AMDGPU_WMMAOp :
let hasVerifier = 1;
}

def AMDGPU_GatherToLDSOp :
AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
Arguments<(ins
Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
Variadic<Index>:$srcIndices,
Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
Variadic<Index>:$dstIndices,
TypeAttr:$transferType
)>,
Results<(outs)> {
let summary = "MLIR wrapper for CDNA mfma instructions";
let description = [{
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.

Operands:
* `$src`: global memory memref to read from.
* `$srcIndices`: indices into `$src` to read from for this thread.
* `$dst`: LDS memory memref to write to.
* `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
The elements gathered by the subgroup will be written in order of lane ID will be written
into contiguously starting at `$dst[$dstIndices]`.
* `$transferType`: type of the data to be transferred by each thread. This is used to determine
the size of the data to be transferred and the number of threads in the subgroup.
The transfer type must be a scalar type or a vector type with a single element type.

The `$dst`, along with its indices, points to the memory location the subgroup of this thread
will write to.

Note: only enabled for gfx942 and later.
}];
let assemblyFormat = [{
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
}];
let hasVerifier = 1;
}

#endif // AMDGPU
53 changes: 51 additions & 2 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,55 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
}
};

struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}

Chipset chipset;

LogicalResult
matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (chipset < kGfx942)
return op.emitOpError("chipset not supported");

Location loc = op.getLoc();

auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());

// TODO: instead of only transfering one element per thread, we could
// augment it to transfer multiple elements per thread by issuing multiple
// `global_load_lds` instructions.
Type transferType = op.getTransferType();
size_t loadWidth = [&]() -> size_t {
if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
return transferVectorType.getNumElements() *
(transferVectorType.getElementTypeBitWidth() / 8);
} else {
return transferType.getIntOrFloatBitWidth() / 8;
}
Comment on lines +1038 to +1040
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}();

// Currently only 1, 2, and 4 byte loads are supported.
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
return op.emitOpError("chipset unsupported element size");

Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
(adaptor.getSrcIndices()), rewriter);
Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
(adaptor.getDstIndices()), rewriter);

rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
createI32Constant(rewriter, loc, 0),
createI32Constant(rewriter, loc, 0), ArrayAttr{}, ArrayAttr{},
ArrayAttr{});

return success();
}
};

namespace {
struct ExtPackedFp8OpLowering final
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
Expand Down Expand Up @@ -1393,6 +1442,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
ROCDL::RawPtrBufferAtomicCmpSwap>,
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
chipset);
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
GatherToLDSOpLowering>(converter, chipset);
}
64 changes: 55 additions & 9 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to link MLIRMemRefUtils library in cmake to fix the buildbot "undefined reference failure" "mlir::memref::isStaticShapeAndContiguousRowMajor(mlir::MemRefType)"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix up here: #134862

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
Expand All @@ -24,6 +25,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/DerivedTypes.h"

#include <limits>
#include <optional>
Expand Down Expand Up @@ -112,21 +114,31 @@ LogicalResult FatRawBufferCastOp::verify() {
return success();
}

static bool hasGlobalMemorySpace(Attribute memorySpace) {
if (!memorySpace)
return true;
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
return false;
}

static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
return intMemorySpace.getInt() == 3;
if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
return false;
}

//===----------------------------------------------------------------------===//
// RawBuffer*Op
//===----------------------------------------------------------------------===//
template <typename T>
static LogicalResult verifyRawBufferOp(T &op) {
MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
Attribute memorySpace = bufferType.getMemorySpace();
bool isGlobal = false;
if (!memorySpace)
isGlobal = true;
else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
else if (auto gpuMemorySpace =
llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());

if (!isGlobal)
return op.emitOpError(
Expand Down Expand Up @@ -461,6 +473,40 @@ LogicalResult DPPOp::verify() {
return success();
}

LogicalResult GatherToLDSOp::verify() {
MemRefType srcType = cast<MemRefType>(getSrc().getType());
MemRefType dstType = cast<MemRefType>(getDst().getType());

if (!memref::isStaticShapeAndContiguousRowMajor(dstType))
return emitOpError(
"destination types must have static shape and contiguous");

auto elemType = srcType.getElementType();
// Check $src and $dst element types are the same.
if (elemType != dstType.getElementType())
return emitOpError("source and destination element types must match");

// copy type sizes should be 1, 2, or 4 bytes.
auto transferType = getTransferType();
size_t transferSize;
if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
transferSize = vectorTransfer.getNumElements() *
vectorTransfer.getElementTypeBitWidth();
} else {
transferSize = transferType.getIntOrFloatBitWidth();
}
if (transferSize != 8 && transferSize != 16 && transferSize != 32)
return emitOpError("Transfering type size must be 8, 16, or 32 bits");

if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
return emitOpError("source memory address space must be Global");

if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
return emitOpError("destination memory address space must be Workgroup");

return success();
}

#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
Expand Down
143 changes: 143 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s

#gpu_global_addrspace = 1
#gpu_lds_addrspace = 3

// CHECK-LABEL: func @global_load_to_rocdl_f32
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
// CHECK: %[[C12:.*]] = arith.constant 12 : index
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
// CHECK: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]

// CHECK: %[[ALLOC:.*]] = memref.alloc()
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]

// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64

// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]

// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64

// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
: f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
func.return
}

// CHECK-LABEL: func @global_load_to_rocdl_i8
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
// CHECK: %[[C12:.*]] = arith.constant 12 : index
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
// CHECK: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]

// CHECK: %[[ALLOC:.*]] = memref.alloc()
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]

// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64

// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]

// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64

// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C1]]
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
: i8, memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
func.return
}

// CHECK-LABEL: func @global_load_to_rocdl_vec
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi16, 1>)
func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_addrspace>) {
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]

// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
// CHECK: %[[C12:.*]] = arith.constant 12 : index
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
// CHECK: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]

// CHECK: %[[ALLOC:.*]] = memref.alloc()
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]

// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64

// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]

// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64

// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c32 = arith.constant 32 : index
%alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
: vector<2 x i16>, memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
func.return
}


// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices
// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index)
func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) {
// CHECK: %[[DSTIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[DST_IDX]]
// CHECK: %[[SRCIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC_IDX]]
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: %[[ALLOC:.*]] = memref.alloc()
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
%alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
%c0 = arith.constant 0 : index
amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0]
: i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
func.return
}
Loading