|
82 | 82 | GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
83 | 83 | GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
84 | 84 | GGML_METAL_KERNEL_TYPE_NORM,
|
| 85 | + GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, |
85 | 86 | GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
86 | 87 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
87 | 88 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
@@ -542,6 +543,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
542 | 543 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
543 | 544 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
544 | 545 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
| 546 | + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); |
545 | 547 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
546 | 548 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
547 | 549 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -803,6 +805,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
803 | 805 | return false;
|
804 | 806 | }
|
805 | 807 | return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
| 808 | + case GGML_OP_SSM_CONV: |
| 809 | + return true; |
806 | 810 | case GGML_OP_MUL_MAT:
|
807 | 811 | case GGML_OP_MUL_MAT_ID:
|
808 | 812 | return ctx->support_simdgroup_reduction &&
|
@@ -1538,6 +1542,39 @@ static enum ggml_status ggml_metal_graph_compute(
|
1538 | 1542 | [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1539 | 1543 | }
|
1540 | 1544 | } break;
|
| 1545 | + case GGML_OP_SSM_CONV: |
| 1546 | + { |
| 1547 | + GGML_ASSERT(src0t == GGML_TYPE_F32); |
| 1548 | + GGML_ASSERT(src1t == GGML_TYPE_F32); |
| 1549 | + |
| 1550 | + GGML_ASSERT(ggml_is_contiguous(src0)); |
| 1551 | + GGML_ASSERT(ggml_is_contiguous(src1)); |
| 1552 | + |
| 1553 | + id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; |
| 1554 | + |
| 1555 | + [encoder setComputePipelineState:pipeline]; |
| 1556 | + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
| 1557 | + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
| 1558 | + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; |
| 1559 | + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |
| 1560 | + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; |
| 1561 | + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; |
| 1562 | + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; |
| 1563 | + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; |
| 1564 | + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; |
| 1565 | + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; |
| 1566 | + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; |
| 1567 | + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; |
| 1568 | + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; |
| 1569 | + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; |
| 1570 | + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; |
| 1571 | + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; |
| 1572 | + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; |
| 1573 | + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:17]; |
| 1574 | + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; |
| 1575 | + |
| 1576 | + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
| 1577 | + } break; |
1541 | 1578 | case GGML_OP_MUL_MAT:
|
1542 | 1579 | {
|
1543 | 1580 | GGML_ASSERT(ne00 == ne10);
|
|
0 commit comments