Skip to content

Commit 04c91d2

Browse files
committed
use ggml_format_name
1 parent 54f77e2 commit 04c91d2

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

examples/control-vector-generator/control-vector-generator.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ struct callback_data {
127127
// diff_filtered: [n_embd, n_nonzero_rows]
128128
struct ggml_tensor * diff_filtered = ggml_new_tensor_2d(
129129
ctx_ggml, GGML_TYPE_F32, n_embd, n_nonzero_rows);
130-
ggml_set_name(diff_filtered, (std::string("diff_filtered_") + a->name).c_str());
130+
ggml_format_name(diff_filtered, "diff_filtered_%s", a->name);
131131
diff_filtered->data = malloc(ggml_nbytes(diff_filtered));
132132

133133
// copy non-zero rows
@@ -245,7 +245,7 @@ struct train_context {
245245

246246
struct ctrl_params {
247247
/* default meta parameters */
248-
int n_completions = INT_MAX;
248+
int n_completions = 64;
249249
int n_pca_batch = 20;
250250
int n_pca_iterations = 1000;
251251

@@ -311,7 +311,7 @@ static void print_usage(const char * executable) {
311311
printf(" -cf, --completions-file completions file\n");
312312
printf(" default: %s\n", defaults.completions_file.c_str());
313313
printf(" -nc, --num-completions N number of lines of completions file to use\n");
314-
printf(" default: use all lines\n");
314+
printf(" default: %d\n", defaults.n_completions);
315315
printf(" --batch-pca N batch size used for PCA. Larger batch runs faster, but uses more memory\n");
316316
printf(" default: %d\n", defaults.n_pca_batch);
317317
printf(" --iter-pca N number of iterations used for PCA\n");
@@ -550,6 +550,11 @@ int main(int argc, char ** argv) {
550550
return 1;
551551
}
552552

553+
if (cparams.n_pca_iterations % cparams.n_pca_batch != 0) {
554+
fprintf(stderr, "PCA iterations must by multiply of PCA batch size\n");
555+
return 1;
556+
}
557+
553558
// load and prepare entries for training
554559
prepare_entries(cparams);
555560

examples/control-vector-generator/pca.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,14 @@ static struct ggml_cgraph * build_graph_piter(
181181
b_tensor,
182182
ggml_sqrt_inplace(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, b_tensor)))
183183
);
184-
ggml_set_name(b_tensor, ("b_tensor_norm_" + std::to_string(i)).c_str());
184+
ggml_format_name(b_tensor, "b_tensor_norm_%d", i);
185185

186186
// calculate distance(new eigenvector - old eigenvector)
187+
// we don't use ggml_sub because it may not be implemented on GPU backend
187188
struct ggml_tensor * new_sub_old = ggml_add(ctx0, old_eigen, ggml_scale(ctx0, b_tensor, -1));
188189
distance = ggml_sqrt_inplace(ctx0,
189190
ggml_sum_rows(ctx0, ggml_sqr_inplace(ctx0, new_sub_old)));
190-
ggml_set_name(distance, ("distance_" + std::to_string(i)).c_str());
191+
ggml_format_name(distance, "distance_%d", i);
191192

192193
old_eigen = b_tensor;
193194

@@ -317,22 +318,20 @@ static void run_pca(
317318
struct pca_params & params,
318319
const std::vector<struct ggml_tensor *> & v_input, // shape of v_input[0]: [n_samples, n_embd]
319320
const std::vector<struct ggml_tensor *> & v_output) {
320-
printf("Running PCA...\n");
321+
printf("%s: Running PCA...\n", __func__);
321322
for (size_t il = 0; il < v_input.size(); ++il) {
322323

323324
// prepare output vector
324325
struct ggml_tensor * ctrl_out = v_output[il];
325-
auto name = std::string("direction.") + std::to_string(il + 1);
326-
ggml_set_name(ctrl_out, name.c_str());
326+
ggml_format_name(ctrl_out, "direction.%ld", il+1);
327327

328328
// run power_iteration
329329
params.i_layer = il;
330330
params.n_layers = v_input.size();
331331
power_iteration(params, v_input[il], ctrl_out);
332-
printf("DONE layer %ld / %ld\n", il+1, v_input.size());
332+
printf("%s: Done layer %ld / %ld\n", __func__, il+1, v_input.size());
333333
//print_debug_tensor(ctrl_out);
334334
}
335-
printf("Done with PCA.\n");
336335
}
337336

338337
}

0 commit comments

Comments
 (0)