@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
1654
1654
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
1655
1655
static void rope_yarn (
1656
1656
float theta_extrap, float freq_scale, float corr_dims[2 ], int64_t i0, float ext_factor, float mscale,
1657
- thread float * cos_theta, thread float * sin_theta
1658
- ) {
1657
+ thread float * cos_theta, thread float * sin_theta) {
1659
1658
// Get n-d rotational scaling corrected for extrapolation
1660
1659
float theta_interp = freq_scale * theta_extrap;
1661
1660
float theta = theta_interp;
@@ -1684,7 +1683,8 @@ static void rope_yarn_corr_dims(
1684
1683
dims[1 ] = min (n_dims - 1 .0f , ceil (rope_yarn_corr_factor (n_dims, n_orig_ctx, beta_slow, freq_base)));
1685
1684
}
1686
1685
1687
- typedef void (rope_t )(
1686
+ template <typename T>
1687
+ kernel void kernel_rope_norm (
1688
1688
device const void * src0,
1689
1689
device const int32_t * src1,
1690
1690
device const float * src2,
@@ -1707,7 +1707,6 @@ typedef void (rope_t)(
1707
1707
constant uint64_t & nb3,
1708
1708
constant int & n_past,
1709
1709
constant int & n_dims,
1710
- constant int & mode,
1711
1710
constant int & n_orig_ctx,
1712
1711
constant float & freq_base,
1713
1712
constant float & freq_scale,
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
1717
1716
constant float & beta_slow,
1718
1717
uint tiitg[[thread_index_in_threadgroup]],
1719
1718
uint3 tptg[[threads_per_threadgroup]],
1720
- uint3 tgpig[[threadgroup_position_in_grid]]);
1719
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
1720
+ const int64_t i3 = tgpig[2 ];
1721
+ const int64_t i2 = tgpig[1 ];
1722
+ const int64_t i1 = tgpig[0 ];
1723
+
1724
+ float corr_dims[2 ];
1725
+ rope_yarn_corr_dims (n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1726
+
1727
+ device const int32_t * pos = src1;
1728
+
1729
+ const float theta_base = (float ) pos[i2];
1730
+ const float inv_ndims = -1 .f /n_dims;
1731
+
1732
+ float cos_theta;
1733
+ float sin_theta;
1734
+
1735
+ for (int64_t i0 = 2 *tiitg; i0 < ne0; i0 += 2 *tptg.x ) {
1736
+ if (i0 < n_dims) {
1737
+ const int64_t ic = i0/2 ;
1738
+
1739
+ const float theta = theta_base * pow (freq_base, inv_ndims*i0);
1740
+
1741
+ const float freq_factor = src2 != src0 ? src2[ic] : 1 .0f ;
1742
+
1743
+ rope_yarn (theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1744
+
1745
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1746
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1747
+
1748
+ const float x0 = src[0 ];
1749
+ const float x1 = src[1 ];
1750
+
1751
+ dst_data[0 ] = x0*cos_theta - x1*sin_theta;
1752
+ dst_data[1 ] = x0*sin_theta + x1*cos_theta;
1753
+ } else {
1754
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1756
+
1757
+ dst_data[0 ] = src[0 ];
1758
+ dst_data[1 ] = src[1 ];
1759
+ }
1760
+ }
1761
+ }
1721
1762
1722
1763
template <typename T>
1723
- kernel void kernel_rope (
1764
+ kernel void kernel_rope_neox (
1724
1765
device const void * src0,
1725
1766
device const int32_t * src1,
1726
1767
device const float * src2,
@@ -1743,7 +1784,6 @@ kernel void kernel_rope(
1743
1784
constant uint64_t & nb3,
1744
1785
constant int & n_past,
1745
1786
constant int & n_dims,
1746
- constant int & mode,
1747
1787
constant int & n_orig_ctx,
1748
1788
constant float & freq_base,
1749
1789
constant float & freq_scale,
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
1758
1798
const int64_t i2 = tgpig[1 ];
1759
1799
const int64_t i1 = tgpig[0 ];
1760
1800
1761
- const bool is_neox = mode & 2 ;
1762
-
1763
1801
float corr_dims[2 ];
1764
1802
rope_yarn_corr_dims (n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1765
1803
1766
1804
device const int32_t * pos = src1;
1767
1805
1768
- const int64_t p = pos[i2];
1769
-
1770
- const float theta_base = (float )p;
1806
+ const float theta_base = (float ) pos[i2];
1771
1807
const float inv_ndims = -1 .f /n_dims;
1772
1808
1773
- if (!is_neox) {
1774
- for (int64_t i0 = 2 *tiitg; i0 < ne0; i0 += 2 *tptg.x ) {
1775
- const float theta = theta_base * pow (freq_base, inv_ndims*i0);
1809
+ float cos_theta;
1810
+ float sin_theta;
1776
1811
1777
- float cos_theta, sin_theta;
1778
- rope_yarn (theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1812
+ for (int64_t i0 = 2 *tiitg; i0 < ne0; i0 += 2 *tptg.x ) {
1813
+ if (i0 < n_dims) {
1814
+ const int64_t ic = i0/2 ;
1779
1815
1780
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1781
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1782
-
1783
- const T x0 = src[0 ];
1784
- const T x1 = src[1 ];
1785
-
1786
- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
1787
- dst_data[1 ] = x0*sin_theta + x1*cos_theta;
1788
- }
1789
- } else {
1790
- for (int64_t ic = 2 *tiitg; ic < ne0; ic += 2 *tptg.x ) {
1791
- if (ic < n_dims) {
1792
- const int64_t i0 = ic/2 ;
1793
-
1794
- const float freq_factor = src2 != src0 ? src2[i0] : 1 .0f ;
1795
-
1796
- const float theta = theta_base * pow (freq_base, inv_ndims*ic);
1816
+ const float theta = theta_base * pow (freq_base, inv_ndims*i0);
1797
1817
1798
- float cos_theta, sin_theta;
1799
- rope_yarn (theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1818
+ const float freq_factor = src2 != src0 ? src2[ic] : 1 .0f ;
1800
1819
1801
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1802
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820
+ rope_yarn (theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1803
1821
1804
- const float x0 = src[ 0 ] ;
1805
- const float x1 = src[n_dims/ 2 ] ;
1822
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00) ;
1823
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0) ;
1806
1824
1807
- dst_data[0 ] = x0*cos_theta - x1*sin_theta;
1808
- dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
1809
- } else {
1810
- const int64_t i0 = ic;
1825
+ const float x0 = src[0 ];
1826
+ const float x1 = src[n_dims/2 ];
1811
1827
1812
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1813
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1828
+ dst_data[0 ] = x0*cos_theta - x1*sin_theta;
1829
+ dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
1830
+ } else {
1831
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1832
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1814
1833
1815
- dst_data[0 ] = src[0 ];
1816
- dst_data[1 ] = src[1 ];
1817
- }
1834
+ dst_data[0 ] = src[0 ];
1835
+ dst_data[1 ] = src[1 ];
1818
1836
}
1819
1837
}
1820
1838
}
1821
1839
1822
- template [[host_name(" kernel_rope_f32" )]] kernel rope_t kernel_rope<float >;
1823
- template [[host_name(" kernel_rope_f16" )]] kernel rope_t kernel_rope<half>;
1840
+ typedef decltype (kernel_rope_norm<float >) kernel_rope_norm_t;
1841
+ typedef decltype (kernel_rope_neox<float >) kernel_rope_neox_t;
1842
+
1843
+ template [[host_name(" kernel_rope_norm_f32" )]] kernel kernel_rope_norm_t kernel_rope_norm<float >;
1844
+ template [[host_name(" kernel_rope_norm_f16" )]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
1845
+
1846
+ template [[host_name(" kernel_rope_neox_f32" )]] kernel kernel_rope_neox_t kernel_rope_neox<float >;
1847
+ template [[host_name(" kernel_rope_neox_f16" )]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
1824
1848
1825
1849
typedef void (im2col_t )(
1826
1850
device const float * x,
0 commit comments