Skip to content

Commit c2691d9

Browse files
committed
disable for multi-gpu and batch size > 1
1 parent 800f4fe commit c2691d9

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

ggml-cuda.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24362436
// pointer to CUDA cpy kernel, which is required to identify
24372437
// kernel parameters which need updated in the graph for each token
24382438
void* ggmlCudaCpyFn = nullptr;
2439+
2440+
if(ggml_backend_cuda_get_device_count() > 1){
2441+
useCudaGraph = false; // disable CUDA graphs for multi-gpu for now. TO DO investigate
2442+
}
2443+
24392444
if(useCudaGraph) {
24402445

24412446
if(cudaGraph.instance == nullptr) cudaGraphUpdateRequired=true;
@@ -2447,6 +2452,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24472452
// Identify if the graph needs updated for this token due to the number of elements changing
24482453
// (identified by inspecting soft max op parameters)
24492454
if(node->op == GGML_OP_SOFT_MAX) {
2455+
if(node->src[1]->ne[1] > 1){
2456+
useCudaGraph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate
2457+
}
24502458
if(node->src[0]->ne[0] != cudaGraph.softmax_ne0) {
24512459
cudaGraphUpdateRequired = true;
24522460
cudaGraph.softmax_ne0 = node->src[0]->ne[0];

0 commit comments

Comments
 (0)