@@ -2713,15 +2713,161 @@ kernel void kernel_rope_neox(
2713
2713
}
2714
2714
}
2715
2715
2716
+ template <typename T>
2717
+ kernel void kernel_rope_multi (
2718
+ constant ggml_metal_kargs_rope & args,
2719
+ device const char * src0,
2720
+ device const char * src1,
2721
+ device const char * src2,
2722
+ device char * dst,
2723
+ ushort tiitg[[thread_index_in_threadgroup]],
2724
+ ushort3 tptg [[threads_per_threadgroup]],
2725
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2726
+ const int i3 = tgpig[2 ];
2727
+ const int i2 = tgpig[1 ];
2728
+ const int i1 = tgpig[0 ];
2729
+
2730
+ float corr_dims[2 ];
2731
+ rope_yarn_corr_dims (args.n_dims , args.n_ctx_orig , args.freq_base , args.beta_fast , args.beta_slow , corr_dims);
2732
+
2733
+ device const int32_t * pos = (device const int32_t *) src1;
2734
+
2735
+ const float inv_ndims = -1 .f /args.n_dims ;
2736
+
2737
+ float cos_theta;
2738
+ float sin_theta;
2739
+
2740
+ for (int i0 = 2 *tiitg; i0 < args.ne0 ; i0 += 2 *tptg.x ) {
2741
+ if (i0 < args.n_dims ) {
2742
+ const int ic = i0/2 ;
2743
+
2744
+ // mrope theta calculations
2745
+ // note: the rest is the same as kernel_rope_neox
2746
+ const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3 ;
2747
+ const int sec_w01 = args.sect_0 + args.sect_1 ; // end of section 1
2748
+ const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2 ; // end of section 2
2749
+ const int sector = ic % sect_dims;
2750
+
2751
+ float theta_base;
2752
+ if (sector < args.sect_0 ) {
2753
+ theta_base = (float ) pos[i2];
2754
+ } else if (sector < sec_w01) {
2755
+ theta_base = (float ) pos[i2 + args.ne02 ];
2756
+ } else if (sector < sec_w012) {
2757
+ theta_base = (float ) pos[i2 + args.ne02 * 2 ];
2758
+ } else {
2759
+ theta_base = (float ) pos[i2 + args.ne02 * 3 ];
2760
+ }
2761
+ // end of mrope
2762
+
2763
+ const float theta = theta_base * pow (args.freq_base , inv_ndims*i0);
2764
+
2765
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1 .0f ;
2766
+
2767
+ rope_yarn (theta/freq_factor, args.freq_scale , corr_dims, i0, args.ext_factor , args.attn_factor , &cos_theta, &sin_theta);
2768
+
2769
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00 );
2770
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0 );
2771
+
2772
+ const float x0 = src[0 ];
2773
+ const float x1 = src[args.n_dims /2 ];
2774
+
2775
+ dst_data[0 ] = x0*cos_theta - x1*sin_theta;
2776
+ dst_data[args.n_dims /2 ] = x0*sin_theta + x1*cos_theta;
2777
+ } else {
2778
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00 );
2779
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0 );
2780
+
2781
+ dst_data[0 ] = src[0 ];
2782
+ dst_data[1 ] = src[1 ];
2783
+ }
2784
+ }
2785
+ }
2786
+
2787
+ template <typename T>
2788
+ kernel void kernel_rope_vision (
2789
+ constant ggml_metal_kargs_rope & args,
2790
+ device const char * src0,
2791
+ device const char * src1,
2792
+ device const char * src2,
2793
+ device char * dst,
2794
+ ushort tiitg[[thread_index_in_threadgroup]],
2795
+ ushort3 tptg [[threads_per_threadgroup]],
2796
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
2797
+ const int i3 = tgpig[2 ];
2798
+ const int i2 = tgpig[1 ];
2799
+ const int i1 = tgpig[0 ];
2800
+
2801
+ float corr_dims[2 ];
2802
+ rope_yarn_corr_dims (args.n_dims , args.n_ctx_orig , args.freq_base , args.beta_fast , args.beta_slow , corr_dims);
2803
+
2804
+ device const int32_t * pos = (device const int32_t *) src1;
2805
+
2806
+ const float inv_ndims = -1 .f /args.n_dims ;
2807
+
2808
+ float cos_theta;
2809
+ float sin_theta;
2810
+
2811
+ for (int i0 = 2 *tiitg; i0 < args.ne0 ; i0 += 2 *tptg.x ) {
2812
+ if (i0 < 2 *args.n_dims ) { // different from kernel_rope_multi
2813
+ const int ic = i0/2 ;
2814
+
2815
+ // mrope theta calculations (only support 2 dimensions)
2816
+ const int sect_dims = args.sect_0 + args.sect_1 ;
2817
+ const int sector = ic % sect_dims;
2818
+
2819
+ float p;
2820
+ float theta_base;
2821
+ if (sector < args.sect_1 ) {
2822
+ p = (float ) sector;
2823
+ theta_base = (float ) pos[i2];
2824
+ } else {
2825
+ p = (float ) sector - args.sect_0 ;
2826
+ theta_base = (float ) pos[i2 + args.ne02 ];
2827
+ }
2828
+
2829
+ const float theta = theta_base * pow (args.freq_base , 2 .0f * inv_ndims * p);
2830
+ // end of mrope
2831
+
2832
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1 .0f ;
2833
+
2834
+ rope_yarn (theta/freq_factor, args.freq_scale , corr_dims, i0, args.ext_factor , args.attn_factor , &cos_theta, &sin_theta);
2835
+
2836
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00 );
2837
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0 );
2838
+
2839
+ const float x0 = src[0 ];
2840
+ const float x1 = src[args.n_dims ]; // different from kernel_rope_multi
2841
+
2842
+ dst_data[0 ] = x0*cos_theta - x1*sin_theta;
2843
+ dst_data[args.n_dims ] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
2844
+ } else {
2845
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00 );
2846
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0 );
2847
+
2848
+ dst_data[0 ] = src[0 ];
2849
+ dst_data[1 ] = src[1 ];
2850
+ }
2851
+ }
2852
+ }
2853
+
2716
2854
typedef decltype (kernel_rope_norm<float >) kernel_rope_norm_t;
2717
2855
typedef decltype (kernel_rope_neox<float >) kernel_rope_neox_t;
2856
+ typedef decltype (kernel_rope_multi<float >) kernel_rope_multi_t;
2857
+ typedef decltype (kernel_rope_vision<float >) kernel_rope_vision_t;
2718
2858
2719
2859
template [[host_name(" kernel_rope_norm_f32" )]] kernel kernel_rope_norm_t kernel_rope_norm<float >;
2720
2860
template [[host_name(" kernel_rope_norm_f16" )]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
2721
2861
2722
2862
template [[host_name(" kernel_rope_neox_f32" )]] kernel kernel_rope_neox_t kernel_rope_neox<float >;
2723
2863
template [[host_name(" kernel_rope_neox_f16" )]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
2724
2864
2865
+ template [[host_name(" kernel_rope_multi_f32" )]] kernel kernel_rope_multi_t kernel_rope_multi<float >;
2866
+ template [[host_name(" kernel_rope_multi_f16" )]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
2867
+
2868
+ template [[host_name(" kernel_rope_vision_f32" )]] kernel kernel_rope_vision_t kernel_rope_vision<float >;
2869
+ template [[host_name(" kernel_rope_vision_f16" )]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
2870
+
2725
2871
typedef void (im2col_t )(
2726
2872
device const float * x,
2727
2873
device char * dst,
0 commit comments