Skip to content

Commit f8f1bd4

Browse files
committed
llm_graph_input_attn_temp
1 parent ee06e9b commit f8f1bd4

File tree

4 files changed

+45
-7
lines changed

4 files changed

+45
-7
lines changed

src/llama-graph.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
5959
}
6060
}
6161

62+
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
63+
if (ubatch->pos && attn_scale) {
64+
const int64_t n_tokens = ubatch->n_tokens;
65+
66+
std::vector<float> attn_scale_data(n_tokens, 0.0f);
67+
for (int i = 0; i < n_tokens; ++i) {
68+
const float pos = ubatch->pos[i];
69+
attn_scale_data[i] = std::log(
70+
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
71+
) * f_attn_temp_scale + 1.0;
72+
}
73+
74+
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*n_pos_per_token*ggml_element_size(attn_scale));
75+
}
76+
}
77+
6278
void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
6379
if (pos_bucket) {
6480
const int64_t n_tokens = ubatch->n_tokens;

src/llama-graph.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,23 @@ class llm_graph_input_pos : public llm_graph_input_i {
100100
const int64_t n_pos_per_token = 1;
101101
};
102102

103+
// temperature tuning, used by llama4
104+
class llm_graph_input_attn_temp : public llm_graph_input_i {
105+
public:
106+
llm_graph_input_attn_temp(int64_t n_pos_per_token, uint32_t n_attn_temp_floor_scale, float f_attn_temp_scale)
107+
: n_pos_per_token(n_pos_per_token), n_attn_temp_floor_scale(n_attn_temp_floor_scale), f_attn_temp_scale(f_attn_temp_scale) {}
108+
virtual ~llm_graph_input_attn_temp() = default;
109+
110+
void set_input(const llama_ubatch * ubatch) override;
111+
112+
ggml_tensor * attn_scale = nullptr; // F32 [n_batch]
113+
114+
const int64_t n_pos_per_token = 1;
115+
116+
const uint32_t n_attn_temp_floor_scale;
117+
const float f_attn_temp_scale;
118+
};
119+
103120
class llm_graph_input_pos_bucket : public llm_graph_input_i {
104121
public:
105122
llm_graph_input_pos_bucket(const llama_hparams & hparams) : hparams(hparams) {}

src/llama-hparams.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ struct llama_hparams {
116116
bool use_kq_norm = true;
117117
// values below seems to be fixed on llama4
118118
uint32_t n_no_rope_layer_step = 4;
119-
uint32_t n_attn_temp_tuning = 4;
120119
uint32_t n_attn_temp_floor_scale = 8192;
121120
float f_attn_temp_scale = 0.1;
122121

src/llama-model.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4271,6 +4271,16 @@ struct llm_build_llama : public llm_graph_context {
42714271
// inp_pos - contains the positions
42724272
ggml_tensor * inp_pos = build_inp_pos();
42734273

4274+
// temperature tuning
4275+
ggml_tensor * inp_attn_scale = nullptr;
4276+
if (arch == LLM_ARCH_LLAMA4) {
4277+
auto inp = std::make_unique<llm_graph_input_attn_temp>(n_pos_per_token(), hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
4278+
inp_attn_scale = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens*n_pos_per_token());
4279+
ggml_set_input(inp_attn_scale);
4280+
inp->attn_scale = inp_attn_scale;
4281+
res->add_input(std::move(inp));
4282+
}
4283+
42744284
auto * inp_attn = build_attn_inp_kv_unified();
42754285

42764286
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
@@ -4330,12 +4340,8 @@ struct llm_build_llama : public llm_graph_context {
43304340
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
43314341
ext_factor, attn_factor, beta_fast, beta_slow
43324342
);
4333-
} else {
4334-
// TODO: support temperature tuning (attn_temperature_tuning)
4335-
// Problem: we are missing 2 things:
4336-
// - ggml_cast from I32 to F32
4337-
// - ggml_floor
4338-
// Ref implementation: https://github.com/ml-explore/mlx-lm/blob/9df43c9863c28065fecf87c9be2c5fd7e6f3864c/mlx_lm/models/llama4.py#L122-L130
4343+
} else if (inp_attn_scale) {
4344+
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
43394345
}
43404346

43414347
cb(Qcur, "Qcur", il);

0 commit comments

Comments
 (0)