Skip to content

Commit 3efdd6d

Browse files
authored
gguf-split : update (LostRuins#444)
gguf-split : improve --split and --merge logic (ggml-org#9619) * make sure params --split and --merge are not specified at same time * update gguf-split params parse logic * Update examples/gguf-split/gguf-split.cpp Co-authored-by: Xuan Son Nguyen <[email protected]> Co-authored-by: slaren <[email protected]> --------- gguf-split : add basic checks (ggml-org#9499) * gguf-split : do not overwrite existing files when merging * gguf-split : error when too many arguments are passed Authored-by: slaren <[email protected]>
1 parent ec45632 commit 3efdd6d

File tree

1 file changed

+58
-39
lines changed

1 file changed

+58
-39
lines changed

examples/gguf-split/gguf-split.cpp

Lines changed: 58 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,20 @@
2222
#endif
2323

2424
enum split_operation : uint8_t {
25-
SPLIT_OP_SPLIT,
26-
SPLIT_OP_MERGE,
25+
OP_NONE,
26+
OP_SPLIT,
27+
OP_MERGE,
28+
};
29+
30+
enum split_mode : uint8_t {
31+
MODE_NONE,
32+
MODE_TENSOR,
33+
MODE_SIZE,
2734
};
2835

2936
struct split_params {
30-
split_operation operation = SPLIT_OP_SPLIT;
37+
split_operation operation = OP_NONE;
38+
split_mode mode = MODE_NONE;
3139
size_t n_bytes_split = 0;
3240
int n_split_tensors = 128;
3341
std::string input;
@@ -87,59 +95,52 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
8795
}
8896

8997
bool arg_found = false;
90-
bool is_op_set = false;
91-
bool is_mode_set = false;
9298
if (arg == "-h" || arg == "--help") {
9399
split_print_usage(argv[0]);
94100
exit(0);
95-
}
96-
if (arg == "--version") {
101+
} else if (arg == "--version") {
97102
fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
98103
fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
99104
exit(0);
100-
}
101-
if (arg == "--dry-run") {
105+
} else if (arg == "--dry-run") {
102106
arg_found = true;
103107
params.dry_run = true;
104-
}
105-
if (arg == "--no-tensor-first-split") {
108+
} else if (arg == "--no-tensor-first-split") {
106109
arg_found = true;
107110
params.no_tensor_first_split = true;
108-
}
109-
110-
if (is_op_set) {
111-
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
112-
}
113-
if (arg == "--merge") {
111+
} else if (arg == "--merge") {
114112
arg_found = true;
115-
is_op_set = true;
116-
params.operation = SPLIT_OP_MERGE;
117-
}
118-
if (arg == "--split") {
113+
if (params.operation != OP_NONE && params.operation != OP_MERGE) {
114+
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
115+
}
116+
params.operation = OP_MERGE;
117+
} else if (arg == "--split") {
119118
arg_found = true;
120-
is_op_set = true;
121-
params.operation = SPLIT_OP_SPLIT;
122-
}
123-
124-
if (is_mode_set) {
125-
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
126-
}
127-
if (arg == "--split-max-tensors") {
119+
if (params.operation != OP_NONE && params.operation != OP_SPLIT) {
120+
throw std::invalid_argument("error: either --split or --merge can be specified, but not both");
121+
}
122+
params.operation = OP_SPLIT;
123+
} else if (arg == "--split-max-tensors") {
128124
if (++arg_idx >= argc) {
129125
invalid_param = true;
130126
break;
131127
}
132128
arg_found = true;
133-
is_mode_set = true;
129+
if (params.mode != MODE_NONE && params.mode != MODE_TENSOR) {
130+
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
131+
}
132+
params.mode = MODE_TENSOR;
134133
params.n_split_tensors = atoi(argv[arg_idx]);
135-
}
136-
if (arg == "--split-max-size") {
134+
} else if (arg == "--split-max-size") {
137135
if (++arg_idx >= argc) {
138136
invalid_param = true;
139137
break;
140138
}
141139
arg_found = true;
142-
is_mode_set = true;
140+
if (params.mode != MODE_NONE && params.mode != MODE_SIZE) {
141+
throw std::invalid_argument("error: either --split-max-tensors or --split-max-size can be specified, but not both");
142+
}
143+
params.mode = MODE_SIZE;
143144
params.n_bytes_split = split_str_to_n_bytes(argv[arg_idx]);
144145
}
145146

@@ -148,11 +149,20 @@ static void split_params_parse_ex(int argc, const char ** argv, split_params & p
148149
}
149150
}
150151

152+
// the operation is split if not specified
153+
if (params.operation == OP_NONE) {
154+
params.operation = OP_SPLIT;
155+
}
156+
// the split mode is by tensor if not specified
157+
if (params.mode == MODE_NONE) {
158+
params.mode = MODE_TENSOR;
159+
}
160+
151161
if (invalid_param) {
152162
throw std::invalid_argument("error: invalid parameter for argument: " + arg);
153163
}
154164

155-
if (argc - arg_idx < 2) {
165+
if (argc - arg_idx != 2) {
156166
throw std::invalid_argument("error: bad arguments");
157167
}
158168

@@ -265,13 +275,15 @@ struct split_strategy {
265275
}
266276

267277
bool should_split(int i_tensor, size_t next_size) {
268-
if (params.n_bytes_split > 0) {
278+
if (params.mode == MODE_SIZE) {
269279
// split by max size per file
270280
return next_size > params.n_bytes_split;
271-
} else {
281+
} else if (params.mode == MODE_TENSOR) {
272282
// split by number of tensors per file
273283
return i_tensor > 0 && i_tensor < n_tensors && i_tensor % params.n_split_tensors == 0;
274284
}
285+
// should never happen
286+
GGML_ABORT("invalid mode");
275287
}
276288

277289
void print_info() {
@@ -389,10 +401,17 @@ static void gguf_merge(const split_params & split_params) {
389401
int n_split = 1;
390402
int total_tensors = 0;
391403

392-
auto * ctx_out = gguf_init_empty();
404+
// avoid overwriting existing output file
405+
if (std::ifstream(split_params.output.c_str())) {
406+
fprintf(stderr, "%s: output file %s already exists\n", __func__, split_params.output.c_str());
407+
exit(EXIT_FAILURE);
408+
}
409+
393410
std::ofstream fout(split_params.output.c_str(), std::ios::binary);
394411
fout.exceptions(std::ofstream::failbit); // fail fast on write errors
395412

413+
auto * ctx_out = gguf_init_empty();
414+
396415
std::vector<uint8_t> read_data;
397416
std::vector<ggml_context *> ctx_metas;
398417
std::vector<gguf_context *> ctx_ggufs;
@@ -552,9 +571,9 @@ int main(int argc, const char ** argv) {
552571
split_params_parse(argc, argv, params);
553572

554573
switch (params.operation) {
555-
case SPLIT_OP_SPLIT: gguf_split(params);
574+
case OP_SPLIT: gguf_split(params);
556575
break;
557-
case SPLIT_OP_MERGE: gguf_merge(params);
576+
case OP_MERGE: gguf_merge(params);
558577
break;
559578
default: split_print_usage(argv[0]);
560579
exit(EXIT_FAILURE);

0 commit comments

Comments
 (0)