Skip to content

Commit dae0ef5

Browse files
authored
[MLIR][AMDGPU] Add a wrapper for global LDS load intrinsics in AMDGPU (#133498)
Defining a new `amdgpu.global_load` op, which is a thin wrap around ROCDL `global_load_lds` intrinsic, along with its lowering logics to `rocdl.global.load.lds`.
1 parent 94b9d75 commit dae0ef5

File tree

4 files changed

+285
-11
lines changed

4 files changed

+285
-11
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,4 +767,40 @@ def AMDGPU_WMMAOp :
767767
let hasVerifier = 1;
768768
}
769769

770+
def AMDGPU_GatherToLDSOp :
771+
AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>,
772+
Arguments<(ins
773+
Arg<AnyMemRef, "buffer to gather from", [MemRead]>:$src,
774+
Variadic<Index>:$srcIndices,
775+
Arg<AnyMemRef, "buffer to write to", [MemWrite]>:$dst,
776+
Variadic<Index>:$dstIndices,
777+
TypeAttr:$transferType
778+
)>,
779+
Results<(outs)> {
780+
let summary = "MLIR wrapper for CDNA mfma instructions";
781+
let description = [{
782+
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
783+
784+
Operands:
785+
* `$src`: global memory memref to read from.
786+
* `$srcIndices`: indices into `$src` to read from for this thread.
787+
* `$dst`: LDS memory memref to write to.
788+
* `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
789+
The elements gathered by the subgroup will be written in order of lane ID will be written
790+
into contiguously starting at `$dst[$dstIndices]`.
791+
* `$transferType`: type of the data to be transferred by each thread. This is used to determine
792+
the size of the data to be transferred and the number of threads in the subgroup.
793+
The transfer type must be a scalar type or a vector type with a single element type.
794+
795+
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
796+
will write to.
797+
798+
Note: only enabled for gfx942 and later.
799+
}];
800+
let assemblyFormat = [{
801+
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst)
802+
}];
803+
let hasVerifier = 1;
804+
}
805+
770806
#endif // AMDGPU

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,55 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
10101010
}
10111011
};
10121012

1013+
struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
1014+
GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
1015+
: ConvertOpToLLVMPattern<GatherToLDSOp>(converter), chipset(chipset) {}
1016+
1017+
Chipset chipset;
1018+
1019+
LogicalResult
1020+
matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor,
1021+
ConversionPatternRewriter &rewriter) const override {
1022+
if (chipset < kGfx942)
1023+
return op.emitOpError("chipset not supported");
1024+
1025+
Location loc = op.getLoc();
1026+
1027+
auto srcMemRefType = cast<MemRefType>(op.getSrc().getType());
1028+
auto dstMemRefType = cast<MemRefType>(op.getSrc().getType());
1029+
1030+
// TODO: instead of only transfering one element per thread, we could
1031+
// augment it to transfer multiple elements per thread by issuing multiple
1032+
// `global_load_lds` instructions.
1033+
Type transferType = op.getTransferType();
1034+
size_t loadWidth = [&]() -> size_t {
1035+
if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1036+
return transferVectorType.getNumElements() *
1037+
(transferVectorType.getElementTypeBitWidth() / 8);
1038+
} else {
1039+
return transferType.getIntOrFloatBitWidth() / 8;
1040+
}
1041+
}();
1042+
1043+
// Currently only 1, 2, and 4 byte loads are supported.
1044+
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1045+
return op.emitOpError("chipset unsupported element size");
1046+
1047+
Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(),
1048+
(adaptor.getSrcIndices()), rewriter);
1049+
Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(),
1050+
(adaptor.getDstIndices()), rewriter);
1051+
1052+
rewriter.replaceOpWithNewOp<ROCDL::GlobalLoadLDSOp>(
1053+
op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth),
1054+
createI32Constant(rewriter, loc, 0),
1055+
createI32Constant(rewriter, loc, 0), ArrayAttr{}, ArrayAttr{},
1056+
ArrayAttr{});
1057+
1058+
return success();
1059+
}
1060+
};
1061+
10131062
namespace {
10141063
struct ExtPackedFp8OpLowering final
10151064
: public ConvertOpToLLVMPattern<ExtPackedFp8Op> {
@@ -1393,6 +1442,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
13931442
ROCDL::RawPtrBufferAtomicCmpSwap>,
13941443
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
13951444
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1396-
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter,
1397-
chipset);
1445+
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
1446+
GatherToLDSOpLowering>(converter, chipset);
13981447
}

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1717
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1819
#include "mlir/IR/Builders.h"
1920
#include "mlir/IR/BuiltinTypes.h"
2021
#include "mlir/IR/Diagnostics.h"
@@ -24,6 +25,7 @@
2425
#include "mlir/IR/PatternMatch.h"
2526
#include "mlir/IR/TypeUtilities.h"
2627
#include "llvm/ADT/TypeSwitch.h"
28+
#include "llvm/IR/DerivedTypes.h"
2729

2830
#include <limits>
2931
#include <optional>
@@ -112,21 +114,31 @@ LogicalResult FatRawBufferCastOp::verify() {
112114
return success();
113115
}
114116

117+
static bool hasGlobalMemorySpace(Attribute memorySpace) {
118+
if (!memorySpace)
119+
return true;
120+
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
121+
return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
122+
if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
123+
return gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
124+
return false;
125+
}
126+
127+
static bool hasWorkgroupMemorySpace(Attribute memorySpace) {
128+
if (auto intMemorySpace = dyn_cast<IntegerAttr>(memorySpace))
129+
return intMemorySpace.getInt() == 3;
130+
if (auto gpuMemorySpace = dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
131+
return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup;
132+
return false;
133+
}
134+
115135
//===----------------------------------------------------------------------===//
116136
// RawBuffer*Op
117137
//===----------------------------------------------------------------------===//
118138
template <typename T>
119139
static LogicalResult verifyRawBufferOp(T &op) {
120140
MemRefType bufferType = llvm::cast<MemRefType>(op.getMemref().getType());
121-
Attribute memorySpace = bufferType.getMemorySpace();
122-
bool isGlobal = false;
123-
if (!memorySpace)
124-
isGlobal = true;
125-
else if (auto intMemorySpace = llvm::dyn_cast<IntegerAttr>(memorySpace))
126-
isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1;
127-
else if (auto gpuMemorySpace =
128-
llvm::dyn_cast<gpu::AddressSpaceAttr>(memorySpace))
129-
isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global;
141+
bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace());
130142

131143
if (!isGlobal)
132144
return op.emitOpError(
@@ -461,6 +473,40 @@ LogicalResult DPPOp::verify() {
461473
return success();
462474
}
463475

476+
LogicalResult GatherToLDSOp::verify() {
477+
MemRefType srcType = cast<MemRefType>(getSrc().getType());
478+
MemRefType dstType = cast<MemRefType>(getDst().getType());
479+
480+
if (!memref::isStaticShapeAndContiguousRowMajor(dstType))
481+
return emitOpError(
482+
"destination types must have static shape and contiguous");
483+
484+
auto elemType = srcType.getElementType();
485+
// Check $src and $dst element types are the same.
486+
if (elemType != dstType.getElementType())
487+
return emitOpError("source and destination element types must match");
488+
489+
// copy type sizes should be 1, 2, or 4 bytes.
490+
auto transferType = getTransferType();
491+
size_t transferSize;
492+
if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
493+
transferSize = vectorTransfer.getNumElements() *
494+
vectorTransfer.getElementTypeBitWidth();
495+
} else {
496+
transferSize = transferType.getIntOrFloatBitWidth();
497+
}
498+
if (transferSize != 8 && transferSize != 16 && transferSize != 32)
499+
return emitOpError("Transfering type size must be 8, 16, or 32 bits");
500+
501+
if (!hasGlobalMemorySpace(srcType.getMemorySpace()))
502+
return emitOpError("source memory address space must be Global");
503+
504+
if (!hasWorkgroupMemorySpace(dstType.getMemorySpace()))
505+
return emitOpError("destination memory address space must be Workgroup");
506+
507+
return success();
508+
}
509+
464510
#include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc"
465511

466512
#define GET_ATTRDEF_CLASSES
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s
2+
3+
#gpu_global_addrspace = 1
4+
#gpu_lds_addrspace = 3
5+
6+
// CHECK-LABEL: func @global_load_to_rocdl_f32
7+
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
8+
func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
9+
%c0 = arith.constant 0 : index
10+
%c12 = arith.constant 12 : index
11+
%c32 = arith.constant 32 : index
12+
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
13+
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
14+
15+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
16+
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
17+
// CHECK: %[[C12:.*]] = arith.constant 12 : index
18+
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
19+
// CHECK: %[[C32:.*]] = arith.constant 32 : index
20+
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
21+
22+
// CHECK: %[[ALLOC:.*]] = memref.alloc()
23+
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
24+
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
25+
26+
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
27+
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
28+
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
29+
30+
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
31+
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
32+
33+
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
34+
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
35+
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
36+
37+
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
38+
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
39+
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
40+
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
41+
: f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
42+
func.return
43+
}
44+
45+
// CHECK-LABEL: func @global_load_to_rocdl_i8
46+
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>)
47+
func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) {
48+
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
49+
50+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
51+
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
52+
// CHECK: %[[C12:.*]] = arith.constant 12 : index
53+
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
54+
// CHECK: %[[C32:.*]] = arith.constant 32 : index
55+
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
56+
57+
// CHECK: %[[ALLOC:.*]] = memref.alloc()
58+
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
59+
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
60+
61+
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
62+
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
63+
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
64+
65+
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
66+
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
67+
68+
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
69+
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
70+
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
71+
72+
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
73+
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
74+
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C1]]
75+
%c0 = arith.constant 0 : index
76+
%c12 = arith.constant 12 : index
77+
%c32 = arith.constant 32 : index
78+
%alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace>
79+
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
80+
: i8, memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace>
81+
func.return
82+
}
83+
84+
// CHECK-LABEL: func @global_load_to_rocdl_vec
85+
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi16, 1>)
86+
func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_addrspace>) {
87+
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
88+
89+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
90+
// CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
91+
// CHECK: %[[C12:.*]] = arith.constant 12 : index
92+
// CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
93+
// CHECK: %[[C32:.*]] = arith.constant 32 : index
94+
// CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
95+
96+
// CHECK: %[[ALLOC:.*]] = memref.alloc()
97+
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
98+
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
99+
100+
// CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
101+
// CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
102+
// CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
103+
104+
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
105+
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
106+
107+
// CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64
108+
// CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64
109+
// CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
110+
111+
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
112+
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
113+
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
114+
%c0 = arith.constant 0 : index
115+
%c12 = arith.constant 12 : index
116+
%c32 = arith.constant 32 : index
117+
%alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace>
118+
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
119+
: vector<2 x i16>, memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace>
120+
func.return
121+
}
122+
123+
124+
// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices
125+
// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index)
126+
func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) {
127+
// CHECK: %[[DSTIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[DST_IDX]]
128+
// CHECK: %[[SRCIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC_IDX]]
129+
// CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
130+
// CHECK: %[[ALLOC:.*]] = memref.alloc()
131+
// CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]]
132+
// CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1]
133+
// CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]]
134+
// CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
135+
// CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]]
136+
// CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
137+
// CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]]
138+
%alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace>
139+
%c0 = arith.constant 0 : index
140+
amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0]
141+
: i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace>
142+
func.return
143+
}

0 commit comments

Comments
 (0)