Skip to content

Commit 20861f1

Browse files
authored
[mlir][gpu] Use alloc OP's host_shared in cuda runtime (#99035)
1 parent 3fe50b6 commit 20861f1

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,18 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventRecord(CUevent event,
237237
}
238238

239239
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
240-
mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/, bool /*isHostShared*/) {
240+
mgpuMemAlloc(uint64_t sizeBytes, CUstream stream, bool isHostShared) {
241241
ScopedContext scopedContext;
242242
CUdeviceptr ptr = 0;
243-
if (sizeBytes != 0)
244-
CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
243+
if (sizeBytes == 0)
244+
return reinterpret_cast<void *>(ptr);
245+
246+
if (isHostShared) {
247+
CUDA_REPORT_IF_ERROR(
248+
cuMemAllocManaged(&ptr, sizeBytes, CU_MEM_ATTACH_GLOBAL));
249+
return reinterpret_cast<void *>(ptr);
250+
}
251+
CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
245252
return reinterpret_cast<void *>(ptr);
246253
}
247254

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-format=%gpu_compilation_format" \
3+
// RUN: | mlir-cpu-runner \
4+
// RUN: --shared-libs=%mlir_cuda_runtime \
5+
// RUN: --shared-libs=%mlir_runner_utils \
6+
// RUN: --entry-point-result=void \
7+
// RUN: | FileCheck %s
8+
9+
// CHECK: 2000
10+
module attributes {gpu.container_module} {
11+
func.func @main() {
12+
%c1 = arith.constant 1 : index
13+
%c0 = arith.constant 0 : index
14+
%c1000_i32 = arith.constant 1000 : i32
15+
%memref = gpu.alloc host_shared () : memref<1xi32>
16+
memref.store %c1000_i32, %memref[%c1] : memref<1xi32>
17+
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1) threads(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) {
18+
%1 = memref.load %memref[%c1] : memref<1xi32>
19+
%2 = arith.addi %1, %1 : i32
20+
memref.store %2, %memref[%c1] : memref<1xi32>
21+
gpu.terminator
22+
}
23+
%0 = memref.load %memref[%c1] : memref<1xi32>
24+
vector.print %0 : i32
25+
return
26+
}
27+
}

0 commit comments

Comments
 (0)