@@ -306,6 +306,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
306
306
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
307
307
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
308
308
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
309
+ GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
310
+ GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
309
311
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
310
312
GGML_METAL_KERNEL_TYPE_PAD_F32,
311
313
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -870,6 +872,8 @@ @implementation GGMLMetalClass
870
872
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true );
871
873
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true );
872
874
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true );
875
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true );
876
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true );
873
877
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true );
874
878
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true );
875
879
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true );
@@ -1069,6 +1073,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1069
1073
case GGML_OP_REPEAT:
1070
1074
case GGML_OP_SCALE:
1071
1075
case GGML_OP_CLAMP:
1076
+ case GGML_OP_CONV_TRANSPOSE_1D:
1072
1077
return true ;
1073
1078
case GGML_OP_SQR:
1074
1079
case GGML_OP_SQRT:
@@ -3138,6 +3143,49 @@ static void ggml_metal_encode_node(
3138
3143
[encoder dispatchThreadgroups: MTLSizeMake (IC, OH, OW) threadsPerThreadgroup: MTLSizeMake (N, KH, KW)];
3139
3144
}
3140
3145
} break ;
3146
+ case GGML_OP_CONV_TRANSPOSE_1D:
3147
+ {
3148
+ GGML_ASSERT (ggml_is_contiguous (src0));
3149
+ GGML_ASSERT (ggml_is_contiguous (src1));
3150
+ GGML_ASSERT (src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
3151
+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
3152
+ GGML_ASSERT ( dst->type == GGML_TYPE_F32);
3153
+
3154
+ const int32_t s0 = ((const int32_t *)(dst->op_params ))[0 ];
3155
+
3156
+ const int32_t IC = src1->ne [1 ];
3157
+ const int32_t IL = src1->ne [0 ];
3158
+
3159
+ const int32_t K = src0->ne [0 ];
3160
+
3161
+ const int32_t OL = dst->ne [0 ];
3162
+ const int32_t OC = dst->ne [1 ];
3163
+
3164
+ id <MTLComputePipelineState > pipeline;
3165
+
3166
+ switch (src0->type ) {
3167
+ case GGML_TYPE_F32: {
3168
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline ;
3169
+ } break ;
3170
+ case GGML_TYPE_F16: {
3171
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline ;
3172
+ } break ;
3173
+ default : GGML_ABORT (" fatal error" );
3174
+ };
3175
+
3176
+ [encoder setComputePipelineState: pipeline];
3177
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3178
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3179
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
3180
+ [encoder setBytes: &IC length: sizeof ( int32_t ) atIndex: 3 ];
3181
+ [encoder setBytes: &IL length: sizeof ( int32_t ) atIndex: 4 ];
3182
+ [encoder setBytes: &K length: sizeof ( int32_t ) atIndex: 5 ];
3183
+ [encoder setBytes: &s0 length: sizeof ( int32_t ) atIndex: 6 ];
3184
+ [encoder setBytes: &nb0 length: sizeof (uint64_t ) atIndex: 7 ];
3185
+ [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 8 ];
3186
+
3187
+ [encoder dispatchThreadgroups: MTLSizeMake (OL, OC, 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
3188
+ } break ;
3141
3189
case GGML_OP_UPSCALE:
3142
3190
{
3143
3191
GGML_ASSERT (src0->type == GGML_TYPE_F32);
0 commit comments