Skip to content

Commit df84919

Browse files
authored
ggml : add mrope kernel for metal (#13457)
1 parent 1449214 commit df84919

File tree

3 files changed

+192
-16
lines changed

3 files changed

+192
-16
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ typedef struct {
207207
float attn_factor;
208208
float beta_fast;
209209
float beta_slow;
210+
int32_t sect_0;
211+
int32_t sect_1;
212+
int32_t sect_2;
213+
int32_t sect_3;
210214
} ggml_metal_kargs_rope;
211215

212216
typedef struct {

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
332332
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,
333333
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,
334334
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
335+
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
336+
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
337+
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
338+
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
335339
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
336340
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
337341
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
@@ -1275,6 +1279,10 @@ @implementation GGMLMetalClass
12751279
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm);
12761280
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true);
12771281
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
1282+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
1283+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
1284+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
1285+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
12781286
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
12791287
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
12801288
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
@@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16371645
case GGML_OP_NORM:
16381646
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
16391647
case GGML_OP_ROPE:
1640-
{
1641-
const int mode = ((const int32_t *) op->op_params)[2];
1642-
if (mode & GGML_ROPE_TYPE_MROPE) {
1643-
return false;
1644-
}
1645-
if (mode & GGML_ROPE_TYPE_VISION) {
1646-
return false;
1647-
}
1648-
return true;
1649-
}
1648+
return true;
16501649
case GGML_OP_IM2COL:
16511650
return op->src[0]->type == GGML_TYPE_F16;
16521651
case GGML_OP_POOL_1D:
@@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node(
38263825
} break;
38273826
case GGML_OP_ROPE:
38283827
{
3828+
38293829
// make sure we have one or more position id(ne10) per token(ne02)
38303830
GGML_ASSERT(ne10 % ne02 == 0);
38313831
GGML_ASSERT(ne10 >= ne02);
@@ -3852,20 +3852,42 @@ static bool ggml_metal_encode_node(
38523852
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
38533853
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
38543854

3855-
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
3855+
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
3856+
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
3857+
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
3858+
3859+
// mrope
3860+
const int sect_0 = ((const int32_t *) dst->op_params)[11];
3861+
const int sect_1 = ((const int32_t *) dst->op_params)[12];
3862+
const int sect_2 = ((const int32_t *) dst->op_params)[13];
3863+
const int sect_3 = ((const int32_t *) dst->op_params)[14];
38563864

38573865
id<MTLComputePipelineState> pipeline = nil;
38583866

3859-
if (!is_neox) {
3867+
if (is_neox) {
38603868
switch (src0->type) {
3861-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3862-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
3869+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3870+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3871+
default: GGML_ABORT("fatal error");
3872+
};
3873+
} else if (is_mrope && !is_vision) {
3874+
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3875+
switch (src0->type) {
3876+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
3877+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
3878+
default: GGML_ABORT("fatal error");
3879+
};
3880+
} else if (is_vision) {
3881+
GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token
3882+
switch (src0->type) {
3883+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
3884+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
38633885
default: GGML_ABORT("fatal error");
38643886
};
38653887
} else {
38663888
switch (src0->type) {
3867-
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
3868-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
3889+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
3890+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
38693891
default: GGML_ABORT("fatal error");
38703892
};
38713893
}
@@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node(
38963918
/*.attn_factor =*/ attn_factor,
38973919
/*.beta_fast =*/ beta_fast,
38983920
/*.beta_slow =*/ beta_slow,
3921+
/* sect_0 =*/ sect_0,
3922+
/* sect_1 =*/ sect_1,
3923+
/* sect_2 =*/ sect_2,
3924+
/* sect_3 =*/ sect_3,
38993925
};
39003926

39013927
[encoder setComputePipelineState:pipeline];

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2713,15 +2713,161 @@ kernel void kernel_rope_neox(
27132713
}
27142714
}
27152715

2716+
template<typename T>
2717+
kernel void kernel_rope_multi(
2718+
constant ggml_metal_kargs_rope & args,
2719+
device const char * src0,
2720+
device const char * src1,
2721+
device const char * src2,
2722+
device char * dst,
2723+
ushort tiitg[[thread_index_in_threadgroup]],
2724+
ushort3 tptg [[threads_per_threadgroup]],
2725+
uint3 tgpig[[threadgroup_position_in_grid]]) {
2726+
const int i3 = tgpig[2];
2727+
const int i2 = tgpig[1];
2728+
const int i1 = tgpig[0];
2729+
2730+
float corr_dims[2];
2731+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
2732+
2733+
device const int32_t * pos = (device const int32_t *) src1;
2734+
2735+
const float inv_ndims = -1.f/args.n_dims;
2736+
2737+
float cos_theta;
2738+
float sin_theta;
2739+
2740+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2741+
if (i0 < args.n_dims) {
2742+
const int ic = i0/2;
2743+
2744+
// mrope theta calculations
2745+
// note: the rest is the same as kernel_rope_neox
2746+
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
2747+
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
2748+
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
2749+
const int sector = ic % sect_dims;
2750+
2751+
float theta_base;
2752+
if (sector < args.sect_0) {
2753+
theta_base = (float) pos[i2];
2754+
} else if (sector < sec_w01) {
2755+
theta_base = (float) pos[i2 + args.ne02];
2756+
} else if (sector < sec_w012) {
2757+
theta_base = (float) pos[i2 + args.ne02 * 2];
2758+
} else {
2759+
theta_base = (float) pos[i2 + args.ne02 * 3];
2760+
}
2761+
// end of mrope
2762+
2763+
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
2764+
2765+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
2766+
2767+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
2768+
2769+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
2770+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
2771+
2772+
const float x0 = src[0];
2773+
const float x1 = src[args.n_dims/2];
2774+
2775+
dst_data[0] = x0*cos_theta - x1*sin_theta;
2776+
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
2777+
} else {
2778+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2779+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2780+
2781+
dst_data[0] = src[0];
2782+
dst_data[1] = src[1];
2783+
}
2784+
}
2785+
}
2786+
2787+
template<typename T>
2788+
kernel void kernel_rope_vision(
2789+
constant ggml_metal_kargs_rope & args,
2790+
device const char * src0,
2791+
device const char * src1,
2792+
device const char * src2,
2793+
device char * dst,
2794+
ushort tiitg[[thread_index_in_threadgroup]],
2795+
ushort3 tptg [[threads_per_threadgroup]],
2796+
uint3 tgpig[[threadgroup_position_in_grid]]) {
2797+
const int i3 = tgpig[2];
2798+
const int i2 = tgpig[1];
2799+
const int i1 = tgpig[0];
2800+
2801+
float corr_dims[2];
2802+
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
2803+
2804+
device const int32_t * pos = (device const int32_t *) src1;
2805+
2806+
const float inv_ndims = -1.f/args.n_dims;
2807+
2808+
float cos_theta;
2809+
float sin_theta;
2810+
2811+
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
2812+
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
2813+
const int ic = i0/2;
2814+
2815+
// mrope theta calculations (only support 2 dimensions)
2816+
const int sect_dims = args.sect_0 + args.sect_1;
2817+
const int sector = ic % sect_dims;
2818+
2819+
float p;
2820+
float theta_base;
2821+
if (sector < args.sect_1) {
2822+
p = (float) sector;
2823+
theta_base = (float) pos[i2];
2824+
} else {
2825+
p = (float) sector - args.sect_0;
2826+
theta_base = (float) pos[i2 + args.ne02];
2827+
}
2828+
2829+
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
2830+
// end of mrope
2831+
2832+
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
2833+
2834+
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
2835+
2836+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
2837+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
2838+
2839+
const float x0 = src[0];
2840+
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
2841+
2842+
dst_data[0] = x0*cos_theta - x1*sin_theta;
2843+
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
2844+
} else {
2845+
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
2846+
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
2847+
2848+
dst_data[0] = src[0];
2849+
dst_data[1] = src[1];
2850+
}
2851+
}
2852+
}
2853+
27162854
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
27172855
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
2856+
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
2857+
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
27182858

27192859
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
27202860
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
27212861

27222862
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
27232863
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
27242864

2865+
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
2866+
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
2867+
2868+
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
2869+
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
2870+
27252871
typedef void (im2col_t)(
27262872
device const float * x,
27272873
device char * dst,

0 commit comments

Comments
 (0)