Skip to content

Commit dc0d0eb

Browse files
committed
Implement customizable RoPE
The original RoPE has pre-defined parameters theta_i = 10000^(−2(i−1)/d), for i in [1, 2, ..., d/2] Our customizable RoPE, ggml_rope_custom_inplace, uses theta_i = scale * base^(−2(i−1)/d), for i in [1, 2, ..., d/2] with the default matches the original scale = 1.0 base = 10000 The new command line arguments --rope-freq-base --rope-freq-scale set the two new RoPE parameter. Recent researches show changing these two parameters extends the context limit with minimal loss. 1. Extending Context to 8K kaiokendev https://kaiokendev.github.io/til#extending-context-to-8k 2. Extending Context Window of Large Language Models via Positional Interpolation Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian https://arxiv.org/abs/2306.15595 3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation. https://www.reddit.com/user/bloc97 https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ For the bold, try adding the following command line parameters to your favorite model: -c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5
1 parent dfd9fce commit dc0d0eb

File tree

10 files changed

+131
-28
lines changed

10 files changed

+131
-28
lines changed

examples/common.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
168168
break;
169169
}
170170
params.n_ctx = std::stoi(argv[i]);
171+
} else if (arg == "--rope-freq-base") {
172+
if (++i >= argc) {
173+
invalid_param = true;
174+
break;
175+
}
176+
params.rope_freq_base = std::stof(argv[i]);
177+
} else if (arg == "--rope-freq-scale") {
178+
if (++i >= argc) {
179+
invalid_param = true;
180+
break;
181+
}
182+
params.rope_freq_scale = std::stof(argv[i]);
171183
} else if (arg == "--memory-f32") {
172184
params.memory_f16 = false;
173185
} else if (arg == "--top-p") {
@@ -469,6 +481,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
469481
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
470482
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
471483
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
484+
fprintf(stderr, " --rope_freq_base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
485+
fprintf(stderr, " --rope_freq_scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
472486
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
473487
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
474488
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -549,6 +563,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
549563
lparams.use_mlock = params.use_mlock;
550564
lparams.logits_all = params.perplexity;
551565
lparams.embedding = params.embedding;
566+
lparams.rope_freq_base = params.rope_freq_base;
567+
lparams.rope_freq_scale = params.rope_freq_scale;
552568

553569
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
554570
if (model == NULL) {

examples/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ struct gpt_params {
3232
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
3333
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
3434
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
35+
float rope_freq_base = 10000.0f; // RoPE base frequency
36+
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
3537

3638
// sampling parameters
3739
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

examples/main/main.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,17 @@ int main(int argc, char ** argv) {
8484
return 0;
8585
}
8686

87+
if (params.rope_freq_base != 10000.0) {
88+
fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
89+
}
90+
91+
if (params.rope_freq_scale != 1.0) {
92+
fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
93+
}
94+
8795
if (params.n_ctx > 2048) {
88-
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
89-
"expect poor results\n", __func__, params.n_ctx);
96+
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
97+
" you are on your own\n", __func__, params.n_ctx);
9098
} else if (params.n_ctx < 8) {
9199
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
92100
params.n_ctx = 8;

examples/server/server.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
608608
fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
609609
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
610610
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
611+
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
612+
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
611613
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
612614
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
613615
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -722,6 +724,22 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
722724
}
723725
params.n_ctx = std::stoi(argv[i]);
724726
}
727+
else if (arg == "--rope-freq-base")
728+
{
729+
if (++i >= argc) {
730+
invalid_param = true;
731+
break;
732+
}
733+
params.rope_freq_base = std::stof(argv[i]);
734+
}
735+
else if (arg == "--rope-freq-scale")
736+
{
737+
if (++i >= argc) {
738+
invalid_param = true;
739+
break;
740+
}
741+
params.rope_freq_scale = std::stof(argv[i]);
742+
}
725743
else if (arg == "--memory-f32" || arg == "--memory_f32")
726744
{
727745
params.memory_f16 = false;

ggml-metal.m

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,10 @@ void ggml_metal_graph_compute(
874874

875875
const int n_past = ((int32_t *)(src1->data))[0];
876876

877+
float freq_base, freq_scale;
878+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
879+
memcpy(&freq_scale, (int32_t *) src1->date + 5, sizeof(float));
880+
877881
[encoder setComputePipelineState:ctx->pipeline_rope];
878882
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
879883
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -896,6 +900,8 @@ void ggml_metal_graph_compute(
896900
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
897901
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
898902
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
903+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
904+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
899905

900906
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
901907
} break;

ggml-metal.metal

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,17 +615,19 @@ kernel void kernel_rope(
615615
constant int & n_past,
616616
constant int & n_dims,
617617
constant int & mode,
618+
constant float & freq_base,
619+
constant float & freq_scale,
618620
uint3 tpig[[thread_position_in_grid]]) {
619621
const int64_t i3 = tpig[2];
620622
const int64_t i2 = tpig[1];
621623
const int64_t i1 = tpig[0];
622624

623625
const bool is_neox = mode & 2;
624-
const float theta_scale = pow(10000.0, -2.0f/n_dims);
626+
const float theta_scale = pow(freq_base, -2.0f/n_dims);
625627

626628
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
627629

628-
float theta = (float)p;
630+
float theta = freq_scale * (float)p;
629631

630632
if (!is_neox) {
631633
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {

ggml.c

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6943,6 +6943,8 @@ struct ggml_tensor * ggml_rope_impl(
69436943
int n_past,
69446944
int n_dims,
69456945
int mode,
6946+
float freq_base,
6947+
float freq_scale,
69466948
int n_ctx,
69476949
bool inplace) {
69486950
GGML_ASSERT(n_past >= 0);
@@ -6956,12 +6958,14 @@ struct ggml_tensor * ggml_rope_impl(
69566958

69576959
ggml_scratch_save(ctx);
69586960

6959-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6961+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
69606962

69616963
((int32_t *) b->data)[0] = n_past;
69626964
((int32_t *) b->data)[1] = n_dims;
69636965
((int32_t *) b->data)[2] = mode;
69646966
((int32_t *) b->data)[3] = n_ctx;
6967+
memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float));
6968+
memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
69656969

69666970
ggml_scratch_load(ctx);
69676971

@@ -6980,7 +6984,7 @@ struct ggml_tensor * ggml_rope(
69806984
int n_dims,
69816985
int mode,
69826986
int n_ctx) {
6983-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
6987+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
69846988
}
69856989

69866990
struct ggml_tensor * ggml_rope_inplace(
@@ -6990,7 +6994,19 @@ struct ggml_tensor * ggml_rope_inplace(
69906994
int n_dims,
69916995
int mode,
69926996
int n_ctx) {
6993-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
6997+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
6998+
}
6999+
7000+
struct ggml_tensor * ggml_rope_custom_inplace(
7001+
struct ggml_context * ctx,
7002+
struct ggml_tensor * a,
7003+
int n_past,
7004+
int n_dims,
7005+
int mode,
7006+
float freq_base,
7007+
float freq_scale,
7008+
int n_ctx) {
7009+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
69947010
}
69957011

69967012
// ggml_rope_back
@@ -11948,7 +11964,7 @@ static void ggml_compute_forward_rope_f32(
1194811964
const struct ggml_tensor * src1,
1194911965
struct ggml_tensor * dst) {
1195011966
GGML_ASSERT(src1->type == GGML_TYPE_I32);
11951-
GGML_ASSERT(ggml_nelements(src1) == 4);
11967+
GGML_ASSERT(ggml_nelements(src1) == 6);
1195211968

1195311969
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1195411970
return;
@@ -11958,6 +11974,9 @@ static void ggml_compute_forward_rope_f32(
1195811974
const int n_dims = ((int32_t *) src1->data)[1];
1195911975
const int mode = ((int32_t *) src1->data)[2];
1196011976
const int n_ctx = ((int32_t *) src1->data)[3];
11977+
float freq_base, freq_scale;
11978+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
11979+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1196111980

1196211981
assert(n_past >= 0);
1196311982

@@ -11986,7 +12005,7 @@ static void ggml_compute_forward_rope_f32(
1198612005
// row index used to determine which thread to use
1198712006
int ir = 0;
1198812007

11989-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12008+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1199012009

1199112010
const bool is_neox = mode & 2;
1199212011
const bool is_glm = mode & 4;
@@ -11998,7 +12017,7 @@ static void ggml_compute_forward_rope_f32(
1199812017
if (ir++ < ir0) continue;
1199912018
if (ir > ir1) break;
1200012019

12001-
float theta = (float)p;
12020+
float theta = freq_scale * (float)p;
1200212021

1200312022
if (is_glm) {
1200412023
theta = MIN(p, n_ctx - 2);
@@ -12075,7 +12094,7 @@ static void ggml_compute_forward_rope_f16(
1207512094
const struct ggml_tensor * src1,
1207612095
struct ggml_tensor * dst) {
1207712096
GGML_ASSERT(src1->type == GGML_TYPE_I32);
12078-
GGML_ASSERT(ggml_nelements(src1) == 4);
12097+
GGML_ASSERT(ggml_nelements(src1) == 6);
1207912098

1208012099
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1208112100
return;
@@ -12085,6 +12104,9 @@ static void ggml_compute_forward_rope_f16(
1208512104
const int n_dims = ((int32_t *) src1->data)[1];
1208612105
const int mode = ((int32_t *) src1->data)[2];
1208712106
const int n_ctx = ((int32_t *) src1->data)[3];
12107+
float freq_base, freq_scale;
12108+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12109+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1208812110

1208912111
assert(n_past >= 0);
1209012112

@@ -12113,7 +12135,7 @@ static void ggml_compute_forward_rope_f16(
1211312135
// row index used to determine which thread to use
1211412136
int ir = 0;
1211512137

12116-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12138+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1211712139

1211812140
const bool is_neox = mode & 2;
1211912141
const bool is_glm = mode & 4;
@@ -12125,7 +12147,7 @@ static void ggml_compute_forward_rope_f16(
1212512147
if (ir++ < ir0) continue;
1212612148
if (ir > ir1) break;
1212712149

12128-
float theta = (float)p;
12150+
float theta = freq_scale * (float)p;
1212912151

1213012152
if (is_glm) {
1213112153
theta = MIN(p, n_ctx - 2);

ggml.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,17 @@ extern "C" {
11071107
int mode,
11081108
int n_ctx);
11091109

1110+
// custom RoPE, in-place, returns view(a)
1111+
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
1112+
struct ggml_context * ctx,
1113+
struct ggml_tensor * a,
1114+
int n_past,
1115+
int n_dims,
1116+
int mode,
1117+
float freq_base,
1118+
float freq_scale,
1119+
int n_ctx);
1120+
11101121
// rotary position embedding backward, i.e compute dx from dy
11111122
// a - dy
11121123
GGML_API struct ggml_tensor * ggml_rope_back(

0 commit comments

Comments
 (0)