@@ -748,7 +748,7 @@ int32x4_t vcvtnq_s32_f32(float32x4_t v) {
748
748
749
749
#define QK4_0 32
750
750
typedef struct {
751
- float d ; // delta
751
+ float d ; // delta
752
752
uint8_t qs [QK4_0 / 2 ]; // nibbles / quants
753
753
} block_q4_0 ;
754
754
static_assert (sizeof (block_q4_0 ) == sizeof (float ) + QK4_0 / 2 , "wrong q4_0 block size/padding" );
@@ -3575,6 +3575,9 @@ struct ggml_compute_params {
3575
3575
// work buffer for all threads
3576
3576
size_t wsize ;
3577
3577
void * wdata ;
3578
+
3579
+ // atomic counter used to distribute chunks of work
3580
+ atomic_int * aic ;
3578
3581
};
3579
3582
3580
3583
//
@@ -9739,6 +9742,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
9739
9742
const int ith = params -> ith ;
9740
9743
const int nth = params -> nth ;
9741
9744
9745
+ UNUSED (ith );
9746
+
9742
9747
GGML_ASSERT (ne02 == ne12 );
9743
9748
GGML_ASSERT (ne03 == ne13 );
9744
9749
GGML_ASSERT (ne2 == ne12 );
@@ -9852,50 +9857,57 @@ static void ggml_compute_forward_mul_mat_q_f32(
9852
9857
}
9853
9858
}
9854
9859
9860
+ atomic_store (params -> aic , 0 );
9861
+
9855
9862
return ;
9856
9863
}
9857
9864
9858
9865
if (params -> type == GGML_TASK_FINALIZE ) {
9859
9866
return ;
9860
9867
}
9861
9868
9869
+ void * wdata = params -> wdata ;
9870
+ const size_t row_size = ne00 * GGML_TYPE_SIZE [vec_dot_type ]/GGML_BLCK_SIZE [vec_dot_type ];
9871
+
9862
9872
// parallelize by src0 rows using ggml_vec_dot_q
9863
9873
9864
- // total rows in src0
9865
- const int nr = ne01 * ne02 * ne03 ;
9874
+ const int nr = ggml_nrows ( src0 );
9875
+ const int dr = ( nr + 8 * nth - 1 )/( 8 * nth ) ;
9866
9876
9867
- // rows per thread
9868
- const int dr = ( nr + nth - 1 )/ nth ;
9877
+ while (true) {
9878
+ const int ir0 = atomic_fetch_add ( params -> aic , dr ) ;
9869
9879
9870
- // row range for this thread
9871
- const int ir0 = dr * ith ;
9872
- const int ir1 = MIN (ir0 + dr , nr );
9880
+ for (int ir = ir0 ; ir < ir0 + dr ; ++ ir ) {
9881
+ if (ir >= nr ) {
9882
+ break ;
9883
+ }
9873
9884
9874
- void * wdata = params -> wdata ;
9875
- const size_t row_size = ne00 * GGML_TYPE_SIZE [vec_dot_type ]/GGML_BLCK_SIZE [vec_dot_type ];
9885
+ // src0 indices
9886
+ const int i03 = ir /(ne02 * ne01 );
9887
+ const int i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
9888
+ const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
9876
9889
9877
- for (int ir = ir0 ; ir < ir1 ; ++ ir ) {
9878
- // src0 indices
9879
- const int i03 = ir /(ne02 * ne01 );
9880
- const int i02 = (ir - i03 * ne02 * ne01 )/ne01 ;
9881
- const int i01 = (ir - i03 * ne02 * ne01 - i02 * ne01 );
9890
+ const int i13 = i03 ;
9891
+ const int i12 = i02 ;
9882
9892
9883
- const int i13 = i03 ;
9884
- const int i12 = i02 ;
9893
+ const int i0 = i01 ;
9894
+ const int i2 = i02 ;
9895
+ const int i3 = i03 ;
9885
9896
9886
- const int i0 = i01 ;
9887
- const int i2 = i02 ;
9888
- const int i3 = i03 ;
9897
+ void * src0_row = (void * ) ((char * ) src0 -> data + (i01 * nb01 + i02 * nb02 + i03 * nb03 ));
9898
+ char * src1_col = ((char * ) wdata + ( (0 + i12 * ne11 + i13 * ne12 * ne11 )* row_size ));
9889
9899
9890
- void * src0_row = (void * ) ((char * ) src0 -> data + (i01 * nb01 + i02 * nb02 + i03 * nb03 ));
9891
- char * src1_col = ((char * ) wdata + ( (0 + i12 * ne11 + i13 * ne12 * ne11 )* row_size ));
9900
+ float * dst_col = (float * ) ((char * ) dst -> data + (i0 * nb0 + 0 * nb1 + i2 * nb2 + i3 * nb3 ));
9892
9901
9893
- float * dst_col = ( float * ) (( char * ) dst -> data + ( i0 * nb0 + 0 * nb1 + i2 * nb2 + i3 * nb3 ) );
9902
+ assert ( ne00 % 32 == 0 );
9894
9903
9895
- assert (ne00 % 32 == 0 );
9904
+ for (int64_t ic = 0 ; ic < ne11 ; ++ ic ) {
9905
+ vec_dot_q (ne00 , & dst_col [ic * ne0 ], src0_row , (void * ) (src1_col + ic * row_size ));
9906
+ }
9907
+ }
9896
9908
9897
- for ( int64_t ic = 0 ; ic < ne11 ; ++ ic ) {
9898
- vec_dot_q ( ne00 , & dst_col [ ic * ne0 ], src0_row , ( void * ) ( src1_col + ic * row_size )) ;
9909
+ if ( ir0 + dr >= nr ) {
9910
+ break ;
9899
9911
}
9900
9912
}
9901
9913
@@ -13734,6 +13746,7 @@ struct ggml_compute_state_shared {
13734
13746
13735
13747
// synchronization primitives
13736
13748
atomic_int n_ready ;
13749
+ atomic_int aic ;
13737
13750
atomic_bool has_work ;
13738
13751
atomic_bool stop ; // stop all threads
13739
13752
};
@@ -13802,6 +13815,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
13802
13815
/*.spin =*/ GGML_LOCK_INITIALIZER ,
13803
13816
/*.n_threads =*/ n_threads ,
13804
13817
/*.n_ready =*/ 0 ,
13818
+ /*.aic =*/ 0 ,
13805
13819
/*.has_work =*/ false,
13806
13820
/*.stop =*/ false,
13807
13821
};
@@ -13822,6 +13836,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
13822
13836
.nth = n_threads ,
13823
13837
.wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
13824
13838
.wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
13839
+ .aic = & state_shared .aic ,
13825
13840
},
13826
13841
.node = NULL ,
13827
13842
.shared = & state_shared ,
@@ -14111,6 +14126,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14111
14126
/*.nth =*/ node -> n_tasks ,
14112
14127
/*.wsize =*/ cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
14113
14128
/*.wdata =*/ cgraph -> work ? cgraph -> work -> data : NULL ,
14129
+ /*.aic =*/ & state_shared .aic ,
14114
14130
};
14115
14131
14116
14132
ggml_compute_forward (& params , node );
@@ -14134,6 +14150,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14134
14150
.nth = node -> n_tasks ,
14135
14151
.wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
14136
14152
.wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
14153
+ .aic = & state_shared .aic ,
14137
14154
};
14138
14155
workers [j ].node = node ;
14139
14156
}
@@ -14149,6 +14166,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14149
14166
}
14150
14167
14151
14168
params .type = GGML_TASK_COMPUTE ;
14169
+ params .aic = & state_shared .aic ;
14152
14170
ggml_compute_forward (& params , node );
14153
14171
14154
14172
// wait for thread pool
@@ -14189,6 +14207,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
14189
14207
.nth = node -> n_tasks ,
14190
14208
.wsize = cgraph -> work ? ggml_nbytes (cgraph -> work ) : 0 ,
14191
14209
.wdata = cgraph -> work ? cgraph -> work -> data : NULL ,
14210
+ .aic = & state_shared .aic ,
14192
14211
};
14193
14212
workers [j ].node = node ;
14194
14213
}
0 commit comments