Skip to content

Commit 21c4882

Browse files
committed
ggml : add ssm_scan metal impl
ggml-ci
1 parent 7e4e98d commit 21c4882

File tree

4 files changed

+191
-2
lines changed

4 files changed

+191
-2
lines changed

ggml/src/ggml-metal.m

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
8484
GGML_METAL_KERNEL_TYPE_NORM,
8585
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
86+
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
8687
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
8788
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
8889
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
@@ -540,6 +541,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
540541
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
541542
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
542543
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
544+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
543545
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
544546
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
545547
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
@@ -802,6 +804,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
802804
}
803805
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
804806
case GGML_OP_SSM_CONV:
807+
case GGML_OP_SSM_SCAN:
805808
return true;
806809
case GGML_OP_MUL_MAT:
807810
case GGML_OP_MUL_MAT_ID:
@@ -1568,6 +1571,88 @@ static enum ggml_status ggml_metal_graph_compute(
15681571

15691572
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
15701573
} break;
1574+
case GGML_OP_SSM_SCAN:
1575+
{
1576+
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
1577+
struct ggml_tensor * src4 = gf->nodes[i]->src[4];
1578+
struct ggml_tensor * src5 = gf->nodes[i]->src[5];
1579+
1580+
GGML_ASSERT(src3);
1581+
GGML_ASSERT(src4);
1582+
GGML_ASSERT(src5);
1583+
1584+
size_t offs_src3 = 0;
1585+
size_t offs_src4 = 0;
1586+
size_t offs_src5 = 0;
1587+
1588+
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
1589+
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
1590+
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
1591+
1592+
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
1593+
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
1594+
1595+
const uint64_t nb30 = src3->nb[0];
1596+
const uint64_t nb31 = src3->nb[1];
1597+
1598+
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
1599+
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
1600+
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
1601+
1602+
const uint64_t nb40 = src4->nb[0];
1603+
const uint64_t nb41 = src4->nb[1];
1604+
const uint64_t nb42 = src4->nb[2];
1605+
1606+
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
1607+
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
1608+
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
1609+
1610+
const uint64_t nb50 = src5->nb[0];
1611+
const uint64_t nb51 = src5->nb[1];
1612+
const uint64_t nb52 = src5->nb[2];
1613+
1614+
const int64_t d_state = ne00;
1615+
const int64_t d_inner = ne01;
1616+
const int64_t n_seq_tokens = ne11;
1617+
const int64_t n_seqs = ne02;
1618+
1619+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
1620+
1621+
[encoder setComputePipelineState:pipeline];
1622+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1623+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1624+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1625+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
1626+
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
1627+
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
1628+
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
1629+
1630+
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
1631+
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
1632+
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
1633+
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
1634+
1635+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
1636+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
1637+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
1638+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1639+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1640+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1641+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1642+
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
1643+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
1644+
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
1645+
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
1646+
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
1647+
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
1648+
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
1649+
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
1650+
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
1651+
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
1652+
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
1653+
1654+
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1655+
} break;
15711656
case GGML_OP_MUL_MAT:
15721657
{
15731658
GGML_ASSERT(ne00 == ne10);

ggml/src/ggml-metal.metal

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,79 @@ kernel void kernel_ssm_conv_f32(
715715
x[0] = sumf;
716716
}
717717

718+
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
719+
// TODO: optimize
720+
kernel void kernel_ssm_scan_f32(
721+
device const void * src0,
722+
device const void * src1,
723+
device const void * src2,
724+
device const void * src3,
725+
device const void * src4,
726+
device const void * src5,
727+
device float * dst,
728+
constant int64_t & d_state,
729+
constant int64_t & d_inner,
730+
constant int64_t & n_seq_tokens,
731+
constant int64_t & n_seqs,
732+
constant uint64_t & nb00,
733+
constant uint64_t & nb01,
734+
constant uint64_t & nb02,
735+
constant uint64_t & nb10,
736+
constant uint64_t & nb11,
737+
constant uint64_t & nb12,
738+
constant uint64_t & nb13,
739+
constant uint64_t & nb20,
740+
constant uint64_t & nb21,
741+
constant uint64_t & nb22,
742+
constant uint64_t & nb30,
743+
constant uint64_t & nb31,
744+
constant uint64_t & nb40,
745+
constant uint64_t & nb41,
746+
constant uint64_t & nb42,
747+
constant uint64_t & nb50,
748+
constant uint64_t & nb51,
749+
constant uint64_t & nb52,
750+
uint3 tgpig[[threadgroup_position_in_grid]],
751+
uint3 tpitg[[thread_position_in_threadgroup]],
752+
uint3 ntg[[threads_per_threadgroup]]) {
753+
const int64_t ir = tgpig.x;
754+
const int64_t i3 = tgpig.y;
755+
756+
const int64_t nc = d_state;
757+
const int64_t nr = d_inner;
758+
const int64_t n_t = n_seq_tokens;
759+
const int64_t n_s = n_seqs;
760+
761+
for (int64_t i2 = 0; i2 < n_t; ++i2) {
762+
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
763+
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
764+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
765+
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
766+
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
767+
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
768+
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
769+
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
770+
771+
if (i2 > 0) {
772+
s0 = s;
773+
}
774+
775+
// i1 == 0
776+
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
777+
float x_dt = x[0] * dt_soft_plus;
778+
float sumf = 0.0f;
779+
780+
for (int64_t i0 = 0; i0 < nc; ++i0) {
781+
int64_t i = i0;
782+
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
783+
sumf += state * C[i0];
784+
s[i] = state;
785+
}
786+
787+
y[0] = sumf;
788+
}
789+
}
790+
718791
kernel void kernel_norm(
719792
device const void * src0,
720793
device float * dst,

ggml/src/ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15788,8 +15788,8 @@ static void ggml_compute_forward_ssm_scan_f32(
1578815788
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
1578915789
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
1579015790
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
15791-
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15792-
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
15791+
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15792+
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
1579315793

1579415794
// use the output as the source for the next token-wise iterations
1579515795
if (i2 > 0) { s0 = s; }

tests/test-backend-ops.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,35 @@ struct test_ssm_conv : public test_case {
957957
}
958958
};
959959

960+
// GGML_OP_SSM_SCAN
961+
struct test_ssm_scan : public test_case {
962+
const ggml_type type;
963+
964+
const int64_t d_state;
965+
const int64_t d_inner;
966+
const int64_t n_seq_tokens;
967+
const int64_t n_seqs;
968+
969+
std::string vars() override {
970+
return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
971+
}
972+
973+
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
974+
int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
975+
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
976+
977+
ggml_tensor * build_graph(ggml_context * ctx) override {
978+
ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data());
979+
ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
980+
ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
981+
ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data());
982+
ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
983+
ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
984+
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
985+
return out;
986+
}
987+
};
988+
960989
// GGML_OP_MUL_MAT
961990
struct test_mul_mat : public test_case {
962991
const ggml_type type_a;
@@ -2228,6 +2257,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22282257
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
22292258
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
22302259

2260+
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
2261+
22312262
for (ggml_type type_a : base_types) {
22322263
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
22332264
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));

0 commit comments

Comments
 (0)