Skip to content

Commit 1d8966e

Browse files
authored
[flang][cuda] Use the provided stream in kernel launch (#135267)
1 parent 1cd5926 commit 1d8966e

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

flang-rt/lib/cuda/kernel.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "flang/Runtime/CUDA/kernel.h"
10+
#include "flang-rt/runtime/descriptor.h"
1011
#include "flang-rt/runtime/terminator.h"
1112
#include "flang/Runtime/CUDA/common.h"
1213

@@ -74,9 +75,9 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
7475
Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
7576
terminator.Crash("Too many invalid grid dimensions");
7677
}
77-
cudaStream_t cuStream = 0; // TODO stream managment
78-
CUDA_REPORT_IF_ERROR(
79-
cudaLaunchKernel(kernel, gridDim, blockDim, params, smem, cuStream));
78+
cudaStream_t defaultStream = 0;
79+
CUDA_REPORT_IF_ERROR(cudaLaunchKernel(kernel, gridDim, blockDim, params, smem,
80+
stream != kNoAsyncId ? (cudaStream_t)stream : defaultStream));
8081
}
8182

8283
void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
@@ -140,7 +141,11 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
140141
terminator.Crash("Too many invalid grid dimensions");
141142
}
142143
config.dynamicSmemBytes = smem;
143-
config.stream = 0; // TODO stream managment
144+
if (stream != kNoAsyncId) {
145+
config.stream = (cudaStream_t)stream;
146+
} else {
147+
config.stream = 0;
148+
}
144149
cudaLaunchAttribute launchAttr[1];
145150
launchAttr[0].id = cudaLaunchAttributeClusterDimension;
146151
launchAttr[0].val.clusterDim.x = clusterX;
@@ -212,9 +217,10 @@ void RTDEF(CUFLaunchCooperativeKernel)(const void *kernel, intptr_t gridX,
212217
Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
213218
terminator.Crash("Too many invalid grid dimensions");
214219
}
215-
cudaStream_t cuStream = 0; // TODO stream managment
216-
CUDA_REPORT_IF_ERROR(cudaLaunchCooperativeKernel(
217-
kernel, gridDim, blockDim, params, smem, cuStream));
220+
cudaStream_t defaultStream = 0;
221+
CUDA_REPORT_IF_ERROR(
222+
cudaLaunchCooperativeKernel(kernel, gridDim, blockDim, params, smem,
223+
stream != kNoAsyncId ? (cudaStream_t)stream : defaultStream));
218224
}
219225

220226
} // extern "C"

0 commit comments

Comments
 (0)