Skip to content

Commit 7d55b91

Browse files
authored
[mlir][nvgpu] Support strided memref when creating TMA descriptor (llvm#85652)
1 parent 1261c02 commit 7d55b91

File tree

2 files changed

+173
-24
lines changed

2 files changed

+173
-24
lines changed

mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -423,24 +423,27 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled(
423423
elementStrides[4], interleave, swizzle, l2Promotion, oobFill);
424424
}
425425

426-
namespace {
427-
428-
template <int rank>
429-
void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr,
430-
uint64_t *globalDim) {
426+
template <int Rank>
427+
void mgpuGetMemRefDataAndShape(void *rawDescriptor, char **addr,
428+
uint64_t *globalDim, uint64_t *globalStrides,
429+
const CUtensorMapDataType tensorDataType) {
431430
auto descriptor =
432-
reinterpret_cast<StridedMemRefType<char, rank> *>(raw_descriptor);
431+
reinterpret_cast<StridedMemRefType<char, Rank> *>(rawDescriptor);
433432
*addr = descriptor->data;
434-
for (int i = 0; i < rank; ++i) {
435-
globalDim[i] = static_cast<uint64_t>(descriptor->sizes[rank - i - 1]);
433+
for (int i = 0; i < Rank; ++i) {
434+
globalDim[i] = static_cast<uint64_t>(descriptor->sizes[Rank - i - 1]);
435+
}
436+
static constexpr int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
437+
4, 8, 2, 4, 4, 4};
438+
for (int i = 0; i < Rank - 1; ++i) {
439+
globalStrides[i] = static_cast<uint64_t>(
440+
descriptor->strides[Rank - i - 2] * elementSizeInBytes[tensorDataType]);
436441
}
437442
}
438443

439-
} // namespace
440-
441444
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
442445
int64_t tensorRank, // Dimensionality of tensor
443-
void *ranked_descriptor, // Ranked MemRef descriptor
446+
void *rankedDescriptor, // Ranked MemRef descriptor
444447
const CUtensorMapDataType tensorDataType, // Stride size (in bytes)
445448
CUtensorMapInterleave interleave, // Type of interleaved layout
446449
CUtensorMapSwizzle swizzle, // Bank swizzling pattern
@@ -457,38 +460,36 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
457460
char *globalAddress = nullptr;
458461
switch (tensorRank) {
459462
case 1:
460-
mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim);
463+
mgpuGetMemRefDataAndShape<1>(rankedDescriptor, &globalAddress, globalDim,
464+
globalStrides, tensorDataType);
461465
break;
462466
case 2:
463-
mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim);
467+
mgpuGetMemRefDataAndShape<2>(rankedDescriptor, &globalAddress, globalDim,
468+
globalStrides, tensorDataType);
464469
break;
465470
case 3:
466-
mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim);
471+
mgpuGetMemRefDataAndShape<3>(rankedDescriptor, &globalAddress, globalDim,
472+
globalStrides, tensorDataType);
467473
break;
468474
case 4:
469-
mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim);
475+
mgpuGetMemRefDataAndShape<4>(rankedDescriptor, &globalAddress, globalDim,
476+
globalStrides, tensorDataType);
470477
break;
471478
case 5:
472-
mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim);
479+
mgpuGetMemRefDataAndShape<5>(rankedDescriptor, &globalAddress, globalDim,
480+
globalStrides, tensorDataType);
473481
break;
474482
default:
475483
fprintf(
476484
stderr,
477485
"'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n");
478-
return NULL;
486+
return nullptr;
479487
}
480488

481-
static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
482-
4, 8, 2, 4, 4, 4};
483489
for (int64_t r = 0; r < tensorRank; ++r) {
484-
elementStrides[r] = uint32_t(1);
485490
boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
486491
}
487492

