Skip to content

Commit a4f5362

Browse files
committed
metal : better rope implementation
ggml-ci
1 parent 19d74d3 commit a4f5362

File tree

4 files changed

+119
-96
lines changed

4 files changed

+119
-96
lines changed

ggml-metal.m

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,10 @@
172172
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173173
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174174
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175-
GGML_METAL_KERNEL_TYPE_ROPE_F32,
176-
GGML_METAL_KERNEL_TYPE_ROPE_F16,
175+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
177179
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
178180
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
179181
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -626,8 +628,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
626628
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
627629
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
628630
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
629-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
630-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
631+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
631635
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
632636
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
633637
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -2297,17 +2301,21 @@ static enum ggml_status ggml_metal_graph_compute(
22972301

22982302
const bool is_neox = mode & 2;
22992303

2300-
if (!is_neox) {
2301-
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2302-
}
2303-
23042304
id<MTLComputePipelineState> pipeline = nil;
23052305

2306-
switch (src0->type) {
2307-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
2308-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
2309-
default: GGML_ASSERT(false);
2310-
};
2306+
if (!is_neox) {
2307+
switch (src0->type) {
2308+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2309+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2310+
default: GGML_ASSERT(false);
2311+
};
2312+
} else {
2313+
switch (src0->type) {
2314+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2315+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2316+
default: GGML_ASSERT(false);
2317+
};
2318+
}
23112319

23122320
[encoder setComputePipelineState:pipeline];
23132321
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -2336,14 +2344,13 @@ static enum ggml_status ggml_metal_graph_compute(
23362344
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
23372345
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
23382346
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2339-
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
2340-
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2341-
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2342-
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2343-
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2344-
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2345-
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2346-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2347+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2348+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2349+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2350+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2351+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2352+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2353+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
23472354

23482355
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
23492356
} break;

ggml-metal.metal

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
16541654
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
16551655
static void rope_yarn(
16561656
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1657-
thread float * cos_theta, thread float * sin_theta
1658-
) {
1657+
thread float * cos_theta, thread float * sin_theta) {
16591658
// Get n-d rotational scaling corrected for extrapolation
16601659
float theta_interp = freq_scale * theta_extrap;
16611660
float theta = theta_interp;
@@ -1684,7 +1683,8 @@ static void rope_yarn_corr_dims(
16841683
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
16851684
}
16861685

1687-
typedef void (rope_t)(
1686+
template<typename T>
1687+
kernel void kernel_rope_norm(
16881688
device const void * src0,
16891689
device const int32_t * src1,
16901690
device const float * src2,
@@ -1707,7 +1707,6 @@ typedef void (rope_t)(
17071707
constant uint64_t & nb3,
17081708
constant int & n_past,
17091709
constant int & n_dims,
1710-
constant int & mode,
17111710
constant int & n_orig_ctx,
17121711
constant float & freq_base,
17131712
constant float & freq_scale,
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
17171716
constant float & beta_slow,
17181717
uint tiitg[[thread_index_in_threadgroup]],
17191718
uint3 tptg[[threads_per_threadgroup]],
1720-
uint3 tgpig[[threadgroup_position_in_grid]]);
1719+
uint3 tgpig[[threadgroup_position_in_grid]]) {
1720+
const int64_t i3 = tgpig[2];
1721+
const int64_t i2 = tgpig[1];
1722+
const int64_t i1 = tgpig[0];
1723+
1724+
float corr_dims[2];
1725+
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1726+
1727+
device const int32_t * pos = src1;
1728+
1729+
const float theta_base = (float) pos[i2];
1730+
const float inv_ndims = -1.f/n_dims;
1731+
1732+
float cos_theta;
1733+
float sin_theta;
1734+
1735+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1736+
if (i0 < n_dims) {
1737+
const int64_t ic = i0/2;
1738+
1739+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1740+
1741+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1742+
1743+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1744+
1745+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1746+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1747+
1748+
const float x0 = src[0];
1749+
const float x1 = src[1];
1750+
1751+
dst_data[0] = x0*cos_theta - x1*sin_theta;
1752+
dst_data[1] = x0*sin_theta + x1*cos_theta;
1753+
} else {
1754+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1756+
1757+
dst_data[0] = src[0];
1758+
dst_data[1] = src[1];
1759+
}
1760+
}
1761+
}
17211762

17221763
template<typename T>
1723-
kernel void kernel_rope(
1764+
kernel void kernel_rope_neox(
17241765
device const void * src0,
17251766
device const int32_t * src1,
17261767
device const float * src2,
@@ -1743,7 +1784,6 @@ kernel void kernel_rope(
17431784
constant uint64_t & nb3,
17441785
constant int & n_past,
17451786
constant int & n_dims,
1746-
constant int & mode,
17471787
constant int & n_orig_ctx,
17481788
constant float & freq_base,
17491789
constant float & freq_scale,
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
17581798
const int64_t i2 = tgpig[1];
17591799
const int64_t i1 = tgpig[0];
17601800

1761-
const bool is_neox = mode & 2;
1762-
17631801
float corr_dims[2];
17641802
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
17651803

17661804
device const int32_t * pos = src1;
17671805

1768-
const int64_t p = pos[i2];
1769-
1770-
const float theta_base = (float)p;
1806+
const float theta_base = (float) pos[i2];
17711807
const float inv_ndims = -1.f/n_dims;
17721808

1773-
if (!is_neox) {
1774-
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775-
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1809+
float cos_theta;
1810+
float sin_theta;
17761811

1777-
float cos_theta, sin_theta;
1778-
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1812+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1813+
if (i0 < n_dims) {
1814+
const int64_t ic = i0/2;
17791815

1780-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1781-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1782-
1783-
const T x0 = src[0];
1784-
const T x1 = src[1];
1785-
1786-
dst_data[0] = x0*cos_theta - x1*sin_theta;
1787-
dst_data[1] = x0*sin_theta + x1*cos_theta;
1788-
}
1789-
} else {
1790-
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1791-
if (ic < n_dims) {
1792-
const int64_t i0 = ic/2;
1793-
1794-
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1795-
1796-
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1816+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
17971817

1798-
float cos_theta, sin_theta;
1799-
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1818+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
18001819

1801-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1802-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
18031821

1804-
const float x0 = src[0];
1805-
const float x1 = src[n_dims/2];
1822+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1823+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
18061824

1807-
dst_data[0] = x0*cos_theta - x1*sin_theta;
1808-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1809-
} else {
1810-
const int64_t i0 = ic;
1825+
const float x0 = src[0];
1826+
const float x1 = src[n_dims/2];
18111827

1812-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1813-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1828+
dst_data[0] = x0*cos_theta - x1*sin_theta;
1829+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1830+
} else {
1831+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1832+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
18141833

1815-
dst_data[0] = src[0];
1816-
dst_data[1] = src[1];
1817-
}
1834+
dst_data[0] = src[0];
1835+
dst_data[1] = src[1];
18181836
}
18191837
}
18201838
}
18211839

1822-
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1823-
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1840+
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
1841+
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
1842+
1843+
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
1844+
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
1845+
1846+
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
1847+
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
18241848

18251849
typedef void (im2col_t)(
18261850
device const float * x,

ggml.c

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14340,14 +14340,10 @@ static void ggml_compute_forward_rope_f32(
1434014340
const bool is_neox = mode & 2;
1434114341

1434214342
const float * freq_factors = NULL;
14343-
if (is_neox) {
14344-
if (src2 != NULL) {
14345-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
14346-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14347-
freq_factors = (const float *) src2->data;
14348-
}
14349-
} else {
14350-
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14343+
if (src2 != NULL) {
14344+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
14345+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14346+
freq_factors = (const float *) src2->data;
1435114347
}
1435214348

1435314349
// backward process uses inverse rotation by cos and sin.
@@ -14473,14 +14469,10 @@ static void ggml_compute_forward_rope_f16(
1447314469
const bool is_neox = mode & 2;
1447414470

1447514471
const float * freq_factors = NULL;
14476-
if (is_neox) {
14477-
if (src2 != NULL) {
14478-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
14479-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14480-
freq_factors = (const float *) src2->data;
14481-
}
14482-
} else {
14483-
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14472+
if (src2 != NULL) {
14473+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
14474+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14475+
freq_factors = (const float *) src2->data;
1448414476
}
1448514477

1448614478
// backward process uses inverse rotation by cos and sin.

tests/test-backend-ops.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,15 +2232,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22322232
for (float ef : { 0.0f, 0.7465f }) {
22332233
for (float af : { 1.0f, 1.4245f }) {
22342234
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2235-
// TODO: ff not supported yet for !neox
2236-
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B
2237-
if (all) {
2238-
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B
2239-
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B
2240-
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B
2241-
}
2242-
22432235
for (bool ff : {false, true}) { // freq_factors
2236+
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
2237+
2238+
if (all) {
2239+
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
2240+
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
2241+
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
2242+
}
2243+
22442244
if (all) {
22452245
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
22462246
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)

0 commit comments

Comments
 (0)