Skip to content

hparams : add SWA rope parameters #12374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,16 +537,12 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);

float freq_base_l = cparams.rope_freq_base;
float freq_scale_l = cparams.rope_freq_scale;
const bool is_swa = hparams.is_swa(il);

// TODO: improve
if (model.arch == LLM_ARCH_GEMMA3) {
const bool is_sliding = hparams.is_sliding(il);

freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base;
freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale;
}
// note: the swa rope params could become part of the cparams in the future
// if we decide to make them configurable, like the non-sliding ones
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;

ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);

Expand Down
4 changes: 2 additions & 2 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1403,9 +1403,9 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
}

const bool is_sliding = hparams.is_sliding(il);
const bool is_swa = hparams.is_swa(il);

const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();

const auto n_kv = kv_self->n;

Expand Down
2 changes: 1 addition & 1 deletion src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
return ssm_d_state * ssm_d_inner;
}

bool llama_hparams::is_sliding(uint32_t il) const {
bool llama_hparams::is_swa(uint32_t il) const {
if (il < n_layer) {
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
}
Expand Down
4 changes: 3 additions & 1 deletion src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ struct llama_hparams {

float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_base_train_swa;
float rope_freq_scale_train;
float rope_freq_scale_train_swa;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;

Expand Down Expand Up @@ -135,7 +137,7 @@ struct llama_hparams {
// dimension of the recurrent state embeddings
uint32_t n_embd_v_s() const;

bool is_sliding(uint32_t il) const;
bool is_swa(uint32_t il) const;
};

static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");
Expand Down
22 changes: 15 additions & 7 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;

// by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;

ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);

// non-transformer models do not have attention heads
Expand Down Expand Up @@ -877,6 +881,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
{
hparams.n_swa_pattern = 6;

hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;

ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

Expand Down Expand Up @@ -1346,13 +1353,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
const bool is_swa = il < (int) hparams.n_layer && hparams.is_swa(il);
if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev));
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa);
return {cpu_dev, &pimpl->cpu_buft_list};
}
const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
auto * dev = devices.at(layer_gpu);
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev));
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa);
return {dev, &pimpl->gpu_buft_list.at(dev)};
};

Expand Down Expand Up @@ -7381,10 +7389,10 @@ struct llm_build_gemma3 : public llm_graph_context {
auto * inp_attn = build_attn_inp_kv_unified(true, true);

for (int il = 0; il < n_layer; ++il) {
const bool is_sliding = hparams.is_sliding(il);
const bool is_swa = hparams.is_swa(il);

const float freq_base_l = is_sliding ? 10000.0f : freq_base;
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;

// norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
Expand Down Expand Up @@ -7973,7 +7981,7 @@ struct llm_build_cohere2 : public llm_graph_context {
auto * inp_attn = build_attn_inp_kv_unified(true, true);

for (int il = 0; il < n_layer; ++il) {
const bool is_sliding = hparams.is_sliding(il);
const bool is_swa = hparams.is_swa(il);

// norm
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);
Expand Down Expand Up @@ -8007,7 +8015,7 @@ struct llm_build_cohere2 : public llm_graph_context {
cb(Vcur, "Vcur", il);
}

if (is_sliding) {
if (is_swa) {
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
beta_fast, beta_slow);
Expand Down
Loading