@@ -2712,9 +2712,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2712
2712
2713
2713
"FLASH_ATTN" ,
2714
2714
"FLASH_FF" ,
2715
+
2716
+ "MAP_UNARY" ,
2717
+ "MAP_BINARY" ,
2715
2718
};
2716
2719
2717
- static_assert (GGML_OP_COUNT == 36 , "GGML_OP_COUNT != 36 " );
2720
+ static_assert (GGML_OP_COUNT == 38 , "GGML_OP_COUNT != 38 " );
2718
2721
2719
2722
static const char * GGML_OP_SYMBOL [GGML_OP_COUNT ] = {
2720
2723
"none" ,
@@ -2757,9 +2760,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2757
2760
2758
2761
"flash_attn(x)" ,
2759
2762
"flash_ff(x)" ,
2763
+
2764
+ "f(x)" ,
2765
+ "f(x,y)" ,
2760
2766
};
2761
2767
2762
- static_assert (GGML_OP_COUNT == 36 , "GGML_OP_COUNT != 36 " );
2768
+ static_assert (GGML_OP_COUNT == 38 , "GGML_OP_COUNT != 38 " );
2763
2769
2764
2770
static_assert (sizeof (struct ggml_object )%GGML_MEM_ALIGN == 0 , "ggml_object size must be a multiple of GGML_MEM_ALIGN" );
2765
2771
static_assert (sizeof (struct ggml_tensor )%GGML_MEM_ALIGN == 0 , "ggml_tensor size must be a multiple of GGML_MEM_ALIGN" );
@@ -3671,6 +3677,92 @@ struct ggml_tensor * ggml_dup_inplace(
3671
3677
return ggml_dup_impl (ctx , a , true);
3672
3678
}
3673
3679
3680
+
3681
+ // ggml_map_binary
3682
+
3683
+ struct ggml_tensor * ggml_map_binary_impl (
3684
+ struct ggml_context * ctx ,
3685
+ struct ggml_tensor * a ,
3686
+ struct ggml_tensor * b ,
3687
+ void (* const fun )(int , float * , float * , float * ),
3688
+ bool inplace ) {
3689
+ GGML_ASSERT (ggml_are_same_shape (a , b ));
3690
+
3691
+ bool is_node = false;
3692
+
3693
+ if (!inplace && (a -> grad || b -> grad )) {
3694
+ is_node = true;
3695
+ }
3696
+
3697
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d (ctx , GGML_TYPE_I32 , sizeof (void * ) / sizeof (int32_t ));
3698
+ * ((void * * )addr_tensor -> data ) = fun ;
3699
+ struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
3700
+
3701
+ result -> op = GGML_OP_MAP_BINARY ;
3702
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
3703
+ result -> src0 = a ;
3704
+ result -> src1 = b ;
3705
+ result -> opt [0 ] = addr_tensor ;
3706
+
3707
+ return result ;
3708
+ }
3709
+
3710
+ struct ggml_tensor * ggml_map_binary (
3711
+ struct ggml_context * ctx ,
3712
+ struct ggml_tensor * a ,
3713
+ struct ggml_tensor * b ,
3714
+ void (* const fun )(int , float * , float * , float * )) {
3715
+ return ggml_map_binary_impl (ctx , a , b , fun , false);
3716
+ }
3717
+
3718
+ struct ggml_tensor * ggml_map_binary_inplace (
3719
+ struct ggml_context * ctx ,
3720
+ struct ggml_tensor * a ,
3721
+ struct ggml_tensor * b ,
3722
+ void (* const fun )(int , float * , float * , float * )) {
3723
+ return ggml_map_binary_impl (ctx , a , b , fun , true);
3724
+ }
3725
+
3726
+ // ggml_map_unary
3727
+
3728
+ struct ggml_tensor * ggml_map_unary_impl (
3729
+ struct ggml_context * ctx ,
3730
+ struct ggml_tensor * a ,
3731
+ void (* const fun )(int , float * , float * ),
3732
+ bool inplace ) {
3733
+ bool is_node = false;
3734
+
3735
+ if (!inplace && a -> grad ) {
3736
+ is_node = true;
3737
+ }
3738
+
3739
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d (ctx , GGML_TYPE_I32 , sizeof (void * ) / sizeof (int32_t ));
3740
+ * ((void * * )addr_tensor -> data ) = fun ;
3741
+ struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
3742
+
3743
+ result -> op = GGML_OP_MAP_UNARY ;
3744
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
3745
+ result -> src0 = a ;
3746
+ result -> opt [0 ] = addr_tensor ;
3747
+
3748
+ return result ;
3749
+ }
3750
+
3751
+ struct ggml_tensor * ggml_map_unary (
3752
+ struct ggml_context * ctx ,
3753
+ struct ggml_tensor * a ,
3754
+ void (* const fun )(int , float * , float * )) {
3755
+ return ggml_map_unary_impl (ctx , a , fun , false);
3756
+ }
3757
+
3758
+ struct ggml_tensor * ggml_map_unary_inplace (
3759
+ struct ggml_context * ctx ,
3760
+ struct ggml_tensor * a ,
3761
+ void (* const fun )(int , float * , float * )) {
3762
+ return ggml_map_unary_impl (ctx , a , fun , true);
3763
+ }
3764
+
3765
+
3674
3766
// ggml_add
3675
3767
3676
3768
struct ggml_tensor * ggml_add_impl (
@@ -5329,6 +5421,111 @@ static void ggml_compute_forward_dup(
5329
5421
}
5330
5422
}
5331
5423
5424
+ // ggml_compute_forward_map_unary
5425
+
5426
+ static void ggml_compute_forward_map_unary_f32 (
5427
+ const struct ggml_compute_params * params ,
5428
+ const struct ggml_tensor * src0 ,
5429
+ struct ggml_tensor * dst ,
5430
+ void (* const fun )(int , float * , float * )) {
5431
+ GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
5432
+
5433
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5434
+ return ;
5435
+ }
5436
+
5437
+ const int n = ggml_nrows (src0 );
5438
+ const int nc = src0 -> ne [0 ];
5439
+
5440
+ assert ( dst -> nb [0 ] == sizeof (float ));
5441
+ assert (src0 -> nb [0 ] == sizeof (float ));
5442
+
5443
+ for (int i = 0 ; i < n ; i ++ ) {
5444
+ fun (nc ,
5445
+ (float * ) ((char * ) dst -> data + i * ( dst -> nb [1 ])),
5446
+ (float * ) ((char * ) src0 -> data + i * (src0 -> nb [1 ])));
5447
+ }
5448
+ }
5449
+
5450
+
5451
+ static void ggml_compute_forward_map_unary (
5452
+ const struct ggml_compute_params * params ,
5453
+ const struct ggml_tensor * src0 ,
5454
+ struct ggml_tensor * dst ,
5455
+ void (* const fun )(int , float * , float * )) {
5456
+ switch (src0 -> type ) {
5457
+ case GGML_TYPE_F32 :
5458
+ {
5459
+ ggml_compute_forward_map_unary_f32 (params , src0 , dst , fun );
5460
+ } break ;
5461
+ case GGML_TYPE_Q4_0 :
5462
+ case GGML_TYPE_Q4_1 :
5463
+ case GGML_TYPE_I8 :
5464
+ case GGML_TYPE_I16 :
5465
+ case GGML_TYPE_I32 :
5466
+ case GGML_TYPE_F16 :
5467
+ case GGML_TYPE_COUNT :
5468
+ {
5469
+ GGML_ASSERT (false);
5470
+ } break ;
5471
+ }
5472
+ }
5473
+
5474
+ // ggml_compute_forward_map_binary
5475
+
5476
+ static void ggml_compute_forward_map_binary_f32 (
5477
+ const struct ggml_compute_params * params ,
5478
+ const struct ggml_tensor * src0 ,
5479
+ const struct ggml_tensor * src1 ,
5480
+ struct ggml_tensor * dst ,
5481
+ void (* const fun )(int , float * , float * , float * )) {
5482
+ assert (params -> ith == 0 );
5483
+ assert (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5484
+
5485
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5486
+ return ;
5487
+ }
5488
+
5489
+ const int n = ggml_nrows (src0 );
5490
+ const int nc = src0 -> ne [0 ];
5491
+
5492
+ assert ( dst -> nb [0 ] == sizeof (float ));
5493
+ assert (src0 -> nb [0 ] == sizeof (float ));
5494
+ assert (src1 -> nb [0 ] == sizeof (float ));
5495
+
5496
+ for (int i = 0 ; i < n ; i ++ ) {
5497
+ fun (nc ,
5498
+ (float * ) ((char * ) dst -> data + i * ( dst -> nb [1 ])),
5499
+ (float * ) ((char * ) src0 -> data + i * (src0 -> nb [1 ])),
5500
+ (float * ) ((char * ) src1 -> data + i * (src1 -> nb [1 ])));
5501
+ }
5502
+ }
5503
+
5504
+
5505
+ static void ggml_compute_forward_map_binary (
5506
+ const struct ggml_compute_params * params ,
5507
+ const struct ggml_tensor * src0 ,
5508
+ const struct ggml_tensor * src1 ,
5509
+ struct ggml_tensor * dst ,
5510
+ void (* const fun )(int , float * , float * , float * )) {
5511
+ switch (src0 -> type ) {
5512
+ case GGML_TYPE_F32 :
5513
+ {
5514
+ ggml_compute_forward_map_binary_f32 (params , src0 , src1 , dst , fun );
5515
+ } break ;
5516
+ case GGML_TYPE_Q4_0 :
5517
+ case GGML_TYPE_Q4_1 :
5518
+ case GGML_TYPE_I8 :
5519
+ case GGML_TYPE_I16 :
5520
+ case GGML_TYPE_I32 :
5521
+ case GGML_TYPE_F16 :
5522
+ case GGML_TYPE_COUNT :
5523
+ {
5524
+ GGML_ASSERT (false);
5525
+ } break ;
5526
+ }
5527
+ }
5528
+
5332
5529
// ggml_compute_forward_add
5333
5530
5334
5531
static void ggml_compute_forward_add_f32 (
@@ -8877,7 +9074,19 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8877
9074
{
8878
9075
ggml_compute_forward_dup (params , tensor -> src0 , tensor );
8879
9076
} break ;
8880
- case GGML_OP_ADD :
9077
+ case GGML_OP_MAP_UNARY :
9078
+ {
9079
+ void (* const fun )(int , float * , float * ) = * ((void * * )tensor -> opt [0 ]-> data );
9080
+ ggml_compute_forward_map_unary (params , tensor -> src0 , tensor , fun );
9081
+ }
9082
+ break ;
9083
+ case GGML_OP_MAP_BINARY :
9084
+ {
9085
+ void (* const fun )(int , float * , float * , float * ) = * ((void * * )tensor -> opt [0 ]-> data );
9086
+ ggml_compute_forward_map_binary (params , tensor -> src0 , tensor -> src1 , tensor , fun );
9087
+ }
9088
+ break ;
9089
+ case GGML_OP_ADD :
8881
9090
{
8882
9091
ggml_compute_forward_add (params , tensor -> src0 , tensor -> src1 , tensor );
8883
9092
} break ;
0 commit comments