Skip to content

Commit 12b5900

Browse files
committed
ggml : sync ggml (add GPT-NeoX RoPE implementation)
1 parent 9ff334f commit 12b5900

File tree

3 files changed

+49
-17
lines changed

3 files changed

+49
-17
lines changed

ggml.c

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8653,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
86538653

86548654
const float theta_scale = powf(10000.0, -2.0f/n_dims);
86558655

8656+
const bool is_neox = mode & 2;
8657+
86568658
for (int64_t i3 = 0; i3 < ne3; i3++) {
8657-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8658-
const int p = (mode == 0 ? n_past + i2 : i2);
8659+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8660+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
86598661
for (int64_t i1 = 0; i1 < ne1; i1++) {
86608662
if (ir++ < ir0) continue;
86618663
if (ir > ir1) break;
@@ -8668,14 +8670,25 @@ static void ggml_compute_forward_rope_f32(
86688670

86698671
theta *= theta_scale;
86708672

8671-
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8672-
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8673+
if (!is_neox) {
8674+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8675+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8676+
8677+
const float x0 = src[0];
8678+
const float x1 = src[1];
86738679

8674-
const float x0 = src[0];
8675-
const float x1 = src[1];
8680+
dst_data[0] = x0*cos_theta - x1*sin_theta;
8681+
dst_data[1] = x0*sin_theta + x1*cos_theta;
8682+
} else {
8683+
const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8684+
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
86768685

8677-
dst_data[0] = x0*cos_theta - x1*sin_theta;
8678-
dst_data[1] = x0*sin_theta + x1*cos_theta;
8686+
const float x0 = src[0];
8687+
const float x1 = src[n_dims/2];
8688+
8689+
dst_data[0] = x0*cos_theta - x1*sin_theta;
8690+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
8691+
}
86798692
}
86808693
}
86818694
}
@@ -8730,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
87308743

87318744
const float theta_scale = powf(10000.0, -2.0f/n_dims);
87328745

8746+
const bool is_neox = mode & 2;
8747+
87338748
for (int64_t i3 = 0; i3 < ne3; i3++) {
8734-
for (int64_t i2 = (mode == 0 ? 0 : n_past); i2 < ne2; i2++) {
8735-
const int p = (mode == 0 ? n_past + i2 : i2);
8749+
for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) {
8750+
const int p = ((mode & 1) == 0 ? n_past + i2 : i2);
87368751
for (int64_t i1 = 0; i1 < ne1; i1++) {
87378752
if (ir++ < ir0) continue;
87388753
if (ir > ir1) break;
@@ -8745,14 +8760,25 @@ static void ggml_compute_forward_rope_f16(
87458760

87468761
theta *= theta_scale;
87478762

8748-
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8749-
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8763+
if (!is_neox) {
8764+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8765+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
8766+
8767+
const float x0 = GGML_FP16_TO_FP32(src[0]);
8768+
const float x1 = GGML_FP16_TO_FP32(src[1]);
87508769

8751-
const float x0 = GGML_FP16_TO_FP32(src[0]);
8752-
const float x1 = GGML_FP16_TO_FP32(src[1]);
8770+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8771+
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8772+
} else {
8773+
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
8774+
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (i0/2)*nb0);
87538775

8754-
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8755-
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8776+
const float x0 = GGML_FP16_TO_FP32(src[0]);
8777+
const float x1 = GGML_FP16_TO_FP32(src[n_dims/2]);
8778+
8779+
dst_data[0] = GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
8780+
dst_data[n_dims/2] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
8781+
}
87568782
}
87578783
}
87588784
}

ggml.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,8 @@ struct ggml_tensor * ggml_soft_max(
630630

631631
// rotary position embedding
632632
// in-place, returns view(a)
633-
// if mode == 1, skip n_past elements
633+
// if mode & 1 == 1, skip n_past elements
634+
// if mode & 2 == 1, GPT-NeoX style
634635
// TODO: avoid creating a new tensor every time
635636
struct ggml_tensor * ggml_rope(
636637
struct ggml_context * ctx,

llama.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,11 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
16181618
// quantize only 2D tensors
16191619
quantize &= (tensor.ne.size() == 2);
16201620

1621+
// GG: uncomment this to keep the output layer in FP16
1622+
//if (tensor.name.rfind("output")) {
1623+
// quantize = false;
1624+
//}
1625+
16211626
enum ggml_type new_type;
16221627
void * new_data;
16231628
size_t new_size;

0 commit comments

Comments
 (0)