File tree Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Expand file tree Collapse file tree 1 file changed +18
-2
lines changed Original file line number Diff line number Diff line change @@ -2419,9 +2419,12 @@ struct ggml_cudaGraph {
2419
2419
int softmax_ne0 = 0 ;
2420
2420
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
2421
2421
cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
2422
+ bool disableDueToGpuArch=false ;
2422
2423
};
2423
2424
#endif
2424
2425
2426
+ const bool disableCudaGraphs = (getenv(" LLAMACPP_DISABLE_CUDA_GRAPHS" ) != nullptr );
2427
+
2425
2428
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
2426
2429
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context ;
2427
2430
@@ -2437,8 +2440,21 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
2437
2440
// kernel parameters which need updated in the graph for each token
2438
2441
void * ggmlCudaCpyFn = nullptr ;
2439
2442
2440
- if (ggml_backend_cuda_get_device_count () > 1 ){
2441
- useCudaGraph = false ; // disable CUDA graphs for multi-gpu for now. TO DO investigate
2443
+
2444
+ if (cudaGraph.count ==0 ){
2445
+ cudaDeviceProp prop;
2446
+ int device;
2447
+ cudaGetDevice (&device);
2448
+ cudaGetDeviceProperties (&prop, device);
2449
+ if (prop.major < 8 ){
2450
+ cudaGraph.disableDueToGpuArch =true ;
2451
+ }
2452
+ }
2453
+
2454
+ // Disable CUDA graphs in presence of env var or old GPU.
2455
+ // Also disable for multi-gpu for now. TO DO investigate
2456
+ if (disableCudaGraphs || cudaGraph.disableDueToGpuArch || ggml_backend_cuda_get_device_count () > 1 ){
2457
+ useCudaGraph = false ;
2442
2458
}
2443
2459
2444
2460
if (useCudaGraph) {
You can’t perform that action at this time.
0 commit comments