Skip to content

Commit d23a4a6

Browse files
committed
mulmat-tune-tool: allow explictly set m_step
1 parent 3c39fdf commit d23a4a6

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

examples/mulmat-tune/mulmat-tune-tool.c

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ static void usage(char *prog) {
3535
" default 7B\n",
3636
"--type TYPE Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | F32 | F16\n",
3737
" default Q4_0\n",
38-
"--m_num M_NUM number of M, max M = 16 * M_NUM\n",
39-
" requires: M_NUM in range [8, 16]\n",
38+
"--m_step M_STEP the initial value of M and delta between adjacent M\n",
39+
" requires: multiple of 2, and in range[2, 32]\n",
40+
" default 16\n",
41+
"--m_num M_NUM number of M, the max M = M_STEP * M_NUM\n",
42+
" requires: in range [2, 32]\n",
4043
" default 8\n",
4144
"--file FILE data file to write\n",
4245
" default stdout\n",
@@ -91,6 +94,7 @@ int main(int argc, char **argv) {
9194

9295
const char *arg_model = NULL;
9396
const char *arg_type = NULL;
97+
const char *arg_m_step = NULL;
9498
const char *arg_m_num = NULL;
9599
const char *arg_file = NULL;
96100
bool always_yes = false;
@@ -106,6 +110,11 @@ int main(int argc, char **argv) {
106110
arg_type = argv[i + 1];
107111
++i;
108112
}
113+
} else if (strcmp(argv[i], "--m_step") == 0) {
114+
if (i + 1 < argc) {
115+
arg_m_step = argv[i + 1];
116+
++i;
117+
}
109118
} else if (strcmp(argv[i], "--m_num") == 0) {
110119
if (i + 1 < argc) {
111120
arg_m_num = argv[i + 1];
@@ -194,12 +203,25 @@ int main(int argc, char **argv) {
194203
tune.n_profiles =
195204
ggml_mulmat_get_task_profiles(&tune.profiles, type, GGML_TYPE_F32);
196205

206+
if (arg_m_step != NULL) {
207+
int v = atoi(arg_m_step);
208+
tune.m_step = v;
209+
}
210+
if (tune.m_step <= 0 || tune.m_step > 32 || tune.m_step % 2 != 0) {
211+
fprintf(stderr,
212+
"invalid m_step: %d, expect multiple of 2 and in range "
213+
"[2, 32]\n",
214+
tune.m_step);
215+
usage(argv[0]);
216+
exit(1);
217+
}
218+
197219
if (arg_m_num != NULL) {
198220
int v = atoi(arg_m_num);
199221
tune.m_num = v;
200222
}
201-
if (tune.m_num < 8 || tune.m_num > 16) {
202-
fprintf(stderr, "invalid m_num: %d, expect in range [8, 16]\n",
223+
if (tune.m_num < 2 || tune.m_num > 32) {
224+
fprintf(stderr, "invalid m_num: %d, expect in range [2, 32]\n",
203225
tune.m_num);
204226
usage(argv[0]);
205227
exit(1);

0 commit comments

Comments
 (0)