Skip to content

Commit e91b83b

Browse files
committed
add GGML_ASSERT to catch ggml_rope and back value errors
1 parent 561fbe0 commit e91b83b

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

ggml.c

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11430,8 +11430,8 @@ static void ggml_compute_forward_rope_f32(
1143011430
const struct ggml_tensor * src0,
1143111431
const struct ggml_tensor * src1,
1143211432
struct ggml_tensor * dst) {
11433-
assert(src1->type == GGML_TYPE_I32);
11434-
assert(ggml_nelements(src1) == 3);
11433+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
11434+
GGML_ASSERT(ggml_nelements(src1) == 3);
1143511435

1143611436
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1143711437
return;
@@ -11454,12 +11454,16 @@ static void ggml_compute_forward_rope_f32(
1145411454
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
1145511455
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
1145611456

11457-
assert(nb0 == sizeof(float));
11457+
GGML_ASSERT(nb0 == sizeof(float));
1145811458

1145911459
const int ith = params->ith;
1146011460
const int nth = params->nth;
1146111461

1146211462
const int nr = ggml_nrows(src0);
11463+
const int nc = src0->ne[0];
11464+
11465+
GGML_ASSERT(n_dims <= nc);
11466+
GGML_ASSERT(n_dims % 2 == 0);
1146311467

1146411468
// rows per thread
1146511469
const int dr = (nr + nth - 1)/nth;
@@ -11520,8 +11524,8 @@ static void ggml_compute_forward_rope_f16(
1152011524
const struct ggml_tensor * src0,
1152111525
const struct ggml_tensor * src1,
1152211526
struct ggml_tensor * dst) {
11523-
assert(src1->type == GGML_TYPE_I32);
11524-
assert(ggml_nelements(src1) == 3);
11527+
GGML_ASSERT(src1->type == GGML_TYPE_I32);
11528+
GGML_ASSERT(ggml_nelements(src1) == 3);
1152511529

1152611530
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1152711531
return;
@@ -11544,12 +11548,16 @@ static void ggml_compute_forward_rope_f16(
1154411548
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
1154511549
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
1154611550

11547-
assert(nb0 == sizeof(ggml_fp16_t));
11551+
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
1154811552

1154911553
const int ith = params->ith;
1155011554
const int nth = params->nth;
1155111555

1155211556
const int nr = ggml_nrows(src0);
11557+
const int nc = src0->ne[0];
11558+
11559+
GGML_ASSERT(n_dims <= nc);
11560+
GGML_ASSERT(n_dims % 2 == 0);
1155311561

1155411562
// rows per thread
1155511563
const int dr = (nr + nth - 1)/nth;

0 commit comments

Comments
 (0)