@@ -9279,23 +9279,22 @@ static void ggml_graph_compute_thread(void * data) {
9279
9279
9280
9280
void ggml_graph_compute (struct ggml_context * ctx , struct ggml_cgraph * cgraph ) {
9281
9281
const int n_threads = cgraph -> n_threads ;
9282
- struct ggml_compute_state * workers = n_threads > 1 ? alloca (sizeof (struct ggml_compute_state )* (n_threads - 1 )) : NULL ;
9282
+ const int max_requests = n_threads * 5 ;
9283
+ struct ggml_compute_state * workers = alloca (sizeof (struct ggml_compute_state )* (max_requests ));
9283
9284
9284
9285
// create thread pool
9285
- if (n_threads > 1 ) {
9286
- ctx -> tpool = thpool_init (n_threads );
9287
- for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
9288
- workers [j ] = (struct ggml_compute_state ) {
9289
- .params = {
9290
- .type = GGML_TASK_COMPUTE ,
9291
- .ith = j + 1 ,
9292
- .nth = n_threads ,
9293
- .wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
9294
- .wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
9295
- },
9296
- .node = NULL ,
9297
- };
9298
- }
9286
+ ctx -> tpool = thpool_init (n_threads );
9287
+ for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
9288
+ workers [j ] = (struct ggml_compute_state ) {
9289
+ .params = {
9290
+ .type = GGML_TASK_COMPUTE ,
9291
+ .ith = j + 1 ,
9292
+ .nth = n_threads ,
9293
+ .wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
9294
+ .wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
9295
+ },
9296
+ .node = NULL ,
9297
+ };
9299
9298
}
9300
9299
9301
9300
// initialize tasks + work buffer
@@ -9505,6 +9504,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9505
9504
9506
9505
const int64_t perf_start_cycles = ggml_perf_cycles ();
9507
9506
const int64_t perf_start_time_us = ggml_perf_time_us ();
9507
+ const size_t wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ;
9508
+ const void * wdata = cgraph -> work ? cgraph -> work -> data : NULL ;
9508
9509
9509
9510
for (int i = 0 ; i < cgraph -> n_nodes ; i ++ ) {
9510
9511
GGML_PRINT_DEBUG_5 ("%s: %d/%d\n" , __func__ , i , cgraph -> n_nodes );
@@ -9524,52 +9525,31 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9524
9525
const int64_t perf_node_start_cycles = ggml_perf_cycles ();
9525
9526
const int64_t perf_node_start_time_us = ggml_perf_time_us ();
9526
9527
9527
- // INIT
9528
- struct ggml_compute_params params = {
9529
- .type = GGML_TASK_INIT ,
9530
- .ith = 0 ,
9531
- .nth = node -> n_tasks ,
9532
- .wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
9533
- .wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
9534
- };
9535
-
9536
- ggml_compute_forward (& params , node );
9537
-
9538
9528
int next_task = 0 ;
9539
9529
9540
9530
// COMPUTE
9541
- if (node -> n_tasks > 1 ) {
9542
- // launch thread pool
9543
- for (int j = 0 ; j < n_threads - 1 ; j ++ ) {
9544
- workers [j ].params = (struct ggml_compute_params ) {
9545
- .type = GGML_TASK_COMPUTE ,
9546
- .ith = j + 1 ,
9547
- .nth = node -> n_tasks ,
9548
- .wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
9549
- .wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
9550
- };
9551
- workers [j ].node = node ;
9552
- thpool_add_work (ctx -> tpool , ggml_graph_compute_thread , & workers [j ]);
9553
- }
9554
- }
9555
- else
9556
9531
{
9557
9532
int start = i ;
9558
- int end = i + 1 ;
9559
- while (end < cgraph -> n_nodes && next_task < n_threads && (end - start ) < n_threads * 2 )
9533
+ int end = i ;
9534
+ while (end < cgraph -> n_nodes && (end - start ) < n_threads * 2 )
9560
9535
{
9561
9536
struct ggml_tensor * next = cgraph -> nodes [end ];
9562
9537
end ++ ;
9563
9538
9564
- if (next -> n_tasks != 1 )
9539
+ // already scheduled
9540
+ if (next -> n_tasks == 0 )
9541
+ continue ;
9542
+
9543
+ // if we have slots
9544
+ if (next_task + next -> n_tasks > max_requests )
9565
9545
continue ;
9566
9546
9567
9547
// check src depedency
9568
9548
bool is_dep = false;
9569
9549
for (int k = start ; k < end ; k ++ )
9570
9550
{
9571
- struct ggml_tensor * node = cgraph -> nodes [k ];
9572
- if (next -> src0 == node || next -> src1 == node )
9551
+ struct ggml_tensor * prev = cgraph -> nodes [k ];
9552
+ if (next -> src0 == prev || next -> src1 == prev )
9573
9553
{
9574
9554
is_dep = true;
9575
9555
break ;
@@ -9579,29 +9559,42 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9579
9559
if (is_dep )
9580
9560
continue ;
9581
9561
9582
- workers [next_task ].params = (struct ggml_compute_params ) {
9583
- .type = GGML_TASK_INIT ,
9584
- .ith = 0 ,
9585
- .nth = 1 ,
9586
- .wsize = 0 ,
9587
- .wdata = NULL ,
9588
- };
9589
- workers [next_task ].node = next ;
9562
+ if (next -> n_tasks > 1 )
9563
+ {
9564
+ // run INIT in main thread if it is multi thread operator
9565
+ struct ggml_compute_params params = {
9566
+ .type = GGML_TASK_INIT ,
9567
+ .ith = 0 ,
9568
+ .nth = next -> n_tasks ,
9569
+ .wsize = wsize ,
9570
+ .wdata = wdata ,
9571
+ };
9572
+
9573
+ ggml_compute_forward (& params , next );
9574
+ }
9590
9575
9591
- thpool_add_work (ctx -> tpool , ggml_graph_compute_thread , & workers [next_task ]);
9576
+ for (int j = 0 ; j < next -> n_tasks ; j ++ ) {
9577
+ workers [next_task ].params = (struct ggml_compute_params ){
9578
+ // single thread operator runs INIT in worker thread
9579
+ .type = next -> n_tasks == 1 ? GGML_TASK_INIT : GGML_TASK_COMPUTE ,
9580
+ .ith = j ,
9581
+ .nth = next -> n_tasks ,
9582
+
9583
+ // TODO: Potential race on wdata
9584
+ .wsize = wsize ,
9585
+ .wdata = wdata ,
9586
+ };
9587
+ workers [next_task ].node = next ;
9588
+
9589
+ thpool_add_work (ctx -> tpool , ggml_graph_compute_thread , & workers [next_task ]);
9590
+ next_task ++ ;
9591
+ }
9592
9592
next -> n_tasks = 0 ; // indicate this node is caculated
9593
- next_task ++ ;
9594
- //printf("Combine task [%d, %d]\n", start, end);
9595
9593
}
9596
9594
}
9597
9595
9598
- params .type = GGML_TASK_COMPUTE ;
9599
- ggml_compute_forward (& params , node );
9600
-
9601
9596
// wait for thread pool
9602
- if (node -> n_tasks > 1 || next_task != 0 ) {
9603
- thpool_wait (ctx -> tpool );
9604
- }
9597
+ thpool_wait (ctx -> tpool );
9605
9598
#if 0
9606
9599
// FINALIZE
9607
9600
if (node -> n_tasks > 1 ) {
0 commit comments