Skip to content

Commit d403b18

Browse files
committed
Addressed comments
1 parent c3d4ead commit d403b18

File tree

1 file changed

+66
-63
lines changed

1 file changed

+66
-63
lines changed

ggml-cuda.cu

Lines changed: 66 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,19 +2411,19 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
24112411

24122412
#ifdef USE_CUDA_GRAPH
24132413
#define MAX_NODES_IN_CUDA_GRAPH 10000
2414-
struct ggml_cudaGraph {
2415-
int count=0;
2414+
struct ggml_cuda_graph {
2415+
int count = 0;
24162416
cudaGraph_t graph = nullptr;
24172417
cudaGraphExec_t instance = nullptr;
2418-
size_t numNodes = 0;
2418+
size_t num_nodes = 0;
24192419
int softmax_ne0 = 0;
24202420
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
24212421
cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
2422-
bool disableDueToGpuArch=false;
2422+
bool disable_due_to_gpu_arch = false;
24232423
};
24242424
#endif
24252425

2426-
const bool disableCudaGraphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr);
2426+
const bool disable_cuda_graphs = (getenv("LLAMACPP_DISABLE_CUDA_GRAPHS") != nullptr);
24272427

24282428
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
24292429
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
@@ -2432,33 +2432,29 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24322432

24332433
#ifdef USE_CUDA_GRAPH
24342434
// Objects required for CUDA Graph
2435-
static ggml_cudaGraph cudaGraph;
2436-
bool useCudaGraph = (cudaGraph.count>=7); //avoid CUDA graphs on first few steps due to incompatible initialisations.
2437-
char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH];
2438-
bool cudaGraphUpdateRequired = false;
2435+
static ggml_cuda_graph cuda_graph;
2436+
bool use_cuda_graph = (cuda_graph.count >= 7); //avoid CUDA graphs on first few steps due to incompatible initialisations.
2437+
char ** updated_kernel_arg[MAX_NODES_IN_CUDA_GRAPH];
2438+
bool cuda_graph_update_required = false;
24392439
// pointer to CUDA cpy kernel, which is required to identify
24402440
// kernel parameters which need updated in the graph for each token
2441-
void* ggmlCudaCpyFn = nullptr;
2441+
void * ggml_cuda_cpy_fn_ptr = nullptr;
24422442

