Skip to content

Commit 5db268c

Browse files
committed
cuda : add asserts for rope/norm + fix DS2
ggml-ci
1 parent 1e41f2f commit 5db268c

File tree

5 files changed

+51
-14
lines changed

5 files changed

+51
-14
lines changed

ggml-cuda.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2886,7 +2886,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28862886
case GGML_OP_CONT:
28872887
case GGML_OP_DIAG_MASK_INF:
28882888
case GGML_OP_SOFT_MAX:
2889+
return true;
28892890
case GGML_OP_ROPE:
2891+
return ggml_is_contiguous(op->src[0]);
28902892
case GGML_OP_IM2COL:
28912893
case GGML_OP_POOL_2D:
28922894
case GGML_OP_SUM_ROWS:

ggml-cuda/norm.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
170170
float * dst_d = (float *)dst->data;
171171
cudaStream_t stream = ctx.stream();
172172

173+
GGML_ASSERT(ggml_is_contiguous(src0));
174+
173175
GGML_ASSERT(src0->type == GGML_TYPE_F32);
174176
GGML_ASSERT( dst->type == GGML_TYPE_F32);
175177

@@ -188,6 +190,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
188190
float * dst_d = (float *)dst->data;
189191
cudaStream_t stream = ctx.stream();
190192

193+
GGML_ASSERT(ggml_is_contiguous(src0));
194+
191195
GGML_ASSERT(src0->type == GGML_TYPE_F32);
192196
GGML_ASSERT( dst->type == GGML_TYPE_F32);
193197

@@ -202,6 +206,8 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
202206
float * dst_d = (float *)dst->data;
203207
cudaStream_t stream = ctx.stream();
204208

209+
GGML_ASSERT(ggml_is_contiguous(src0));
210+
205211
GGML_ASSERT(src0->type == GGML_TYPE_F32);
206212
GGML_ASSERT( dst->type == GGML_TYPE_F32);
207213

ggml-cuda/rope.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
251251
float * dst_d = (float *)dst->data;
252252
cudaStream_t stream = ctx.stream();
253253

254+
GGML_ASSERT(ggml_is_contiguous(src0));
254255
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
255256
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
256257
GGML_ASSERT(src0->type == dst->type);

ggml-metal.m

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2187,6 +2187,7 @@ static enum ggml_status ggml_metal_graph_compute(
21872187
case GGML_OP_RMS_NORM:
21882188
{
21892189
GGML_ASSERT(ne00 % 4 == 0);
2190+
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3
21902191

21912192
float eps;
21922193
memcpy(&eps, dst->op_params, sizeof(float));
@@ -2214,6 +2215,7 @@ static enum ggml_status ggml_metal_graph_compute(
22142215
case GGML_OP_GROUP_NORM:
22152216
{
22162217
GGML_ASSERT(ne00 % 4 == 0);
2218+
GGML_ASSERT(ggml_is_contiguous(src0));
22172219

22182220
//float eps;
22192221
//memcpy(&eps, dst->op_params, sizeof(float));
@@ -2247,6 +2249,8 @@ static enum ggml_status ggml_metal_graph_compute(
22472249
} break;
22482250
case GGML_OP_NORM:
22492251
{
2252+
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: only requires contiguous dim 1, 2, 3
2253+
22502254
float eps;
22512255
memcpy(&eps, dst->op_params, sizeof(float));
22522256

llama.cpp

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11187,46 +11187,69 @@ struct llm_build_context {
1118711187
}
1118811188

1118911189
// split into {n_head * n_embd_head_qk_nope, n_tokens}
11190-
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0);
11190+
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
11191+
ggml_row_size(q->type, hparams.n_embd_head_k),
11192+
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
11193+
0);
1119111194
cb(q_nope, "q_nope", il);
11195+
1119211196
// and {n_head * n_embd_head_qk_rope, n_tokens}
11193-
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * n_embd_head_qk_nope);
11197+
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
11198+
ggml_row_size(q->type, hparams.n_embd_head_k),
11199+
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
11200+
ggml_row_size(q->type, n_embd_head_qk_nope));
1119411201
cb(q_pe, "q_pe", il);
1119511202

1119611203
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
11197-
struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
11198-
cb(compressed_kv_pe, "compressed_kv_pe", il);
11204+
struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
11205+
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
1119911206

1120011207
// split into {kv_lora_rank, n_tokens}
11201-
struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0);
11202-
cb(compressed_kv, "compressed_kv", il);
11208+
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
11209+
kv_pe_compresseed->nb[1],
11210+
0);
11211+
cb(kv_compressed, "kv_compressed", il);
11212+
1120311213
// and {n_embd_head_qk_rope, n_tokens}
11204-
struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, n_embd_head_qk_rope, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank);
11214+
struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
11215+
kv_pe_compresseed->nb[1],
11216+
kv_pe_compresseed->nb[1],
11217+
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
1120511218
cb(k_pe, "k_pe", il);
1120611219

11207-
compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams,
11220+
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
11221+
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
1120811222
model.layers[il].attn_kv_a_norm, NULL,
1120911223
LLM_NORM_RMS, cb, il);
11210-
cb(compressed_kv, "compressed_kv", il);
11224+
cb(kv_compressed, "kv_compressed", il);
1121111225

1121211226
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
11213-
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv);
11227+
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
1121411228
cb(kv, "kv", il);
1121511229

1121611230
// split into {n_head * n_embd_head_qk_nope, n_tokens}
11217-
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
11231+
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
11232+
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
11233+
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
11234+
0);
1121811235
cb(k_nope, "k_nope", il);
1121911236

1122011237
// and {n_head * n_embd_head_v, n_tokens}
11221-
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_embd_head_qk_nope);
11238+
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
11239+
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
11240+
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
11241+
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
1122211242
cb(v_states, "v_states", il);
1122311243

1122411244
v_states = ggml_cont(ctx0, v_states);
1122511245
cb(v_states, "v_states", il);
1122611246

11227-
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0);
11247+
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
11248+
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
11249+
0);
1122811250
cb(v_states, "v_states", il);
1122911251

11252+
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
1123011253
q_pe = ggml_rope_ext(
1123111254
ctx0, q_pe, inp_pos, nullptr,
1123211255
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
@@ -11235,8 +11258,9 @@ struct llm_build_context {
1123511258
cb(q_pe, "q_pe", il);
1123611259

1123711260
// shared RoPE key
11261+
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
1123811262
k_pe = ggml_rope_ext(
11239-
ctx0, ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr,
11263+
ctx0, k_pe, inp_pos, nullptr,
1124011264
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
1124111265
ext_factor, attn_factor_scaled, beta_fast, beta_slow
1124211266
);

0 commit comments

Comments
 (0)