Skip to content

Commit 3035c2d

Browse files
committed
metal : better rope implementation
ggml-ci
1 parent cbe4f5f commit 3035c2d

File tree

4 files changed

+119
-96
lines changed

4 files changed

+119
-96
lines changed

ggml-metal.m

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,10 @@
172172
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,
173173
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
174174
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
175-
GGML_METAL_KERNEL_TYPE_ROPE_F32,
176-
GGML_METAL_KERNEL_TYPE_ROPE_F16,
175+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
176+
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
177+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
178+
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
177179
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
178180
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
179181
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -626,8 +628,10 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
626628
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32, mul_mm_id_iq1_m_f32, ctx->support_simdgroup_mm);
627629
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
628630
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
629-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
630-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
631+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
632+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
633+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
634+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
631635
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
632636
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
633637
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
@@ -2303,17 +2307,21 @@ static enum ggml_status ggml_metal_graph_compute(
23032307

23042308
const bool is_neox = mode & 2;
23052309

2306-
if (!is_neox) {
2307-
GGML_ASSERT(id_src2 == nil && "TODO: freq_factors not implemented for !is_neox");
2308-
}
2309-
23102310
id<MTLComputePipelineState> pipeline = nil;
23112311

2312-
switch (src0->type) {
2313-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
2314-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
2315-
default: GGML_ASSERT(false);
2316-
};
2312+
if (!is_neox) {
2313+
switch (src0->type) {
2314+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
2315+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
2316+
default: GGML_ASSERT(false);
2317+
};
2318+
} else {
2319+
switch (src0->type) {
2320+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
2321+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
2322+
default: GGML_ASSERT(false);
2323+
};
2324+
}
23172325

23182326
[encoder setComputePipelineState:pipeline];
23192327
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -2342,14 +2350,13 @@ static enum ggml_status ggml_metal_graph_compute(
23422350
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19];
23432351
[encoder setBytes:&n_past length:sizeof( int) atIndex:20];
23442352
[encoder setBytes:&n_dims length:sizeof( int) atIndex:21];
2345-
[encoder setBytes:&mode length:sizeof( int) atIndex:22];
2346-
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:23];
2347-
[encoder setBytes:&freq_base length:sizeof( float) atIndex:24];
2348-
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:25];
2349-
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:26];
2350-
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:27];
2351-
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:28];
2352-
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:29];
2353+
[encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22];
2354+
[encoder setBytes:&freq_base length:sizeof( float) atIndex:23];
2355+
[encoder setBytes:&freq_scale length:sizeof( float) atIndex:24];
2356+
[encoder setBytes:&ext_factor length:sizeof( float) atIndex:25];
2357+
[encoder setBytes:&attn_factor length:sizeof( float) atIndex:26];
2358+
[encoder setBytes:&beta_fast length:sizeof( float) atIndex:27];
2359+
[encoder setBytes:&beta_slow length:sizeof( float) atIndex:28];
23532360

23542361
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
23552362
} break;

ggml-metal.metal

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,8 +1654,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
16541654
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
16551655
static void rope_yarn(
16561656
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
1657-
thread float * cos_theta, thread float * sin_theta
1658-
) {
1657+
thread float * cos_theta, thread float * sin_theta) {
16591658
// Get n-d rotational scaling corrected for extrapolation
16601659
float theta_interp = freq_scale * theta_extrap;
16611660
float theta = theta_interp;
@@ -1684,7 +1683,8 @@ static void rope_yarn_corr_dims(
16841683
dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base)));
16851684
}
16861685

