@@ -35,8 +35,11 @@ static void usage(char *prog) {
35
35
" default 7B\n" ,
36
36
"--type TYPE Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | F32 | F16\n" ,
37
37
" default Q4_0\n" ,
38
- "--m_num M_NUM number of M, max M = 16 * M_NUM\n" ,
39
- " requires: M_NUM in range [8, 16]\n" ,
38
+ "--m_step M_STEP the initial value of M and delta between adjacent M\n" ,
39
+ " requires: multiple of 2, and in range[2, 32]\n" ,
40
+ " default 16\n" ,
41
+ "--m_num M_NUM number of M, the max M = M_STEP * M_NUM\n" ,
42
+ " requires: in range [2, 32]\n" ,
40
43
" default 8\n" ,
41
44
"--file FILE data file to write\n" ,
42
45
" default stdout\n" ,
@@ -91,6 +94,7 @@ int main(int argc, char **argv) {
91
94
92
95
const char * arg_model = NULL ;
93
96
const char * arg_type = NULL ;
97
+ const char * arg_m_step = NULL ;
94
98
const char * arg_m_num = NULL ;
95
99
const char * arg_file = NULL ;
96
100
bool always_yes = false;
@@ -106,6 +110,11 @@ int main(int argc, char **argv) {
106
110
arg_type = argv [i + 1 ];
107
111
++ i ;
108
112
}
113
+ } else if (strcmp (argv [i ], "--m_step" ) == 0 ) {
114
+ if (i + 1 < argc ) {
115
+ arg_m_step = argv [i + 1 ];
116
+ ++ i ;
117
+ }
109
118
} else if (strcmp (argv [i ], "--m_num" ) == 0 ) {
110
119
if (i + 1 < argc ) {
111
120
arg_m_num = argv [i + 1 ];
@@ -194,12 +203,25 @@ int main(int argc, char **argv) {
194
203
tune .n_profiles =
195
204
ggml_mulmat_get_task_profiles (& tune .profiles , type , GGML_TYPE_F32 );
196
205
206
+ if (arg_m_step != NULL ) {
207
+ int v = atoi (arg_m_step );
208
+ tune .m_step = v ;
209
+ }
210
+ if (tune .m_step <= 0 || tune .m_step > 32 || tune .m_step % 2 != 0 ) {
211
+ fprintf (stderr ,
212
+ "invalid m_step: %d, expect multiple of 2 and in range "
213
+ "[2, 32]\n" ,
214
+ tune .m_step );
215
+ usage (argv [0 ]);
216
+ exit (1 );
217
+ }
218
+
197
219
if (arg_m_num != NULL ) {
198
220
int v = atoi (arg_m_num );
199
221
tune .m_num = v ;
200
222
}
201
- if (tune .m_num < 8 || tune .m_num > 16 ) {
202
- fprintf (stderr , "invalid m_num: %d, expect in range [8, 16 ]\n" ,
223
+ if (tune .m_num < 2 || tune .m_num > 32 ) {
224
+ fprintf (stderr , "invalid m_num: %d, expect in range [2, 32 ]\n" ,
203
225
tune .m_num );
204
226
usage (argv [0 ]);
205
227
exit (1 );
0 commit comments