@@ -8653,9 +8653,11 @@ static void ggml_compute_forward_rope_f32(
8653
8653
8654
8654
const float theta_scale = powf (10000.0 , -2.0f /n_dims );
8655
8655
8656
+ const bool is_neox = mode & 2 ;
8657
+
8656
8658
for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
8657
- for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
8658
- const int p = (mode == 0 ? n_past + i2 : i2 );
8659
+ for (int64_t i2 = (( mode & 1 ) == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
8660
+ const int p = (( mode & 1 ) == 0 ? n_past + i2 : i2 );
8659
8661
for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
8660
8662
if (ir ++ < ir0 ) continue ;
8661
8663
if (ir > ir1 ) break ;
@@ -8668,14 +8670,25 @@ static void ggml_compute_forward_rope_f32(
8668
8670
8669
8671
theta *= theta_scale ;
8670
8672
8671
- const float * const src = (float * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
8672
- float * dst_data = (float * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
8673
+ if (!is_neox ) {
8674
+ const float * const src = (float * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
8675
+ float * dst_data = (float * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
8676
+
8677
+ const float x0 = src [0 ];
8678
+ const float x1 = src [1 ];
8673
8679
8674
- const float x0 = src [0 ];
8675
- const float x1 = src [1 ];
8680
+ dst_data [0 ] = x0 * cos_theta - x1 * sin_theta ;
8681
+ dst_data [1 ] = x0 * sin_theta + x1 * cos_theta ;
8682
+ } else {
8683
+ const float * const src = (float * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + (i0 /2 )* nb0 );
8684
+ float * dst_data = (float * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + (i0 /2 )* nb0 );
8676
8685
8677
- dst_data [0 ] = x0 * cos_theta - x1 * sin_theta ;
8678
- dst_data [1 ] = x0 * sin_theta + x1 * cos_theta ;
8686
+ const float x0 = src [0 ];
8687
+ const float x1 = src [n_dims /2 ];
8688
+
8689
+ dst_data [0 ] = x0 * cos_theta - x1 * sin_theta ;
8690
+ dst_data [n_dims /2 ] = x0 * sin_theta + x1 * cos_theta ;
8691
+ }
8679
8692
}
8680
8693
}
8681
8694
}
@@ -8730,9 +8743,11 @@ static void ggml_compute_forward_rope_f16(
8730
8743
8731
8744
const float theta_scale = powf (10000.0 , -2.0f /n_dims );
8732
8745
8746
+ const bool is_neox = mode & 2 ;
8747
+
8733
8748
for (int64_t i3 = 0 ; i3 < ne3 ; i3 ++ ) {
8734
- for (int64_t i2 = (mode == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
8735
- const int p = (mode == 0 ? n_past + i2 : i2 );
8749
+ for (int64_t i2 = (( mode & 1 ) == 0 ? 0 : n_past ); i2 < ne2 ; i2 ++ ) {
8750
+ const int p = (( mode & 1 ) == 0 ? n_past + i2 : i2 );
8736
8751
for (int64_t i1 = 0 ; i1 < ne1 ; i1 ++ ) {
8737
8752
if (ir ++ < ir0 ) continue ;
8738
8753
if (ir > ir1 ) break ;
@@ -8745,14 +8760,25 @@ static void ggml_compute_forward_rope_f16(
8745
8760
8746
8761
theta *= theta_scale ;
8747
8762
8748
- const ggml_fp16_t * const src = (ggml_fp16_t * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
8749
- ggml_fp16_t * dst_data = (ggml_fp16_t * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
8763
+ if (!is_neox ) {
8764
+ const ggml_fp16_t * const src = (ggml_fp16_t * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
8765
+ ggml_fp16_t * dst_data = (ggml_fp16_t * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0 );
8766
+
8767
+ const float x0 = GGML_FP16_TO_FP32 (src [0 ]);
8768
+ const float x1 = GGML_FP16_TO_FP32 (src [1 ]);
8750
8769
8751
- const float x0 = GGML_FP16_TO_FP32 (src [0 ]);
8752
- const float x1 = GGML_FP16_TO_FP32 (src [1 ]);
8770
+ dst_data [0 ] = GGML_FP32_TO_FP16 (x0 * cos_theta - x1 * sin_theta );
8771
+ dst_data [1 ] = GGML_FP32_TO_FP16 (x0 * sin_theta + x1 * cos_theta );
8772
+ } else {
8773
+ const ggml_fp16_t * const src = (ggml_fp16_t * )((char * ) src0 -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + (i0 /2 )* nb0 );
8774
+ ggml_fp16_t * dst_data = (ggml_fp16_t * )((char * ) dst -> data + i3 * nb3 + i2 * nb2 + i1 * nb1 + (i0 /2 )* nb0 );
8753
8775
8754
- dst_data [0 ] = GGML_FP32_TO_FP16 (x0 * cos_theta - x1 * sin_theta );
8755
- dst_data [1 ] = GGML_FP32_TO_FP16 (x0 * sin_theta + x1 * cos_theta );
8776
+ const float x0 = GGML_FP16_TO_FP32 (src [0 ]);
8777
+ const float x1 = GGML_FP16_TO_FP32 (src [n_dims /2 ]);
8778
+
8779
+ dst_data [0 ] = GGML_FP32_TO_FP16 (x0 * cos_theta - x1 * sin_theta );
8780
+ dst_data [n_dims /2 ] = GGML_FP32_TO_FP16 (x0 * sin_theta + x1 * cos_theta );
8781
+ }
8756
8782
}
8757
8783
}
8758
8784
}
0 commit comments