488-
globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType];
489-
for (int r = 1; r < tensorRank - 1; r++)
490-
globalStrides[r] = globalStrides[r - 1] * globalDim[r];
491-
492493
ScopedContext scopedContext;
493494
mgpuTensorMapEncodeTiled(&tensorMap, tensorDataType, tensorRank32,
494495
globalAddress, globalDim, globalStrides, boxDim,
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_90 cubin-features=+ptx80 opt-level=3" \
3+
// RUN: | mlir-cpu-runner \
4+
// RUN: --shared-libs=%mlir_cuda_runtime \
5+
// RUN: --shared-libs=%mlir_runner_utils \
6+
// RUN: --shared-libs=%mlir_c_runner_utils \
7+
// RUN: --entry-point-result=void \
8+
// RUN: | FileCheck %s
9+
10+
// CHECK: Correct Results :8192
11+
// CHECK: Incorrect Results :0
12+
13+
module {
14+
func.func @main() {
15+
%c10000000 = arith.constant 10000000 : index
16+
%false = arith.constant false
17+
%c32768 = arith.constant 32768 : index
18+
%c31_i32 = arith.constant 31 : i32
19+
%c-1_i32 = arith.constant -1 : i32
20+
%c5_i32 = arith.constant 5 : i32
21+
%c0_i32 = arith.constant 0 : i32
22+
%c0 = arith.constant 0 : index
23+
%c8 = arith.constant 8 : index
24+
%c64 = arith.constant 64 : index
25+
%c2 = arith.constant 2 : index
26+
%c32768_i32 = arith.constant 32768 : i32
27+
%c128 = arith.constant 128 : index
28+
%c1 = arith.constant 1 : index
29+
%0 = llvm.mlir.constant(1 : i64) : i64
30+
%1 = llvm.mlir.constant(128 : i64) : i64
31+
%2 = llvm.mlir.constant(0 : i64) : i64
32+
%f0 = arith.constant 0.0 : f16
33+
%f123 = arith.constant 1.123 : f16
34+
35+
%srcMemref_host = memref.alloc() : memref<128x128xf16>
36+
%dstMemref_host = memref.alloc() : memref<128x128xf16>
37+
scf.for %arg0 = %c0 to %c128 step %c1 {
38+
scf.for %arg1 = %c0 to %c64 step %c1 {
39+
%d1 = arith.index_cast %arg0 : index to i32
40+
%d2 = arith.index_cast %arg1 : index to i32
41+
%d3 = arith.sitofp %d1 : i32 to f16
42+
%d4 = arith.sitofp %d2 : i32 to f16
43+
%d5 = arith.addf %d3, %f123 : f16
44+
%d6 = arith.constant 3.12 : f16
45+
%d7 = arith.mulf %d5, %d6 : f16
46+
%d8 = arith.addf %d7, %d5 : f16
47+
%d9 = arith.constant 0.178 : f16
48+
%d10 = arith.divf %d9, %d8 : f16
49+
memref.store %d10, %srcMemref_host[%arg0, %arg1] : memref<128x128xf16>
50+
memref.store %f0, %dstMemref_host[%arg0, %arg1] : memref<128x128xf16>
51+
}
52+
}
53+
54+
%s1 = gpu.wait async
55+
%srcMemref, %s2 = gpu.alloc async [%s1] () : memref<128x128xf16>
56+
%dstMemref, %s3 = gpu.alloc async [%s2] () : memref<128x128xf16>
57+
%s4 = gpu.memcpy async [%s3] %srcMemref, %srcMemref_host : memref<128x128xf16>, memref<128x128xf16>
58+
%s5 = gpu.memcpy async [%s4] %dstMemref, %dstMemref_host : memref<128x128xf16>, memref<128x128xf16>
59+
60+
%expand_shape = memref.expand_shape %srcMemref [[0, 1], [2, 3]] : memref<128x128xf16> into memref<2x64x2x64xf16>
61+
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3) -> (d0, d2, d1, d3) : memref<2x64x2x64xf16> to memref<2x2x64x64xf16, strided<[8192, 64, 128, 1]>>
62+
%cast = memref.cast %transpose : memref<2x2x64x64xf16, strided<[8192, 64, 128, 1]>> to memref<*xf16>
63+
%24 = nvgpu.tma.create.descriptor %cast box[%c2, %c2, %c64, %c64] : memref<*xf16> -> <tensor = memref<2x2x64x64xf16, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>
64+
65+
gpu.launch
66+
blocks(%arg2, %arg3, %arg4) in (%arg8 = %c1, %arg9 = %c1, %arg10 = %c1)
67+
threads(%arg5, %arg6, %arg7) in (%arg11 = %c128, %arg12 = %c1, %arg13 = %c1)
68+
dynamic_shared_memory_size %c32768_i32
69+
{
70+
%26 = gpu.dynamic_shared_memory : memref<?xi8, #gpu.address_space<workgroup>>
71+
%view = memref.view %26[%c0][] : memref<?xi8, #gpu.address_space<workgroup>> to memref<2x2x64x64xf16, #gpu.address_space<workgroup>>
72+
%27 = nvgpu.mbarrier.create -> <memorySpace = #gpu.address_space<workgroup>>
73+
%thread_id_x = gpu.thread_id x
74+
%28 = arith.index_cast %thread_id_x : index to i32
75+
%29 = arith.shrui %28, %c5_i32 : i32
76+
%30 = nvvm.shfl.sync idx %c-1_i32, %29, %c0_i32, %c31_i32 : i32 -> i32
77+
%31 = arith.cmpi eq, %30, %c0_i32 : i32
78+
%32 = nvvm.elect.sync -> i1
79+
%33 = arith.andi %31, %32 : i1
80+
scf.if %33 {
81+
nvgpu.mbarrier.init %27[%c0], %c1 : <memorySpace = #gpu.address_space<workgroup>>
82+
}
83+
%34 = nvvm.shfl.sync idx %c-1_i32, %29, %c0_i32, %c31_i32 : i32 -> i32
84+
%35 = arith.cmpi eq, %34, %c0_i32 : i32
85+
%36 = nvvm.elect.sync -> i1
86+
%37 = arith.andi %35, %36 : i1
87+
scf.if %37 {
88+
nvgpu.mbarrier.arrive.expect_tx %27[%c0], %c32768 : <memorySpace = #gpu.address_space<workgroup>>
89+
nvgpu.tma.async.load %24[%c0, %c0, %c0, %c0], %27[%c0] to %view : <tensor = memref<2x2x64x64xf16, 3>, swizzle = none, l2promo = none, oob = zero, interleave = none>, <memorySpace = #gpu.address_space<workgroup>> -> memref<2x2x64x64xf16, #gpu.address_space<workgroup>>
90+
}
91+
nvgpu.mbarrier.try_wait.parity %27[%c0], %false, %c10000000 : <memorySpace = #gpu.address_space<workgroup>>
92+
scf.for %arg14 = %c0 to %c2 step %c1 {
93+
scf.for %arg15 = %c0 to %c2 step %c1 {
94+
%38 = arith.muli %arg14, %c64 : index
95+
%39 = arith.muli %arg15, %c64 : index
96+
%subview = memref.subview %view[%arg14, %arg15, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<2x2x64x64xf16, #gpu.address_space<workgroup>> to memref<64x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>
97+
%subview_0 = memref.subview %dstMemref[%38, %39] [64, 64] [1, 1] : memref<128x128xf16> to memref<64x64xf16, strided<[128, 1], offset: ?>>
98+
%block_dim_x = gpu.block_dim x
99+
%thread_id_y = gpu.thread_id y
100+
%40 = arith.muli %thread_id_y, %block_dim_x : index
101+
%41 = arith.addi %thread_id_x, %40 : index
102+
%block_dim_y = gpu.block_dim y
103+
%42 = arith.muli %block_dim_x, %block_dim_y : index
104+
%thread_id_z = gpu.thread_id z
105+
%43 = arith.muli %thread_id_z, %42 : index
106+
%44 = arith.addi %41, %43 : index
107+
%45 = arith.cmpi eq, %44, %c0 : index
108+
scf.if %45 {
109+
scf.for %arg16 = %c0 to %c64 step %c1 {
110+
scf.for %arg17 = %c0 to %c64 step %c1 {
111+
%46 = memref.load %subview[%arg16, %arg17] : memref<64x64xf16, strided<[64, 1], offset: ?>, #gpu.address_space<workgroup>>
112+
memref.store %46, %subview_0[%arg16, %arg17] : memref<64x64xf16, strided<[128, 1], offset: ?>>
113+
}
114+
}
115+
}
116+
gpu.barrier
117+
}
118+
}
119+
gpu.terminator
120+
}
121+
122+
%s6 = gpu.memcpy async [%s5] %dstMemref_host, %dstMemref : memref<128x128xf16>, memref<128x128xf16>
123+
gpu.wait [%s6]
124+
125+
%errorCount, %correctCount = scf.for %arg0 = %c0 to %c128 step %c1 iter_args(%ec1 = %c0, %cc1 = %c0) -> (index,index) {
126+
%ec2, %cc2 = scf.for %arg1 = %c0 to %c64 step %c1 iter_args(%ec2 = %ec1, %cc2 = %cc1) -> (index, index) {
127+
%v1 = memref.load %dstMemref_host[%arg0, %arg1] : memref<128x128xf16>
128+
%v2 = memref.load %srcMemref_host[%arg0, %arg1] : memref<128x128xf16>
129+
%p = arith.cmpf one, %v1, %v2 : f16
130+
%ec3, %cc3 = scf.if %p -> (index, index) {
131+
%ec3 = arith.addi %ec2, %c1 : index
132+
scf.yield %ec3, %cc2 : index, index
133+
} else {
134+
%cc3 = arith.addi %cc2, %c1 : index
135+
scf.yield %ec2, %cc3 : index, index
136+
}
137+
scf.yield %ec3, %cc3 : index,index
138+
}
139+
scf.yield %ec2, %cc2 : index,index
140+
}
141+
142+
vector.print str "Correct Results :"
143+
vector.print %correctCount : index
144+
vector.print str "Incorrect Results :"
145+
vector.print %errorCount : index
146+
return
147+
}
148+
}

0 commit comments

Comments
 (0)