Skip to content

Commit 667d70d

Browse files
PABannierggerganov
authored andcommitted
metal : add GGML_OP_CONV_TRANSPOSE_1D kernels (ggml/1026)
* wip * wip implementation f32 * kernel conv transpose 1d f32 working * initial commit
1 parent 3b4f2e3 commit 667d70d

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
306306
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
307307
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
308308
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
309+
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
310+
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
309311
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
310312
GGML_METAL_KERNEL_TYPE_PAD_F32,
311313
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -870,6 +872,8 @@ @implementation GGMLMetalClass
870872
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
871873
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
872874
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
875+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
876+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
873877
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
874878
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
875879
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
@@ -1069,6 +1073,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
10691073
case GGML_OP_REPEAT:
10701074
case GGML_OP_SCALE:
10711075
case GGML_OP_CLAMP:
1076+
case GGML_OP_CONV_TRANSPOSE_1D:
10721077
return true;
10731078
case GGML_OP_SQR:
10741079
case GGML_OP_SQRT:
@@ -3138,6 +3143,49 @@ static void ggml_metal_encode_node(
31383143
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
31393144
}
31403145
} break;
3146+
case GGML_OP_CONV_TRANSPOSE_1D:
3147+
{
3148+
GGML_ASSERT(ggml_is_contiguous(src0));
3149+
GGML_ASSERT(ggml_is_contiguous(src1));
3150+
GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
3151+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
3152+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
3153+
3154+
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
3155+
3156+
const int32_t IC = src1->ne[1];
3157+
const int32_t IL = src1->ne[0];
3158+
3159+
const int32_t K = src0->ne[0];
3160+
3161+
const int32_t OL = dst->ne[0];
3162+
const int32_t OC = dst->ne[1];
3163+
3164+
id<MTLComputePipelineState> pipeline;
3165+
3166+
switch (src0->type) {
3167+
case GGML_TYPE_F32: {
3168+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
3169+
} break;
3170+
case GGML_TYPE_F16: {
3171+
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
3172+
} break;
3173+
default: GGML_ABORT("fatal error");
3174+
};
3175+
3176+
[encoder setComputePipelineState:pipeline];
3177+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3178+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3179+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3180+
[encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
3181+
[encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
3182+
[encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
3183+
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
3184+
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
3185+
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
3186+
3187+
[encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
3188+
} break;
31413189
case GGML_OP_UPSCALE:
31423190
{
31433191
GGML_ASSERT(src0->type == GGML_TYPE_F32);

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2671,6 +2671,79 @@ kernel void kernel_im2col_ext(
26712671
template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
26722672
template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
26732673

2674+
typedef void (conv_transpose_1d_t)(
2675+
device const float * src0,
2676+
device const float * src1,
2677+
device char * dst,
2678+
constant int32_t & IC,
2679+
constant int32_t & IL,
2680+
constant int32_t & K,
2681+
constant int32_t & s0,
2682+
constant uint64_t & nb0,
2683+
constant uint64_t & nb1,
2684+
uint3 tgpig[[threadgroup_position_in_grid]],
2685+
uint3 tgpg[[threadgroups_per_grid]]);
2686+
2687+
template <typename T>
2688+
kernel void kernel_conv_transpose_1d(
2689+
device const T * src0,
2690+
device const float * src1,
2691+
device char * dst,
2692+
constant int32_t & IC,
2693+
constant int32_t & IL,
2694+
constant int32_t & K,
2695+
constant int32_t & s0,
2696+
constant uint64_t & nb0,
2697+
constant uint64_t & nb1,
2698+
uint3 tgpig[[threadgroup_position_in_grid]],
2699+
uint3 tgpg[[threadgroups_per_grid]]) {
2700+
2701+
float v = 0.0f;
2702+
2703+
for (int64_t c = 0; c < IC; c++) {
2704+
const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
2705+
const int32_t input_offset = c * IL;
2706+
2707+
for (int64_t i = 0; i < IL; i++) {
2708+
if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
2709+
v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
2710+
}
2711+
}
2712+
}
2713+
2714+
device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
2715+
2716+
dst_ptr[0] = v;
2717+
}
2718+
2719+
template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
2720+
kernel void kernel_conv_transpose_1d<float>(
2721+
device const float * src0,
2722+
device const float * src1,
2723+
device char * dst,
2724+
constant int32_t & IC,
2725+
constant int32_t & IL,
2726+
constant int32_t & K,
2727+
constant int32_t & s0,
2728+
constant uint64_t & nb0,
2729+
constant uint64_t & nb1,
2730+
uint3 tgpig[[threadgroup_position_in_grid]],
2731+
uint3 tgpg[[threadgroups_per_grid]]);
2732+
2733+
template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
2734+
kernel void kernel_conv_transpose_1d<half>(
2735+
device const half * src0,
2736+
device const float * src1,
2737+
device char * dst,
2738+
constant int32_t & IC,
2739+
constant int32_t & IL,
2740+
constant int32_t & K,
2741+
constant int32_t & s0,
2742+
constant uint64_t & nb0,
2743+
constant uint64_t & nb1,
2744+
uint3 tgpig[[threadgroup_position_in_grid]],
2745+
uint3 tgpg[[threadgroups_per_grid]]);
2746+
26742747
kernel void kernel_upscale_f32(
26752748
device const char * src0,
26762749
device char * dst,

0 commit comments

Comments
 (0)