@@ -134,6 +134,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
134
134
GGML_METAL_KERNEL_TYPE_SCALE_4,
135
135
GGML_METAL_KERNEL_TYPE_CLAMP,
136
136
GGML_METAL_KERNEL_TYPE_TANH,
137
+ GGML_METAL_KERNEL_TYPE_EXP,
138
+ GGML_METAL_KERNEL_TYPE_NEG,
137
139
GGML_METAL_KERNEL_TYPE_RELU,
138
140
GGML_METAL_KERNEL_TYPE_SIGMOID,
139
141
GGML_METAL_KERNEL_TYPE_GELU,
@@ -741,6 +743,8 @@ @implementation GGMLMetalClass
741
743
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true );
742
744
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true );
743
745
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TANH, tanh, true );
746
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_EXP, exp, true );
747
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
744
748
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RELU, relu, true );
745
749
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true );
746
750
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU, gelu, true );
@@ -1180,6 +1184,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1180
1184
case GGML_UNARY_OP_GELU_QUICK:
1181
1185
case GGML_UNARY_OP_SILU:
1182
1186
case GGML_UNARY_OP_ELU:
1187
+ case GGML_UNARY_OP_EXP:
1188
+ case GGML_UNARY_OP_NEG:
1183
1189
return ggml_is_contiguous (op->src [0 ]);
1184
1190
default :
1185
1191
return false ;
@@ -1747,6 +1753,30 @@ static void ggml_metal_encode_node(
1747
1753
1748
1754
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1749
1755
} break ;
1756
+ case GGML_UNARY_OP_EXP:
1757
+ {
1758
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_EXP].pipeline ;
1759
+
1760
+ [encoder setComputePipelineState: pipeline];
1761
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1762
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1763
+
1764
+ const int64_t n = ggml_nelements (dst);
1765
+
1766
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1767
+ } break ;
1768
+ case GGML_UNARY_OP_NEG:
1769
+ {
1770
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_NEG].pipeline ;
1771
+
1772
+ [encoder setComputePipelineState: pipeline];
1773
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1774
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1775
+
1776
+ const int64_t n = ggml_nelements (dst);
1777
+
1778
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1779
+ } break ;
1750
1780
case GGML_UNARY_OP_RELU:
1751
1781
{
1752
1782
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RELU].pipeline ;
0 commit comments