-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU #133498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ba38466
8e46990
c072e66
564ebc8
92a1ef9
a17e854
4ed2006
81bb8bc
73629f4
b483701
1a40d6c
d68db39
33fbbc3
7aad9fb
85df43b
fd97908
cd22692
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1010,6 +1010,55 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> { | |
} | ||
}; | ||
|
||
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> { | ||
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset) | ||
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {} | ||
|
||
Chipset chipset; | ||
|
||
LogicalResult | ||
matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
if (chipset < kGfx942) | ||
return op.emitOpError("chipset not supported"); | ||
|
||
Location loc = op.getLoc(); | ||
|
||
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType()); | ||
auto dstMemRefType = cast<MemRefType>(op.getSrc().getType()); | ||
|
||
// TODO: instead of only transfering one element per thread, we could | ||
// augment it to transfer multiple elements per thread by issuing multiple | ||
// `global_load_lds` instructions. | ||
Type transferType = op.getTransferType(); | ||
size_t loadWidth = [&]() -> size_t { | ||
if (auto transferVectorType = dyn_cast<VectorType>(transferType)) { | ||
return transferVectorType.getNumElements() * | ||
(transferVectorType.getElementTypeBitWidth() / 8); | ||
} else { | ||
return transferType.getIntOrFloatBitWidth() / 8; | ||
} | ||
Comment on lines
+1038
to
+1040
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no else after return: https://llvm.org/docs/CodingStandards.html#don-t-use-else-after-a-return |
||
}(); | ||
|
||
// Currently only 1, 2, and 4 byte loads are supported. | ||
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4) | ||
return op.emitOpError("chipset unsupported element size"); | ||
|
||
Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(), | ||
(adaptor.getSrcIndices()), rewriter); | ||
Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(), | ||
(adaptor.getDstIndices()), rewriter); | ||
|
||
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>( | ||
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth), | ||
createI32Constant(rewriter, loc, 0), | ||
createI32Constant(rewriter, loc, 0), ArrayAttr{}, ArrayAttr{}, | ||
ArrayAttr{}); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
namespace { | ||
struct ExtPackedFp8OpLowering final | ||
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> { | ||
|
@@ -1393,6 +1442,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, | |
ROCDL::RawPtrBufferAtomicCmpSwap>, | ||
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, | ||
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, | ||
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter, | ||
chipset); | ||
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, | ||
GatherToLDSOpLowering>(converter, chipset); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/GPU/IR/GPUDialect.h" | ||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" | ||
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to link MLIRMemRefUtils library in cmake to fix the buildbot "undefined reference failure" "mlir::memref::isStaticShapeAndContiguousRowMajor(mlir::MemRefType)" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix up here: #134862 |
||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/Diagnostics.h" | ||
|
@@ -24,6 +25,7 @@ | |
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
#include "llvm/IR/DerivedTypes.h" | ||
|
||
#include <limits> | ||
#include <optional> | ||
|
@@ -112,21 +114,31 @@ LogicalResult FatRawBufferCastOp::verify() { | |
return success(); | ||
} | ||
|
||
static bool hasGlobalMemorySpace(Attribute memorySpace) { | ||
if (!memorySpace) | ||
return true; | ||
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace)) | ||
return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; | ||
if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) | ||
return gpuMemorySpace.getValue() == gpu::AddressSpace::Global; | ||
return false; | ||
} | ||
|
||
static bool hasWorkgroupMemorySpace(Attribute memorySpace) { | ||
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace)) | ||
return intMemorySpace.getInt() == 3; | ||
krzysz00 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) | ||
return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup; | ||
return false; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// RawBuffer*Op | ||
//===----------------------------------------------------------------------===// | ||
template <typename T> | ||
static LogicalResult verifyRawBufferOp(T &op) { | ||
MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType()); | ||
Attribute memorySpace = bufferType.getMemorySpace(); | ||
bool isGlobal = false; | ||
if (!memorySpace) | ||
isGlobal = true; | ||
else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace)) | ||
isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; | ||
else if (auto gpuMemorySpace = | ||
llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace)) | ||
isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global; | ||
bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace()); | ||
|
||
if (!isGlobal) | ||
return op.emitOpError( | ||
|
@@ -461,6 +473,40 @@ LogicalResult DPPOp::verify() { | |
return success(); | ||
} | ||
|
||
LogicalResult GatherToLDSOp::verify() { | ||
MemRefType srcType = cast<MemRefType>(getSrc().getType()); | ||
MemRefType dstType = cast<MemRefType>(getDst().getType()); | ||
|
||
if (!memref::isStaticShapeAndContiguousRowMajor(dstType)) | ||
return emitOpError( | ||
"destination types must have static shape and contiguous"); | ||
|
||
auto elemType = srcType.getElementType(); | ||
// Check $src and $dst element types are the same. | ||
if (elemType != dstType.getElementType()) | ||
return emitOpError("source and destination element types must match"); | ||
|
||
// copy type sizes should be 1, 2, or 4 bytes. | ||
auto transferType = getTransferType(); | ||
size_t transferSize; | ||
if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) { | ||
transferSize = vectorTransfer.getNumElements() * | ||
vectorTransfer.getElementTypeBitWidth(); | ||
} else { | ||
transferSize = transferType.getIntOrFloatBitWidth(); | ||
} | ||
if (transferSize != 8 && transferSize != 16 && transferSize != 32) | ||
return emitOpError("Transfering type size must be 8, 16, or 32 bits"); | ||
|
||
if (!hasGlobalMemorySpace(srcType.getMemorySpace())) | ||
return emitOpError("source memory address space must be Global"); | ||
|
||
if (!hasWorkgroupMemorySpace(dstType.getMemorySpace())) | ||
return emitOpError("destination memory address space must be Workgroup"); | ||
|
||
return success(); | ||
} | ||
|
||
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" | ||
|
||
#define GET_ATTRDEF_CLASSES | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s | ||
|
||
#gpu_global_addrspace = 1 | ||
#gpu_lds_addrspace = 3 | ||
|
||
// CHECK-LABEL: func @global_load_to_rocdl_f32 | ||
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>) | ||
func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) { | ||
%c0 = arith.constant 0 : index | ||
%c12 = arith.constant 12 : index | ||
%c32 = arith.constant 32 : index | ||
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace> | ||
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] | ||
|
||
// CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64 | ||
// CHECK: %[[C12:.*]] = arith.constant 12 : index | ||
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]] | ||
// CHECK: %[[C32:.*]] = arith.constant 32 : index | ||
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]] | ||
|
||
// CHECK: %[[ALLOC:.*]] = memref.alloc() | ||
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast | ||
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] | ||
|
||
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64 | ||
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64 | ||
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64 | ||
|
||
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]] | ||
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] | ||
|
||
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64 | ||
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64 | ||
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64 | ||
|
||
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]] | ||
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32 | ||
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]] | ||
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] | ||
: f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace> | ||
func.return | ||
} | ||
|
||
// CHECK-LABEL: func @global_load_to_rocdl_i8 | ||
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>) | ||
func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) { | ||
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] | ||
|
||
// CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64 | ||
// CHECK: %[[C12:.*]] = arith.constant 12 : index | ||
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]] | ||
// CHECK: %[[C32:.*]] = arith.constant 32 : index | ||
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]] | ||
|
||
// CHECK: %[[ALLOC:.*]] = memref.alloc() | ||
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] | ||
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] | ||
|
||
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64 | ||
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64 | ||
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64 | ||
|
||
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]] | ||
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] | ||
|
||
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64 | ||
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64 | ||
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64 | ||
|
||
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]] | ||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 | ||
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C1]] | ||
%c0 = arith.constant 0 : index | ||
%c12 = arith.constant 12 : index | ||
%c32 = arith.constant 32 : index | ||
%alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace> | ||
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] | ||
: i8, memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace> | ||
func.return | ||
} | ||
|
||
// CHECK-LABEL: func @global_load_to_rocdl_vec | ||
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi16, 1>) | ||
func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_addrspace>) { | ||
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] | ||
|
||
// CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64 | ||
// CHECK: %[[C12:.*]] = arith.constant 12 : index | ||
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]] | ||
// CHECK: %[[C32:.*]] = arith.constant 32 : index | ||
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]] | ||
|
||
// CHECK: %[[ALLOC:.*]] = memref.alloc() | ||
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] | ||
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] | ||
|
||
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64 | ||
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64 | ||
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64 | ||
|
||
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]] | ||
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] | ||
|
||
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64 | ||
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64 | ||
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64 | ||
|
||
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]] | ||
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32 | ||
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]] | ||
%c0 = arith.constant 0 : index | ||
%c12 = arith.constant 12 : index | ||
%c32 = arith.constant 32 : index | ||
%alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace> | ||
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] | ||
: vector<2 x i16>, memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace> | ||
func.return | ||
} | ||
|
||
|
||
// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices | ||
// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index) | ||
func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) { | ||
// CHECK: %[[DSTIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[DST_IDX]] | ||
// CHECK: %[[SRCIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC_IDX]] | ||
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] | ||
// CHECK: %[[ALLOC:.*]] = memref.alloc() | ||
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] | ||
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] | ||
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]] | ||
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] | ||
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]] | ||
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32 | ||
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]] | ||
%alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace> | ||
%c0 = arith.constant 0 : index | ||
amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0] | ||
: i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace> | ||
func.return | ||
} | ||
lialan marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.