@@ -127,6 +127,8 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
127
127
GGML_METAL_KERNEL_TYPE_SCALE_4,
128
128
GGML_METAL_KERNEL_TYPE_CLAMP,
129
129
GGML_METAL_KERNEL_TYPE_TANH,
130
+ GGML_METAL_KERNEL_TYPE_EXP,
131
+ GGML_METAL_KERNEL_TYPE_NEG,
130
132
GGML_METAL_KERNEL_TYPE_RELU,
131
133
GGML_METAL_KERNEL_TYPE_SIGMOID,
132
134
GGML_METAL_KERNEL_TYPE_GELU,
@@ -734,6 +736,8 @@ @implementation GGMLMetalClass
734
736
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true );
735
737
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_CLAMP, clamp, true );
736
738
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_TANH, tanh, true );
739
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_EXP, exp, true );
740
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_NEG, neg, true );
737
741
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RELU, relu, true );
738
742
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true );
739
743
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GELU, gelu, true );
@@ -1173,6 +1177,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1173
1177
case GGML_UNARY_OP_GELU_QUICK:
1174
1178
case GGML_UNARY_OP_SILU:
1175
1179
case GGML_UNARY_OP_ELU:
1180
+ case GGML_UNARY_OP_EXP:
1181
+ case GGML_UNARY_OP_NEG:
1176
1182
return ggml_is_contiguous (op->src [0 ]);
1177
1183
default :
1178
1184
return false ;
@@ -1739,6 +1745,30 @@ static void ggml_metal_encode_node(
1739
1745
1740
1746
[encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1741
1747
} break ;
1748
+ case GGML_UNARY_OP_EXP:
1749
+ {
1750
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_EXP].pipeline ;
1751
+
1752
+ [encoder setComputePipelineState: pipeline];
1753
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1754
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1755
+
1756
+ const int64_t n = ggml_nelements (dst);
1757
+
1758
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1759
+ } break ;
1760
+ case GGML_UNARY_OP_NEG:
1761
+ {
1762
+ id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_NEG].pipeline ;
1763
+
1764
+ [encoder setComputePipelineState: pipeline];
1765
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1766
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1767
+
1768
+ const int64_t n = ggml_nelements (dst);
1769
+
1770
+ [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
1771
+ } break ;
1742
1772
case GGML_UNARY_OP_RELU:
1743
1773
{
1744
1774
id <MTLComputePipelineState > pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_RELU].pipeline ;
0 commit comments