@@ -3590,6 +3590,9 @@ struct ggml_compute_params {
3590
3590
// work buffer for all threads
3591
3591
size_t wsize ;
3592
3592
void * wdata ;
3593
+
3594
+ // atomic counter used to distribute chunks of work
3595
+ atomic_int * aic ;
3593
3596
};
3594
3597
3595
3598
//
@@ -9754,6 +9757,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
9754
9757
const int ith = params -> ith ;
9755
9758
const int nth = params -> nth ;
9756
9759
9760
+ UNUSED (ith );
9761
+
9757
9762
GGML_ASSERT (ne02 == ne12 );
9758
9763
GGML_ASSERT (ne03 == ne13 );
9759
9764
GGML_ASSERT (ne2 == ne12 );
@@ -9867,50 +9872,57 @@ static void ggml_compute_forward_mul_mat_q_f32(
9867
9872
}
9868
9873
}
9869
9874
9875
+ atomic_store (params -> aic , 0 );
9876
+
9870
9877
return ;
9871
9878
}
9872
9879
9873
9880
if (params -> type == GGML_TASK_FINALIZE ) {
9874
9881
return ;
9875
9882
}
9876
9883
9884
+ void * wdata = params -> wdata ;
9885
+ const size_t row_size = ne00 * GGML_TYPE_SIZE [vec_dot_type ]/GGML_BLCK_SIZE [vec_dot_type ];
9886
+
9877
9887
// parallelize by src0 rows using ggml_vec_dot_q
9878
9888
9879
- // total rows in src0
9880
- const int nr = ne01 * ne02 * ne03 ;
9889
+ const int nr = ggml_nrows ( src0 );
9890
+ const int dr = ( nr + 8 * nth - 1 )/( 8 * nth ) ;
9881
9891
9882
- // rows per thread
9883
- const int dr = ( nr + nth - 1 )/ nth ;
9892
+ while (true) {
9893
+ const int ir0 = atomic_fetch_add ( params -> aic , dr ) ;
9884
9894
9885
- // row range for this thread
9886
- const int ir0 = dr * ith ;
9887
- const int ir1 = MIN (ir0 + dr , nr );
9895
+ for (int ir = ir0 ; ir < ir0 + dr ; ++ ir ) {
9896
+ if (ir >= nr ) {
9897
+ break ;
9898
+ }
9888
9899
9889
- void * wdata = params -> wdata ;
9890
- const size_t row_size = ne00 * GGML_TYPE_SIZE [vec_dot_type ]/GGML_BLCK_SIZE [vec_dot_type ];
9900
+ // src0 indices
9901
+ const int i03 = ir /(ne02 * ne01 );
9902
+ const int i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
9903
+ const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
9891
9904
9892
- for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
9893
- // src0 indices
9894
- const int i03 = ir /(ne02 * ne01 );
9895
- const int i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
9896
- const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
9905
+ const int i13 = i03 ;
9906
+ const int i12 = i02 ;
9897
9907
9898
- const int i13 = i03 ;
9899
- const int i12 = i02 ;
9908
+ const int i0 = i01 ;
9909
+ const int i2 = i02 ;
9910
+ const int i3 = i03 ;
9900
9911
9901
- const int i0 = i01 ;
9902
- const int i2 = i02 ;
9903
- const int i3 = i03 ;
9912
+ void * src0_row = (void * ) ((char * ) src0 -> data + (i01 * nb01 + i02 * nb02 + i03 * nb03 ));
9913
+ char * src1_col = ((char * ) wdata + ( (0 + i12 * ne11 + i13 * ne12 * ne11 )* row_size ));
9904
9914
9905
- void * src0_row = (void * ) ((char * ) src0 -> data + (i01 * nb01 + i02 * nb02 + i03 * nb03 ));
9906
- char * src1_col = ((char * ) wdata + ( (0 + i12 * ne11 + i13 * ne12 * ne11 )* row_size ));
9915
+ float * dst_col = (float * ) ((char * ) dst -> data + (i0 * nb0 + 0 * nb1 + i2 * nb2 + i3 * nb3 ));
9907
9916
9908
- float * dst_col = ( float * ) (( char * ) dst -> data + ( i0 * nb0 + 0 * nb1 + i2 * nb2 + i3 * nb3 ) );
9917
+ assert ( ne00 % 32 == 0 );
9909
9918
9910
- assert (ne00 % 32 == 0 );
9919
+ for (int64_t ic = 0 ; ic < ne11 ; ++ ic ) {
9920
+ vec_dot_q (ne00 , & dst_col [ic * ne0 ], src0_row , (void * ) (src1_col + ic * row_size ));
9921
+ }
9922
+ }
9911
9923
9912
- for ( int64_t ic = 0 ; ic < ne11 ; ++ ic ) {
9913
- vec_dot_q ( ne00 , & dst_col [ ic * ne0 ], src0_row , ( void * ) ( src1_col + ic * row_size )) ;
9924
+ if ( ir0 + dr >= nr ) {
9925
+ break ;
9914
9926
}
9915
9927
}
9916
9928
@@ -13749,6 +13761,7 @@ struct ggml_compute_state_shared {
13749
13761
13750
13762
// synchronization primitives
13751
13763
atomic_int n_ready ;
13764
+ atomic_int aic ;
13752
13765
atomic_bool has_work ;
13753
13766
atomic_bool stop ; // stop all threads
13754
13767
};
@@ -13817,6 +13830,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
13817
13830
/*.spin =*/ GGML_LOCK_INITIALIZER ,
13818
13831
/*.n_threads =*/ n_threads ,
13819
13832
/*.n_ready =*/ 0 ,
13833
+ /*.aic =*/ 0 ,
13820
13834
/*.has_work =*/ false,
13821
13835
/*.stop =*/ false,
13822
13836
};
@@ -13837,6 +13851,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
13837
13851
.nth = n_threads ,
13838
13852
.wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
13839
13853
.wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
13854
+ .aic = & state_shared .aic ,
13840
13855
},
13841
13856
.node = NULL ,
13842
13857
.shared = & state_shared ,
@@ -14126,6 +14141,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14126
14141
/*.nth =*/ node -> n_tasks ,
14127
14142
/*.wsize =*/ cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
14128
14143
/*.wdata =*/ cgraph -> work ? cgraph -> work -> data : NULL ,
14144
+ /*.aic =*/ & state_shared .aic ,
14129
14145
};
14130
14146
14131
14147
ggml_compute_forward (& params , node );
@@ -14149,6 +14165,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14149
14165
.nth = node -> n_tasks ,
14150
14166
.wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
14151
14167
.wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
14168
+ .aic = & state_shared .aic ,
14152
14169
};
14153
14170
workers [j ].node = node ;
14154
14171
}
@@ -14164,6 +14181,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14164
14181
}
14165
14182
14166
14183
params .type = GGML_TASK_COMPUTE ;
14184
+ params .aic = & state_shared .aic ;
14167
14185
ggml_compute_forward (& params , node );
14168
14186
14169
14187
// wait for thread pool
@@ -14204,6 +14222,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14204
14222
.nth = node -> n_tasks ,
14205
14223
.wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
14206
14224
.wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
14225
+ .aic = & state_shared .aic ,
14207
14226
};
14208
14227
workers [j ].node = node ;
14209
14228
}
0 commit comments