@@ -7507,19 +7507,20 @@ static void ggml_compute_forward_rope_f32(
7507
7507
// row index used to determine which thread to use
7508
7508
int ir = 0 ;
7509
7509
7510
+ const float theta_scale = powf (10000.0 , ((float )-2 )/n_dims );
7511
+
7510
7512
for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
7511
7513
for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
7512
7514
const int p = (mode == 0 ? n_past + i2 : i2 );
7513
7515
for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
7514
7516
if (ir ++ < ir0 ) continue ;
7515
7517
if (ir > ir1 ) break ;
7516
-
7518
+ float theta = ( float ) p ;
7517
7519
for (int i0 = 0 ; i0 < n_dims ; i0 += 2 ) {
7518
- const float theta = powf (10000.0 , ((float )- i0 )/n_dims );
7519
-
7520
- const float cos_theta = cosf (p * theta );
7521
- const float sin_theta = sinf (p * theta );
7520
+ const float cos_theta = cosf (theta );
7521
+ const float sin_theta = sinf (theta );
7522
7522
7523
+ theta *= theta_scale ;
7523
7524
const float * const src = (float * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
7524
7525
float * dst_data = (float * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
7525
7526
@@ -7580,19 +7581,20 @@ static void ggml_compute_forward_rope_f16(
7580
7581
// row index used to determine which thread to use
7581
7582
int ir = 0 ;
7582
7583
7584
+ const float theta_scale = powf (10000.0 , ((float )-2 )/n_dims );
7585
+
7583
7586
for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
7584
7587
for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
7585
7588
const int p = (mode == 0 ? n_past + i2 : i2 );
7586
7589
for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
7587
7590
if (ir ++ < ir0 ) continue ;
7588
7591
if (ir > ir1 ) break ;
7589
-
7592
+ float theta = ( float ) p ;
7590
7593
for (int i0 = 0 ; i0 < n_dims ; i0 += 2 ) {
7591
- const float theta = powf (10000.0 , ((float )- i0 )/n_dims );
7592
-
7593
- const float cos_theta = cosf (p * theta );
7594
- const float sin_theta = sinf (p * theta );
7594
+ const float cos_theta = cosf (theta );
7595
+ const float sin_theta = sinf (theta );
7595
7596
7597
+ theta *= theta_scale ;
7596
7598
const ggml_fp16_t * const src = (ggml_fp16_t * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
7597
7599
ggml_fp16_t * dst_data = (ggml_fp16_t * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
7598
7600
0 commit comments