@@ -9249,16 +9249,10 @@ typedef int ggml_lock_t;
9249
9249
9250
9250
#endif
9251
9251
9252
- struct ggml_compute_state_shared {
9253
- int n_threads ;
9254
- };
9255
-
9256
9252
struct ggml_compute_state {
9257
9253
9258
9254
struct ggml_compute_params params ;
9259
9255
struct ggml_tensor * node ;
9260
-
9261
- struct ggml_compute_state_shared * shared ;
9262
9256
};
9263
9257
9264
9258
static void ggml_graph_compute_thread (void * data ) {
@@ -9284,9 +9278,6 @@ static void ggml_graph_compute_thread(void * data) {
9284
9278
9285
9279
void ggml_graph_compute (struct ggml_context * ctx , struct ggml_cgraph * cgraph ) {
9286
9280
const int n_threads = cgraph -> n_threads ;
9287
- struct ggml_compute_state_shared state_shared = {
9288
- /*.n_threads =*/ n_threads ,
9289
- };
9290
9281
struct ggml_compute_state * workers = n_threads > 1 ? alloca (sizeof (struct ggml_compute_state )* (n_threads - 1 )) : NULL ;
9291
9282
9292
9283
// create thread pool
@@ -9302,7 +9293,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9302
9293
.wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
9303
9294
},
9304
9295
.node = NULL ,
9305
- .shared = & state_shared ,
9306
9296
};
9307
9297
}
9308
9298
}
@@ -9520,6 +9510,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9520
9510
9521
9511
struct ggml_tensor * node = cgraph -> nodes [i ];
9522
9512
9513
+ if (node -> n_tasks == 0 )
9514
+ {
9515
+ // no work need to be done.
9516
+ continue ;
9517
+ }
9523
9518
// TODO: this could be used to avoid unnecessary computations, but it needs to be improved
9524
9519
//if (node->grad == NULL && node->perf_runs > 0) {
9525
9520
// continue;
@@ -9558,46 +9553,45 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9558
9553
}
9559
9554
else
9560
9555
{
9561
- if (i + 1 < cgraph -> n_nodes )
9556
+ int start = i ;
9557
+ int end = i + 1 ;
9558
+ while (end < cgraph -> n_nodes && next_task < n_threads && (end - start ) < n_threads * 2 )
9562
9559
{
9563
- struct ggml_tensor * next = cgraph -> nodes [i + 1 ];
9564
- if (next -> src0 != node && next -> src1 != node && next -> n_tasks == 1 )
9560
+ struct ggml_tensor * next = cgraph -> nodes [end ];
9561
+ end ++ ;
9562
+
9563
+ if (next -> n_tasks != 1 )
9564
+ continue ;
9565
+
9566
+ // check src depedency
9567
+ bool is_dep = false;
9568
+ for (int k = start ; k < end ; k ++ )
9565
9569
{
9566
- workers [next_task ].params = (struct ggml_compute_params ) {
9567
- .type = GGML_TASK_COMPUTE | GGML_TASK_INIT ,
9568
- .ith = 0 ,
9569
- .nth = 1 ,
9570
- .wsize = 0 ,
9571
- .wdata = NULL ,
9572
- };
9573
- workers [next_task ].node = next ;
9574
- thpool_add_work (ctx -> tpool , ggml_graph_compute_thread , & workers [next_task ]);
9575
- next_task ++ ;
9576
-
9577
- if (i + 2 < cgraph -> n_nodes )
9570
+ struct ggml_tensor * node = cgraph -> nodes [k ];
9571
+ if (next -> src0 == node || next -> src1 == node )
9578
9572
{
9579
- struct ggml_tensor * prev = cgraph -> nodes [i + 1 ];
9580
- struct ggml_tensor * next = cgraph -> nodes [i + 2 ];
9581
- if (next -> src0 != node && next -> src1 != node && next -> n_tasks == 1 &&
9582
- next -> src0 != prev && next -> src1 != prev
9583
- )
9584
- {
9585
- workers [next_task ].params = (struct ggml_compute_params ) {
9586
- .type = GGML_TASK_COMPUTE | GGML_TASK_INIT ,
9587
- .ith = 0 ,
9588
- .nth = 1 ,
9589
- .wsize = 0 ,
9590
- .wdata = NULL ,
9591
- };
9592
- workers [next_task ].node = next ;
9593
- thpool_add_work (ctx -> tpool , ggml_graph_compute_thread , & workers [next_task ]);
9594
- next_task ++ ;
9595
- }
9573
+ is_dep = true;
9574
+ break ;
9596
9575
}
9597
9576
}
9598
- }
9599
9577
9578
+ if (is_dep )
9579
+ continue ;
9600
9580
9581
+ workers [next_task ].params = (struct ggml_compute_params ) {
9582
+ .type = GGML_TASK_COMPUTE | GGML_TASK_INIT ,
9583
+ .ith = 0 ,
9584
+ .nth = 1 ,
9585
+ .wsize = 0 ,
9586
+ .wdata = NULL ,
9587
+ };
9588
+ workers [next_task ].node = next ;
9589
+
9590
+ thpool_add_work (ctx -> tpool , ggml_graph_compute_thread , & workers [next_task ]);
9591
+ next -> n_tasks = 0 ; // indicate this node is caculated
9592
+ next_task ++ ;
9593
+ //printf("Combine task [%d, %d]\n", start, end);
9594
+ }
9601
9595
}
9602
9596
9603
9597
params .type = GGML_TASK_COMPUTE ;
0 commit comments