2443-
if(cudaGraph.count==0){
2444-
cudaDeviceProp prop;
2445-
int device;
2446-
CUDA_CHECK(cudaGetDevice(&device));
2447-
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
2448-
if (prop.major < 8){
2449-
cudaGraph.disableDueToGpuArch=true;
2443+
if(cuda_graph.count == 0){
2444+
if (ggml_cuda_info().devices[cuda_ctx->device].cc < 800){
2445+
cuda_graph.disable_due_to_gpu_arch=true;
24502446
}
24512447
}
24522448

24532449
// Disable CUDA graphs in presence of env var or old GPU.
24542450
// Also disable for multi-gpu for now. TO DO investigate
2455-
if(disableCudaGraphs || cudaGraph.disableDueToGpuArch || ggml_backend_cuda_get_device_count() > 1){
2456-
useCudaGraph = false;
2451+
if(disable_cuda_graphs || cuda_graph.disable_due_to_gpu_arch || ggml_backend_cuda_get_device_count() > 1){
2452+
use_cuda_graph = false;
24572453
}
24582454

2459-
if(useCudaGraph) {
2455+
if(use_cuda_graph) {
24602456

2461-
if(cudaGraph.instance == nullptr) cudaGraphUpdateRequired=true;
2457+
if(cuda_graph.instance == nullptr) cuda_graph_update_required=true;
24622458

24632459
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
24642460
int k=0;
@@ -2468,36 +2464,36 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
24682464
// (identified by inspecting soft max op parameters)
24692465
if(node->op == GGML_OP_SOFT_MAX) {
24702466
if(node->src[1]->ne[1] > 1){
2471-
useCudaGraph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate
2467+
use_cuda_graph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate
24722468
}
2473-
if(node->src[0]->ne[0] != cudaGraph.softmax_ne0) {
2474-
cudaGraphUpdateRequired = true;
2475-
cudaGraph.softmax_ne0 = node->src[0]->ne[0];
2469+
if(node->src[0]->ne[0] != cuda_graph.softmax_ne0) {
2470+
cuda_graph_update_required = true;
2471+
cuda_graph.softmax_ne0 = node->src[0]->ne[0];
24762472
}
24772473
}
24782474
if(node->op == GGML_OP_CPY) {
24792475
// store the copy op parameter which changes with each token.
2480-
updatedKernelArg[k++]=(char**) &(node->src[1]->data);
2481-
if(ggmlCudaCpyFn == nullptr){
2476+
updated_kernel_arg[k++]=(char **) &(node->src[1]->data);
2477+
if(ggml_cuda_cpy_fn_ptr == nullptr){
24822478
// store a pointer to the copy op CUDA kernel to identify it later
2483-
ggmlCudaCpyFn = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
2479+
ggml_cuda_cpy_fn_ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
24842480
}
24852481
}
24862482
}
24872483
}
24882484

2489-
if(useCudaGraph && cudaGraphUpdateRequired) { // Start CUDA graph capture
2485+
if(use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
24902486
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal));
24912487
}
24922488

24932489
#else
2494-
bool useCudaGraph = false;
2495-
bool cudaGraphUpdateRequired = false;
2490+
bool use_cuda_graph = false;
2491+
bool cuda_graph_update_required = false;
24962492
#endif
24972493

24982494
// Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph.
24992495
// With use of CUDA graphs, the execution will be performed by the graph launch.
2500-
if(!useCudaGraph || cudaGraphUpdateRequired) {
2496+
if(!use_cuda_graph || cuda_graph_update_required) {
25012497
//temporarily avoid indenting here to make code review easier
25022498
for (int i = 0; i < cgraph->n_nodes; i++) {
25032499
ggml_tensor * node = cgraph->nodes[i];
@@ -2524,67 +2520,74 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
25242520
}
25252521

25262522
#ifdef USE_CUDA_GRAPH
2527-
if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture
2528-
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph));
2523+
if(use_cuda_graph && (cuda_graph_update_required)) { // End CUDA graph capture
2524+
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_graph.graph));
25292525
}
2530-
if(useCudaGraph){
2526+
if(use_cuda_graph){
25312527

2532-
if(cudaGraph.instance == nullptr) { // Create executable graph from captured graph.
2533-
CUDA_CHECK(cudaGraphInstantiate(&cudaGraph.instance, cudaGraph.graph, NULL, NULL, 0));
2528+
if(cuda_graph.instance == nullptr) { // Create executable graph from captured graph.
2529+
CUDA_CHECK(cudaGraphInstantiate(&cuda_graph.instance, cuda_graph.graph, NULL, NULL, 0));
25342530
}
25352531

25362532

25372533
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
25382534

2539-
if(cudaGraphUpdateRequired) {
2535+
if(cuda_graph_update_required) {
25402536
// Extract nodes from graph
2541-
if(cudaGraph.numNodes == 0) {
2542-
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes));
2537+
if(cuda_graph.num_nodes == 0) {
2538+
// First call with null argument gets number of nodes in graph
2539+
CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, nullptr, &cuda_graph.num_nodes));
25432540
}
2544-
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, cudaGraph.nodes, &cudaGraph.numNodes));
2541+
// Subsequent call with non-null argument gets nodes
2542+
CUDA_CHECK(cudaGraphGetNodes(cuda_graph.graph, cuda_graph.nodes, &cuda_graph.num_nodes));
25452543

25462544
// Loop over nodes, and extract kernel parameters fro each node
2547-
for(size_t i=0; i<cudaGraph.numNodes; i++) {
2548-
cudaGraphNodeType nodeType;
2549-
CUDA_CHECK(cudaGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
2550-
if (nodeType == cudaGraphNodeTypeKernel) {
2551-
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.params[i]); // Get params using runtime
2552-
if(statRT == cudaErrorInvalidDeviceFunction) {
2545+
for(size_t i=0; i<cuda_graph.num_nodes; i++) {
2546+
cudaGraphNodeType node_type;
2547+
CUDA_CHECK(cudaGraphNodeGetType(cuda_graph.nodes[i], &node_type));
2548+
if (node_type == cudaGraphNodeTypeKernel) {
2549+
auto stat = cudaGraphKernelNodeGetParams(cuda_graph.nodes[i], &cuda_graph.params[i]); // Get params using runtime
2550+
if(stat == cudaErrorInvalidDeviceFunction) {
25532551
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
25542552
// We don't need to update blas nodes, so clear error and move on.
25552553
cudaGetLastError();
25562554
}
2555+
else {
2556+
GGML_ASSERT(stat == cudaSuccess);
2557+
}
25572558
}
25582559
}
25592560
}
25602561

2561-
// Update copy kernel param (required every token)
2562-
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
2562+
// One of the arguments to the copy kernel is updated for each token, hence we need to
2563+
// replace that argument with the updated value in the CUDA graph
2564+
if(!cuda_graph_update_required) { // on update steps, the live parameters will already be captured
25632565
int k=0;
2564-
for(size_t i=0; i<cudaGraph.numNodes; i++) {
2565-
if(cudaGraph.params[i].func == ggmlCudaCpyFn) {
2566-
char** updatedKernelArgPointer = updatedKernelArg[k++];
2567-
cudaGraph.params[i].kernelParams[1] = updatedKernelArgPointer;
2568-
CUDA_CHECK(cudaGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.params[i]));
2566+
for(size_t i=0; i<cuda_graph.num_nodes; i++) {
2567+
if(cuda_graph.params[i].func == ggml_cuda_cpy_fn_ptr) {
2568+
char ** updated_kernel_arg_ptr = updated_kernel_arg[k++];
2569+
cuda_graph.params[i].kernelParams[1] = updated_kernel_arg_ptr;
2570+
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_graph.nodes[i], &cuda_graph.params[i]));
25692571
}
25702572
}
25712573
}
25722574

25732575
// Update graph executable
2574-
cudaGraphExecUpdateResultInfo resultInfo;
2575-
auto stat = cudaGraphExecUpdate(cudaGraph.instance, cudaGraph.graph, &resultInfo);
2576-
if(stat == cudaErrorGraphExecUpdateFailure)
2577-
{
2576+
cudaGraphExecUpdateResultInfo result_info;
2577+
auto stat = cudaGraphExecUpdate(cuda_graph.instance, cuda_graph.graph, &result_info);
2578+
if(stat == cudaErrorGraphExecUpdateFailure) {
25782579
// The pre-existing graph exec cannot be updated due to violated constraints
2579-
// so instead clar error and re-instantiate
2580+
// so instead clear error and re-instantiate
25802581
cudaGetLastError();
2581-
CUDA_CHECK(cudaGraphInstantiate(&cudaGraph.instance, cudaGraph.graph, NULL, NULL, 0));
2582+
CUDA_CHECK(cudaGraphInstantiate(&cuda_graph.instance, cuda_graph.graph, NULL, NULL, 0));
2583+
}
2584+
else {
2585+
GGML_ASSERT(stat == cudaSuccess);
25822586
}
2583-
25842587
// Launch graph
2585-
CUDA_CHECK(cudaGraphLaunch(cudaGraph.instance, cuda_ctx->stream()));
2588+
CUDA_CHECK(cudaGraphLaunch(cuda_graph.instance, cuda_ctx->stream()));
25862589
}
2587-
cudaGraph.count++;
2590+
cuda_graph.count++;
25882591
#endif
25892592
return GGML_STATUS_SUCCESS;
25902593
}

0 commit comments

Comments
 (0)