Skip to content

Commit e64f5b5

Browse files
authored
examples : make n_ctx warning work again (#3066)
This was broken by commit e36ecdc ("build : on Mac OS enable Metal by default (#2901)").
1 parent 94f10b9 commit e64f5b5

File tree

5 files changed

+33
-19
lines changed

5 files changed

+33
-19
lines changed

examples/embedding/embedding.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@ int main(int argc, char ** argv) {
1717

1818
params.embedding = true;
1919

20-
if (params.n_ctx > 2048) {
21-
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
22-
"expect poor results\n", __func__, params.n_ctx);
23-
}
24-
2520
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
2621

2722
if (params.seed == LLAMA_DEFAULT_SEED) {
@@ -47,6 +42,12 @@ int main(int argc, char ** argv) {
4742
return 1;
4843
}
4944

45+
const int n_ctx_train = llama_n_ctx_train(ctx);
46+
if (params.n_ctx > n_ctx_train) {
47+
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
48+
__func__, n_ctx_train, params.n_ctx);
49+
}
50+
5051
// print system information
5152
{
5253
fprintf(stderr, "\n");

examples/main/main.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,10 @@ int main(int argc, char ** argv) {
182182
return 1;
183183
}
184184

185-
if (params.n_ctx > llama_n_ctx(ctx)) {
186-
LOG_TEE("%s: warning: base model only supports context sizes no greater than %d tokens (%d specified)\n", __func__, llama_n_ctx(ctx), params.n_ctx);
185+
const int n_ctx_train = llama_n_ctx_train(ctx);
186+
if (params.n_ctx > n_ctx_train) {
187+
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
188+
__func__, n_ctx_train, params.n_ctx);
187189
} else if (params.n_ctx < 8) {
188190
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
189191
params.n_ctx = 8;

examples/perplexity/perplexity.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -693,9 +693,10 @@ int main(int argc, char ** argv) {
693693
return 1;
694694
}
695695

696-
if (params.n_ctx > llama_n_ctx(ctx)) {
697-
fprintf(stderr, "%s: warning: model might not support context sizes greater than %d tokens (%d specified);"
698-
"expect poor results\n", __func__, llama_n_ctx(ctx), params.n_ctx);
696+
const int n_ctx_train = llama_n_ctx_train(ctx);
697+
if (params.n_ctx > n_ctx_train) {
698+
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
699+
__func__, n_ctx_train, params.n_ctx);
699700
}
700701

701702
// print system information

llama.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5633,15 +5633,19 @@ void llama_free(struct llama_context * ctx) {
56335633
}
56345634

56355635
int llama_n_vocab(const struct llama_context * ctx) {
5636-
return ctx->model.vocab.id_to_token.size();
5636+
return llama_model_n_vocab(&ctx->model);
56375637
}
56385638

56395639
int llama_n_ctx(const struct llama_context * ctx) {
5640-
return ctx->model.hparams.n_ctx;
5640+
return llama_model_n_ctx(&ctx->model);
5641+
}
5642+
5643+
int llama_n_ctx_train(const struct llama_context * ctx) {
5644+
return llama_model_n_ctx_train(&ctx->model);
56415645
}
56425646

56435647
int llama_n_embd(const struct llama_context * ctx) {
5644-
return ctx->model.hparams.n_embd;
5648+
return llama_model_n_embd(&ctx->model);
56455649
}
56465650

56475651
enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) {
@@ -5656,6 +5660,10 @@ int llama_model_n_ctx(const struct llama_model * model) {
56565660
return model->hparams.n_ctx;
56575661
}
56585662

5663+
int llama_model_n_ctx_train(const struct llama_model * model) {
5664+
return model->hparams.n_ctx_train;
5665+
}
5666+
56595667
int llama_model_n_embd(const struct llama_model * model) {
56605668
return model->hparams.n_embd;
56615669
}

llama.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,17 @@ extern "C" {
245245
LLAMA_API bool llama_mmap_supported (void);
246246
LLAMA_API bool llama_mlock_supported(void);
247247

248-
LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
249-
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
250-
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
248+
LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
249+
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
250+
LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
251+
LLAMA_API int llama_n_embd (const struct llama_context * ctx);
251252

252253
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
253254

254-
LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
255-
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
256-
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
255+
LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
256+
LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
257+
LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
258+
LLAMA_API int llama_model_n_embd (const struct llama_model * model);
257259

258260
// Get a string describing the model type
259261
LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);

0 commit comments

Comments
 (0)