Skip to content

Commit f271c01

Browse files
committed
update AMDGPU description.
1 parent 564ebc8 commit f271c01

File tree

3 files changed

+46
-25
lines changed

3 files changed

+46
-25
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -784,15 +784,19 @@ def AMDGPU_GlobalLoadLDSOp :
784784
Results<(outs)> {
785785
let summary = "MLIR wrapper for CDNA mfma instructions";
786786
let description = [{
787-
The `amdgpu.mfma` op is an MLIR wrapper around intrinsics
788-
for various `mfma` instructions in the CDNA architecture, which perform
789-
multiple outer products in order to allow fast matrix multiplication.
790-
791-
The `amdgpu.global_load` op is a wrapper around the various `global_load_lds` instructions.
792-
793-
The `$src`, along with its indices, points to the memory location this thread reads from.
787+
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
788+
789+
Operands:
790+
* `$src`: global memory memref to read from.
791+
* `$srcIndices`: indices into `$src` to read from for this thread.
792+
* `$dst`: LDS memory memref to write to.
793+
* `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread.
794+
number of subgroup size of elements will be written contiguously to `$dst[$dstIndices]`.
795+
794796
The `$dst`, along with its indices, points to the memory location the subgroup of this thread
795797
will write to.
798+
799+
Note: only enabled for gfx942 and later.
796800
}];
797801
let assemblyFormat = [{
798802
$src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` type($src) `,` type($dst)

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -925,23 +925,16 @@ struct GlobalLoadLDSOpLowering
925925
// `global_load_lds` instructions.
926926
auto loadWidth = elemSizeInBits / 8;
927927

928-
// TODO: add chipset support check
929-
if (chipset.majorVersion >= 12)
930-
return op.emitOpError("TODO");
928+
const Chipset GlobalLoadEnabled{9, 0x4, 0x0};
929+
if (chipset < GlobalLoadEnabled)
930+
return op.emitOpError("chipset not supported");
931931

932-
// TODO: fold this into chipset check.
933932
// Currently only 1, 2, and 4 byte loads are supported.
934933
if (!(loadWidth == 1 || loadWidth == 2 || loadWidth == 4))
935-
return op.emitOpError("unsupported element size");
934+
return op.emitOpError("chipset unsupported element size");
936935

937-
Value src = adaptor.getSrc();
938-
Value dst = adaptor.getDst();
939-
Value memrefSrc = op.getSrc();
940-
Value memrefDst = op.getDst();
941-
942-
// Collapse src memref with indices, returns the base pointer and linearized
943-
// index.
944-
auto flattenIndex =
936+
// Return pair of {base pointer, linearized index}.
937+
auto getBasePtrAndLinearizedIndex =
945938
[&](Value memref, MemRefType memrefType,
946939
ValueRange indices) -> std::optional<std::pair<Value, Value>> {
947940
MemRefDescriptor memRefDescriptor(memref);
@@ -955,13 +948,14 @@ struct GlobalLoadLDSOpLowering
955948
getLinearIndexI32(rewriter, loc, memRefDescriptor, indices, strides));
956949
};
957950

958-
// Source
959-
auto optSrcBuffer = flattenIndex(src, cast<MemRefType>(memrefSrc.getType()),
960-
op.getSrcIndices());
951+
auto optSrcBuffer = getBasePtrAndLinearizedIndex(
952+
adaptor.getSrc(), cast<MemRefType>(op.getSrc().getType()),
953+
op.getSrcIndices());
961954
if (!optSrcBuffer)
962955
return op.emitOpError("failed to flatten source memref indices");
963-
auto optDstBuffer = flattenIndex(dst, cast<MemRefType>(memrefDst.getType()),
964-
op.getDstIndices());
956+
auto optDstBuffer = getBasePtrAndLinearizedIndex(
957+
adaptor.getDst(), cast<MemRefType>(op.getDst().getType()),
958+
op.getDstIndices());
965959
if (!optDstBuffer)
966960
return op.emitOpError("failed to flatten destination memref indices");
967961

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// test pass doesn't set up the GPU address space conversions.
1010

1111
#gpu_global_addrspace = 1
12+
#gpu_lds_addrspace = 3
1213

1314
// CHECK-LABEL: func @fat_raw_buffer_cast
1415
func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
@@ -461,3 +462,25 @@ func.func @sched_barrier() {
461462
amdgpu.sched_barrier allow = <valu|all_vmem>
462463
func.return
463464
}
465+
466+
// CHECK-LABEL: func @global_load_to_rocdl_f32
467+
// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>)
468+
func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) {
469+
%c0 = arith.constant 0 : i32
470+
%c12 = arith.constant 12 : i32
471+
%c32 = arith.constant 32 : i32
472+
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
473+
// GFX942: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %arg0 : memref<128x72xf32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
474+
// GFX942: %[[ALLOC:.*]] = memref.alloc() : memref<64x64xf32, 3>
475+
// GFX942: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] : memref<64x64xf32, 3> to !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
476+
// GFX942: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] : !llvm.struct<(ptr<1>, ptr<1>, i64, array<2 x i64>, array<2 x i64>)>
477+
// GFX942: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)>
478+
// GFX942: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[GLOBAL_OFFSET:.*]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f32
479+
// GFX942: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[LDS_OFFSET:.*]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f32
480+
// GFX942: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32
481+
// GFX942: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
482+
// GFX942: %[[C0_2:.*]] = llvm.mlir.constant(0 : i32) : i32
483+
// GFX942: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]], %[[C0]], %[[C0_2]]
484+
amdgpu.global_load %global[%c12, %c0], %alloc[%c32, %c0] : memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
485+
func.return
486+
}

0 commit comments

Comments
 (0)