Skip to content

Commit 561fbe0

Browse files
committed
replace inplace operations for training with copying operations to allow gradient propagation
1 parent 956511b commit 561fbe0

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

examples/baby-llama/baby-llama.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ struct ggml_tensor * forward(
360360
// wk shape [n_embd, n_embd, 1, 1]
361361
// Qcur shape [n_embd/n_head, n_head, N, 1]
362362
// Kcur shape [n_embd/n_head, n_head, N, 1]
363-
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
364-
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
363+
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
364+
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
365365

366366
// store key and value to memory
367367
{
@@ -414,17 +414,17 @@ struct ggml_tensor * forward(
414414
// KQ_scaled = KQ / sqrt(n_embd/n_head)
415415
// KQ_scaled shape [n_past + N, N, n_head, 1]
416416
struct ggml_tensor * KQ_scaled =
417-
ggml_scale_inplace(ctx0,
417+
ggml_scale(ctx0,
418418
KQ,
419419
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
420420

421421
// KQ_masked = mask_past(KQ_scaled)
422422
// KQ_masked shape [n_past + N, N, n_head, 1]
423-
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
423+
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
424424

425425
// KQ = soft_max(KQ_masked)
426426
// KQ_soft_max shape [n_past + N, N, n_head, 1]
427-
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
427+
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
428428

429429
// split cached V into n_head heads
430430
//// V shape [n_past + N, n_embd/n_head, n_head, 1]
@@ -446,9 +446,10 @@ struct ggml_tensor * forward(
446446

447447
// cur = KQV_merged.contiguous().view(n_embd, N)
448448
// cur shape [n_embd,N,1,1]
449-
cur = ggml_cpy(ctx0,
450-
KQV_merged,
451-
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
449+
cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, KQV_merged), n_embd, N);
450+
// cur = ggml_cpy(ctx0,
451+
// KQV_merged,
452+
// ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
452453

453454
// projection (no bias)
454455
cur = ggml_mul_mat(ctx0,

0 commit comments

Comments
 (0)