1687-
typedef void (rope_t)(
1686+
template<typename T>
1687+
kernel void kernel_rope_norm(
16881688
device const void * src0,
16891689
device const int32_t * src1,
16901690
device const float * src2,
@@ -1707,7 +1707,6 @@ typedef void (rope_t)(
17071707
constant uint64_t & nb3,
17081708
constant int & n_past,
17091709
constant int & n_dims,
1710-
constant int & mode,
17111710
constant int & n_orig_ctx,
17121711
constant float & freq_base,
17131712
constant float & freq_scale,
@@ -1717,10 +1716,52 @@ typedef void (rope_t)(
17171716
constant float & beta_slow,
17181717
uint tiitg[[thread_index_in_threadgroup]],
17191718
uint3 tptg[[threads_per_threadgroup]],
1720-
uint3 tgpig[[threadgroup_position_in_grid]]);
1719+
uint3 tgpig[[threadgroup_position_in_grid]]) {
1720+
const int64_t i3 = tgpig[2];
1721+
const int64_t i2 = tgpig[1];
1722+
const int64_t i1 = tgpig[0];
1723+
1724+
float corr_dims[2];
1725+
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1726+
1727+
device const int32_t * pos = src1;
1728+
1729+
const float theta_base = (float) pos[i2];
1730+
const float inv_ndims = -1.f/n_dims;
1731+
1732+
float cos_theta;
1733+
float sin_theta;
1734+
1735+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1736+
if (i0 < n_dims) {
1737+
const int64_t ic = i0/2;
1738+
1739+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1740+
1741+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
1742+
1743+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1744+
1745+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1746+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1747+
1748+
const float x0 = src[0];
1749+
const float x1 = src[1];
1750+
1751+
dst_data[0] = x0*cos_theta - x1*sin_theta;
1752+
dst_data[1] = x0*sin_theta + x1*cos_theta;
1753+
} else {
1754+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1755+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1756+
1757+
dst_data[0] = src[0];
1758+
dst_data[1] = src[1];
1759+
}
1760+
}
1761+
}
17211762

17221763
template<typename T>
1723-
kernel void kernel_rope(
1764+
kernel void kernel_rope_neox(
17241765
device const void * src0,
17251766
device const int32_t * src1,
17261767
device const float * src2,
@@ -1743,7 +1784,6 @@ kernel void kernel_rope(
17431784
constant uint64_t & nb3,
17441785
constant int & n_past,
17451786
constant int & n_dims,
1746-
constant int & mode,
17471787
constant int & n_orig_ctx,
17481788
constant float & freq_base,
17491789
constant float & freq_scale,
@@ -1758,69 +1798,53 @@ kernel void kernel_rope(
17581798
const int64_t i2 = tgpig[1];
17591799
const int64_t i1 = tgpig[0];
17601800

1761-
const bool is_neox = mode & 2;
1762-
17631801
float corr_dims[2];
17641802
rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
17651803

17661804
device const int32_t * pos = src1;
17671805

1768-
const int64_t p = pos[i2];
1769-
1770-
const float theta_base = (float)p;
1806+
const float theta_base = (float) pos[i2];
17711807
const float inv_ndims = -1.f/n_dims;
17721808

1773-
if (!is_neox) {
1774-
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775-
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
1809+
float cos_theta;
1810+
float sin_theta;
17761811

1777-
float cos_theta, sin_theta;
1778-
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
1812+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1813+
if (i0 < n_dims) {
1814+
const int64_t ic = i0/2;
17791815

1780-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1781-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1782-
1783-
const T x0 = src[0];
1784-
const T x1 = src[1];
1785-
1786-
dst_data[0] = x0*cos_theta - x1*sin_theta;
1787-
dst_data[1] = x0*sin_theta + x1*cos_theta;
1788-
}
1789-
} else {
1790-
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
1791-
if (ic < n_dims) {
1792-
const int64_t i0 = ic/2;
1793-
1794-
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
1795-
1796-
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
1816+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
17971817

1798-
float cos_theta, sin_theta;
1799-
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
1818+
const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
18001819

1801-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1802-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1820+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
18031821

1804-
const float x0 = src[0];
1805-
const float x1 = src[n_dims/2];
1822+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
1823+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
18061824

1807-
dst_data[0] = x0*cos_theta - x1*sin_theta;
1808-
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1809-
} else {
1810-
const int64_t i0 = ic;
1825+
const float x0 = src[0];
1826+
const float x1 = src[n_dims/2];
18111827

1812-
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1813-
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1828+
dst_data[0] = x0*cos_theta - x1*sin_theta;
1829+
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
1830+
} else {
1831+
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1832+
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
18141833

1815-
dst_data[0] = src[0];
1816-
dst_data[1] = src[1];
1817-
}
1834+
dst_data[0] = src[0];
1835+
dst_data[1] = src[1];
18181836
}
18191837
}
18201838
}
18211839

1822-
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
1823-
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
1840+
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
1841+
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
1842+
1843+
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
1844+
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
1845+
1846+
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
1847+
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
18241848

18251849
typedef void (im2col_t)(
18261850
device const float * x,

ggml.c

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14341,14 +14341,10 @@ static void ggml_compute_forward_rope_f32(
1434114341
const bool is_neox = mode & 2;
1434214342

1434314343
const float * freq_factors = NULL;
14344-
if (is_neox) {
14345-
if (src2 != NULL) {
14346-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
14347-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14348-
freq_factors = (const float *) src2->data;
14349-
}
14350-
} else {
14351-
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14344+
if (src2 != NULL) {
14345+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
14346+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14347+
freq_factors = (const float *) src2->data;
1435214348
}
1435314349

1435414350
// backward process uses inverse rotation by cos and sin.
@@ -14474,14 +14470,10 @@ static void ggml_compute_forward_rope_f16(
1447414470
const bool is_neox = mode & 2;
1447514471

1447614472
const float * freq_factors = NULL;
14477-
if (is_neox) {
14478-
if (src2 != NULL) {
14479-
GGML_ASSERT(src2->type == GGML_TYPE_F32);
14480-
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14481-
freq_factors = (const float *) src2->data;
14482-
}
14483-
} else {
14484-
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14473+
if (src2 != NULL) {
14474+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
14475+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14476+
freq_factors = (const float *) src2->data;
1448514477
}
1448614478

1448714479
// backward process uses inverse rotation by cos and sin.

tests/test-backend-ops.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2236,15 +2236,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22362236
for (float ef : { 0.0f, 0.7465f }) {
22372237
for (float af : { 1.0f, 1.4245f }) {
22382238
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2239-
// TODO: ff not supported yet for !neox
2240-
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 7B
2241-
if (all) {
2242-
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 13B
2243-
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 30B
2244-
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false, v)); // llama 65B
2245-
}
2246-
22472239
for (bool ff : {false, true}) { // freq_factors
2240+
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 7B
2241+
2242+
if (all) {
2243+
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 13B
2244+
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 30B
2245+
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, ff, v)); // llama 65B
2246+
}
2247+
22482248
if (all) {
22492249
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)
22502250
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff, v)); // neox (falcon 7B)

0 commit comments

Comments
 (0)