@@ -2712,9 +2712,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2712
2712
2713
2713
"FLASH_ATTN" ,
2714
2714
"FLASH_FF" ,
2715
+
2716
+ "MAP_UNARY" ,
2717
+ "MAP_BINARY" ,
2715
2718
};
2716
2719
2717
- static_assert (GGML_OP_COUNT == 36 , "GGML_OP_COUNT != 36 " );
2720
+ static_assert (GGML_OP_COUNT == 38 , "GGML_OP_COUNT != 38 " );
2718
2721
2719
2722
static const char * GGML_OP_SYMBOL [GGML_OP_COUNT ] = {
2720
2723
"none" ,
@@ -2757,9 +2760,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2757
2760
2758
2761
"flash_attn(x)" ,
2759
2762
"flash_ff(x)" ,
2763
+
2764
+ "f(x)" ,
2765
+ "f(x,y)" ,
2760
2766
};
2761
2767
2762
- static_assert (GGML_OP_COUNT == 36 , "GGML_OP_COUNT != 36 " );
2768
+ static_assert (GGML_OP_COUNT == 38 , "GGML_OP_COUNT != 38 " );
2763
2769
2764
2770
static_assert (sizeof (struct ggml_object )%GGML_MEM_ALIGN == 0 , "ggml_object size must be a multiple of GGML_MEM_ALIGN" );
2765
2771
static_assert (sizeof (struct ggml_tensor )%GGML_MEM_ALIGN == 0 , "ggml_tensor size must be a multiple of GGML_MEM_ALIGN" );
@@ -4907,6 +4913,90 @@ struct ggml_tensor * ggml_flash_ff(
4907
4913
return result ;
4908
4914
}
4909
4915
4916
+ // ggml_map_unary
4917
+
4918
+ struct ggml_tensor * ggml_map_unary_impl_f32 (
4919
+ struct ggml_context * ctx ,
4920
+ struct ggml_tensor * a ,
4921
+ const ggml_unary_op_f32_t fun ,
4922
+ bool inplace ) {
4923
+ bool is_node = false;
4924
+
4925
+ if (!inplace && a -> grad ) {
4926
+ is_node = true;
4927
+ }
4928
+
4929
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d (ctx , GGML_TYPE_I32 , sizeof (void * ) / sizeof (int32_t ));
4930
+ * ((void (* * )(void ))addr_tensor -> data ) = (void (* )(void ))fun ;
4931
+ struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
4932
+
4933
+ result -> op = GGML_OP_MAP_UNARY ;
4934
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
4935
+ result -> src0 = a ;
4936
+ result -> opt [0 ] = addr_tensor ;
4937
+
4938
+ return result ;
4939
+ }
4940
+
4941
+ struct ggml_tensor * ggml_map_unary_f32 (
4942
+ struct ggml_context * ctx ,
4943
+ struct ggml_tensor * a ,
4944
+ const ggml_unary_op_f32_t fun ) {
4945
+ return ggml_map_unary_impl_f32 (ctx , a , fun , false);
4946
+ }
4947
+
4948
+ struct ggml_tensor * ggml_map_unary_inplace_f32 (
4949
+ struct ggml_context * ctx ,
4950
+ struct ggml_tensor * a ,
4951
+ const ggml_unary_op_f32_t fun ) {
4952
+ return ggml_map_unary_impl_f32 (ctx , a , fun , true);
4953
+ }
4954
+
4955
+ // ggml_map_binary
4956
+
4957
+ struct ggml_tensor * ggml_map_binary_impl_f32 (
4958
+ struct ggml_context * ctx ,
4959
+ struct ggml_tensor * a ,
4960
+ struct ggml_tensor * b ,
4961
+ const ggml_binary_op_f32_t fun ,
4962
+ bool inplace ) {
4963
+ GGML_ASSERT (ggml_are_same_shape (a , b ));
4964
+
4965
+ bool is_node = false;
4966
+
4967
+ if (!inplace && (a -> grad || b -> grad )) {
4968
+ is_node = true;
4969
+ }
4970
+
4971
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d (ctx , GGML_TYPE_I32 , sizeof (void * ) / sizeof (int32_t ));
4972
+ * ((void (* * )(void ))addr_tensor -> data ) = (void (* )(void ))fun ;
4973
+ struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
4974
+
4975
+ result -> op = GGML_OP_MAP_BINARY ;
4976
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
4977
+ result -> src0 = a ;
4978
+ result -> src1 = b ;
4979
+ result -> opt [0 ] = addr_tensor ;
4980
+
4981
+ return result ;
4982
+ }
4983
+
4984
+ struct ggml_tensor * ggml_map_binary_f32 (
4985
+ struct ggml_context * ctx ,
4986
+ struct ggml_tensor * a ,
4987
+ struct ggml_tensor * b ,
4988
+ const ggml_binary_op_f32_t fun ) {
4989
+ return ggml_map_binary_impl_f32 (ctx , a , b , fun , false);
4990
+ }
4991
+
4992
+ struct ggml_tensor * ggml_map_binary_inplace_f32 (
4993
+ struct ggml_context * ctx ,
4994
+ struct ggml_tensor * a ,
4995
+ struct ggml_tensor * b ,
4996
+ const ggml_binary_op_f32_t fun ) {
4997
+ return ggml_map_binary_impl_f32 (ctx , a , b , fun , true);
4998
+ }
4999
+
4910
5000
////////////////////////////////////////////////////////////////////////////////
4911
5001
4912
5002
void ggml_set_param (
@@ -8875,6 +8965,111 @@ static void ggml_compute_forward_flash_ff(
8875
8965
}
8876
8966
}
8877
8967
8968
+ // ggml_compute_forward_map_unary
8969
+
8970
+ static void ggml_compute_forward_map_unary_f32 (
8971
+ const struct ggml_compute_params * params ,
8972
+ const struct ggml_tensor * src0 ,
8973
+ struct ggml_tensor * dst ,
8974
+ const ggml_unary_op_f32_t fun ) {
8975
+ GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
8976
+
8977
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
8978
+ return ;
8979
+ }
8980
+
8981
+ const int n = ggml_nrows (src0 );
8982
+ const int nc = src0 -> ne [0 ];
8983
+
8984
+ assert ( dst -> nb [0 ] == sizeof (float ));
8985
+ assert (src0 -> nb [0 ] == sizeof (float ));
8986
+
8987
+ for (int i = 0 ; i < n ; i ++ ) {
8988
+ fun (nc ,
8989
+ (float * ) ((char * ) dst -> data + i * ( dst -> nb [1 ])),
8990
+ (float * ) ((char * ) src0 -> data + i * (src0 -> nb [1 ])));
8991
+ }
8992
+ }
8993
+
8994
+
8995
+ static void ggml_compute_forward_map_unary (
8996
+ const struct ggml_compute_params * params ,
8997
+ const struct ggml_tensor * src0 ,
8998
+ struct ggml_tensor * dst ,
8999
+ const ggml_unary_op_f32_t fun ) {
9000
+ switch (src0 -> type ) {
9001
+ case GGML_TYPE_F32 :
9002
+ {
9003
+ ggml_compute_forward_map_unary_f32 (params , src0 , dst , fun );
9004
+ } break ;
9005
+ case GGML_TYPE_Q4_0 :
9006
+ case GGML_TYPE_Q4_1 :
9007
+ case GGML_TYPE_I8 :
9008
+ case GGML_TYPE_I16 :
9009
+ case GGML_TYPE_I32 :
9010
+ case GGML_TYPE_F16 :
9011
+ case GGML_TYPE_COUNT :
9012
+ {
9013
+ GGML_ASSERT (false);
9014
+ } break ;
9015
+ }
9016
+ }
9017
+
9018
+ // ggml_compute_forward_map_binary
9019
+
9020
+ static void ggml_compute_forward_map_binary_f32 (
9021
+ const struct ggml_compute_params * params ,
9022
+ const struct ggml_tensor * src0 ,
9023
+ const struct ggml_tensor * src1 ,
9024
+ struct ggml_tensor * dst ,
9025
+ const ggml_binary_op_f32_t fun ) {
9026
+ assert (params -> ith == 0 );
9027
+ assert (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
9028
+
9029
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
9030
+ return ;
9031
+ }
9032
+
9033
+ const int n = ggml_nrows (src0 );
9034
+ const int nc = src0 -> ne [0 ];
9035
+
9036
+ assert ( dst -> nb [0 ] == sizeof (float ));
9037
+ assert (src0 -> nb [0 ] == sizeof (float ));
9038
+ assert (src1 -> nb [0 ] == sizeof (float ));
9039
+
9040
+ for (int i = 0 ; i < n ; i ++ ) {
9041
+ fun (nc ,
9042
+ (float * ) ((char * ) dst -> data + i * ( dst -> nb [1 ])),
9043
+ (float * ) ((char * ) src0 -> data + i * (src0 -> nb [1 ])),
9044
+ (float * ) ((char * ) src1 -> data + i * (src1 -> nb [1 ])));
9045
+ }
9046
+ }
9047
+
9048
+
9049
+ static void ggml_compute_forward_map_binary (
9050
+ const struct ggml_compute_params * params ,
9051
+ const struct ggml_tensor * src0 ,
9052
+ const struct ggml_tensor * src1 ,
9053
+ struct ggml_tensor * dst ,
9054
+ const ggml_binary_op_f32_t fun ) {
9055
+ switch (src0 -> type ) {
9056
+ case GGML_TYPE_F32 :
9057
+ {
9058
+ ggml_compute_forward_map_binary_f32 (params , src0 , src1 , dst , fun );
9059
+ } break ;
9060
+ case GGML_TYPE_Q4_0 :
9061
+ case GGML_TYPE_Q4_1 :
9062
+ case GGML_TYPE_I8 :
9063
+ case GGML_TYPE_I16 :
9064
+ case GGML_TYPE_I32 :
9065
+ case GGML_TYPE_F16 :
9066
+ case GGML_TYPE_COUNT :
9067
+ {
9068
+ GGML_ASSERT (false);
9069
+ } break ;
9070
+ }
9071
+ }
9072
+
8878
9073
/////////////////////////////////
8879
9074
8880
9075
static void ggml_compute_forward (struct ggml_compute_params * params , struct ggml_tensor * tensor ) {
@@ -9024,6 +9219,18 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
9024
9219
{
9025
9220
ggml_compute_forward_flash_ff (params , tensor -> src0 , tensor -> src1 , tensor -> opt [0 ], tensor -> opt [1 ], tensor -> opt [2 ], tensor );
9026
9221
} break ;
9222
+ case GGML_OP_MAP_UNARY :
9223
+ {
9224
+ const ggml_unary_op_f32_t fun = * ((ggml_unary_op_f32_t * )tensor -> opt [0 ]-> data );
9225
+ ggml_compute_forward_map_unary (params , tensor -> src0 , tensor , fun );
9226
+ }
9227
+ break ;
9228
+ case GGML_OP_MAP_BINARY :
9229
+ {
9230
+ const ggml_binary_op_f32_t fun = * ((ggml_binary_op_f32_t * )tensor -> opt [0 ]-> data );
9231
+ ggml_compute_forward_map_binary (params , tensor -> src0 , tensor -> src1 , tensor , fun );
9232
+ }
9233
+ break ;
9027
9234
case GGML_OP_NONE :
9028
9235
{
9029
9236
// nop
@@ -9283,6 +9490,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
9283
9490
{
9284
9491
GGML_ASSERT (false); // not supported
9285
9492
} break ;
9493
+ case GGML_OP_MAP_UNARY :
9494
+ case GGML_OP_MAP_BINARY :
9495
+ {
9496
+ GGML_ASSERT (false); // not supported
9497
+ } break ;
9286
9498
case GGML_OP_NONE :
9287
9499
{
9288
9500
// nop
@@ -9775,6 +9987,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
9775
9987
9776
9988
work_size = MAX (work_size , cur );
9777
9989
} break ;
9990
+ case GGML_OP_MAP_UNARY :
9991
+ case GGML_OP_MAP_BINARY :
9992
+ {
9993
+ node -> n_tasks = 1 ;
9994
+ } break ;
9778
9995
case GGML_OP_NONE :
9779
9996
{
9780
9997
node -> n_tasks = 1 ;
0 commit comments