Skip to content

Commit 27f15ad

Browse files
plognjenoplavsic
andauthored
[MLIR][ROCDL] Add ops for LDS read transpose and global to LDS intrinsics (#123530)
This PR adds missing ds\.read.tr4\.b64, ds\.read\.tr8\.b64, ds\.read\.tr6\.b96, ds\.read\.tr16\.b64 and global\.load\.lds ops to the ROCDL dialect. The ops are converted to the corresponding intrinsic calls during the translation from MLIR to LLVM IRs. --------- Co-authored-by: Ognjen Plavsic <[email protected]>
1 parent a79098b commit 27f15ad

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,36 @@ def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]
412412
def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
413413
def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;
414414

415+
//===---------------------------------------------------------------------===//
416+
// LDS transpose intrinsics (available in GFX950)
417+
418+
def ROCDLGlobalBuffer : LLVM_PointerInAddressSpace<1>;
419+
def ROCDLBufferLDS : LLVM_PointerInAddressSpace<3>;
420+
421+
class ROCDL_LDS_Read_Tr_IntrOp<string mnemonic> :
422+
ROCDL_IntrOp<mnemonic, [1], [], [], 1>,
423+
Arguments<(ins Arg<ROCDLBufferLDS, "", [MemRead]>:$ptr)>{
424+
let assemblyFormat = "$ptr attr-dict `:` type($ptr) `->` type($res)";
425+
}
426+
427+
def ROCDL_ds_read_tr4_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr4.b64">;
428+
def ROCDL_ds_read_tr8_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr8.b64">;
429+
def ROCDL_ds_read_tr6_b96 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr6.b96">;
430+
def ROCDL_ds_read_tr16_b64 : ROCDL_LDS_Read_Tr_IntrOp<"ds.read.tr16.b64">;
431+
432+
//===---------------------------------------------------------------------===//
433+
// Global load to LDS intrinsic (available in GFX950)
434+
435+
def ROCDL_GlobalLoadLDSOp :
436+
ROCDL_IntrOp<"global.load.lds", [], [], [], 0>,
437+
Arguments<(ins Arg<ROCDLGlobalBuffer, "", [MemRead]>:$globalPtr,
438+
Arg<ROCDLBufferLDS, "", [MemWrite]>:$ldsPtr,
439+
I32:$size,
440+
I32:$offset,
441+
I32:$aux)> {
442+
let assemblyFormat = "operands attr-dict";
443+
}
444+
415445
//===---------------------------------------------------------------------===//
416446
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
417447
// raw buffer mode).

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,32 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
227227
llvm.return
228228
}
229229

230+
llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
231+
// CHECK-LABEL: rocdl.ds.read.tr
232+
// CHECK: rocdl.ds.read.tr4.b64 {{.*}} : <3> -> vector<2xi32>
233+
%r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
234+
// CHECK: rocdl.ds.read.tr6.b96 {{.*}} : <3> -> vector<3xi32>
235+
%r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
236+
// CHECK: rocdl.ds.read.tr8.b64 {{.*}} : <3> -> vector<2xi32>
237+
%r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
238+
// CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xf16>
239+
%r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
240+
// CHECK: rocdl.ds.read.tr16.b64 {{.*}} : <3> -> vector<4xbf16>
241+
%r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
242+
llvm.return %r3 : vector<4xf16>
243+
}
244+
245+
llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
246+
%aux = llvm.mlir.constant(0 : i32) : i32
247+
%offset = llvm.mlir.constant(0 : i32) : i32
248+
%size = llvm.mlir.constant(10 : i32) : i32
249+
250+
//CHECK: rocdl.global.load.lds %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
251+
rocdl.global.load.lds %src, %dst, %size, %offset, %aux
252+
253+
llvm.return
254+
}
255+
230256
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
231257
%stride : i16,
232258
%numRecords : i32,

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,30 @@ llvm.func @rocdl.wmma(%arg0 : vector<8xf32>, %arg1 : vector<16 x f16>, %arg2 : v
424424
llvm.return %r0 : vector<8xf32>
425425
}
426426

427+
llvm.func @rocdl.ds.read.tr(%ptr : !llvm.ptr<3>) -> vector<4xf16> {
428+
// CHECK-LABEL: rocdl.ds.read.tr
429+
// CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr4.b64.v2i32(ptr addrspace(3) %0)
430+
%r0 = rocdl.ds.read.tr4.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
431+
// CHECK: call <3 x i32> @llvm.amdgcn.ds.read.tr6.b96.v3i32(ptr addrspace(3) %0)
432+
%r1 = rocdl.ds.read.tr6.b96 %ptr : !llvm.ptr<3> -> vector<3xi32>
433+
// CHECK: call <2 x i32> @llvm.amdgcn.ds.read.tr8.b64.v2i32(ptr addrspace(3) %0)
434+
%r2 = rocdl.ds.read.tr8.b64 %ptr : !llvm.ptr<3> -> vector<2xi32>
435+
// CHECK: call <4 x half> @llvm.amdgcn.ds.read.tr16.b64.v4f16(ptr addrspace(3) %0)
436+
%r3 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xf16>
437+
// CHECK: call <4 x bfloat> @llvm.amdgcn.ds.read.tr16.b64.v4bf16(ptr addrspace(3) %0)
438+
%r4 = rocdl.ds.read.tr16.b64 %ptr : !llvm.ptr<3> -> vector<4xbf16>
439+
llvm.return %r3 : vector<4xf16>
440+
}
441+
442+
llvm.func @rocdl.global.load.lds(%src : !llvm.ptr<1>, %dst: !llvm.ptr<3>) {
443+
%aux = llvm.mlir.constant(0 : i32) : i32
444+
%offset = llvm.mlir.constant(0 : i32) : i32
445+
%size = llvm.mlir.constant(10 : i32) : i32
446+
//CHECK: call void @llvm.amdgcn.global.load.lds
447+
rocdl.global.load.lds %src, %dst, %size, %offset, %aux
448+
llvm.return
449+
}
450+
427451
llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
428452
%stride : i16,
429453
%numRecords : i32,

0 commit comments

Comments
 (0)