Skip to content

Commit e643fa1

Browse files
committed
smaller default values for baby llama model parameters
1 parent ee565f3 commit e643fa1

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

examples/baby-llama/baby-llama.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ void get_example_targets(int example_id, struct ggml_tensor * tokens_input, stru
10621062
ggml_set_f32(targets, -1.0f);
10631063
ggml_set_i32_1d(tokens_input, 0, 0);
10641064
for (int i=1; i<n_tokens+1; ++i) {
1065-
float x = example_id + i * 3.14159f * 2.0f * 1.0f / n_tokens;
1065+
float x = example_id + i * 3.14159f * 2.0f * 1.0f * 0.5f / n_tokens;
10661066
float y = sinf(x);//*cosf(x*1.1f+1.0f);
10671067
float z = (y+1.0f)*0.5f; // scale to [0..1]
10681068
z += (frand()-0.5f)*(randomness/n_vocab);
@@ -1113,12 +1113,12 @@ int main(int argc, char ** argv) {
11131113

11141114
struct llama_model model;
11151115
model.hparams.n_vocab = 8;
1116-
model.hparams.n_ctx = 32;
1116+
model.hparams.n_ctx = 8;
11171117
model.hparams.n_embd = 32;
11181118
model.hparams.n_mult = 2;
11191119
model.hparams.n_head = 8;
1120-
model.hparams.n_layer = 8;
1121-
model.hparams.n_rot = model.hparams.n_embd / model.hparams.n_head;
1120+
model.hparams.n_layer = 1;
1121+
model.hparams.n_rot = MIN(16, model.hparams.n_embd / model.hparams.n_head);
11221122

11231123
// model.hparams.n_embd = 32;
11241124
// model.hparams.n_mult = 2;
@@ -1177,7 +1177,7 @@ int main(int argc, char ** argv) {
11771177
size_t compute_size = 1024ll*1024ll*1024ll;
11781178
uint8_t * compute_addr = new uint8_t[compute_size];
11791179

1180-
int n_examples = 128;
1180+
int n_examples = 256;
11811181
int n_tokens = model.hparams.n_ctx;
11821182
int n_vocab = model.hparams.n_vocab;
11831183

@@ -1285,7 +1285,7 @@ int main(int argc, char ** argv) {
12851285

12861286
{
12871287
int n_gen = 128;
1288-
int sample_ctx = n_tokens/2-n_tokens/16;
1288+
int sample_ctx = n_tokens-n_tokens/8;
12891289

12901290
printf("Generating %d tokens.\n", n_gen);
12911291

0 commit comments

Comments
 (0)