@@ -1951,7 +1951,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1951
1951
// Initialize accumulator with zeros
1952
1952
__m256 acc = _mm256_setzero_ps ();
1953
1953
1954
- /* Prepare the constants we will need during execution */
1954
+ /* Prepare the constants we will need during execution */
1955
1955
const __m256i lowMask = _mm256_set1_epi8 ( 0xF );
1956
1956
const __m256i offset_8 = _mm256_set1_epi16 ( 8 );
1957
1957
@@ -1962,60 +1962,60 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1962
1962
// Main loop
1963
1963
for (int i = 0 ; i < nb ; i += UNROLL_COUNT ) {
1964
1964
1965
- // This loop will be unrolled by the compiler
1965
+ // This loop will be unrolled by the compiler
1966
1966
for (int u = 0 ;u < UNROLL_COUNT ;u ++ ) {
1967
- /* Compute combined scale for the block */
1968
- const __m256 scale = _mm256_mul_ps (
1969
- _mm256_broadcast_ss ( & x [i + u ].d ),
1970
- _mm256_broadcast_ss ( & y [i + u ].d ) );
1971
-
1972
- /* get input from x
1973
- Input: 32 Nibbles (16 bytes) at *x[i+u]
1974
- Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1975
-
1976
- /* Load 16 bytes from memory */
1977
- const __m128i tmp_x = _mm_loadu_si128 ( ( const __m128i * ) x [i + u ].qs );
1978
- /* Expand bytes into uint16_t values */
1979
- const __m256i bytes_x = _mm256_cvtepu8_epi16 (tmp_x );
1967
+ /* Compute combined scale for the block */
1968
+ const __m256 scale = _mm256_mul_ps (
1969
+ _mm256_broadcast_ss ( & x [i + u ].d ),
1970
+ _mm256_broadcast_ss ( & y [i + u ].d ) );
1971
+
1972
+ /* get input from x
1973
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
1974
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1975
+
1976
+ /* Load 16 bytes from memory */
1977
+ const __m128i tmp_x = _mm_loadu_si128 ( ( const __m128i * ) x [i + u ].qs );
1978
+ /* Expand bytes into uint16_t values */
1979
+ const __m256i bytes_x = _mm256_cvtepu8_epi16 (tmp_x );
1980
1980
/* Unpack values into individual bytes */
1981
1981
__m256i x_low_q = _mm256_and_si256 ( lowMask , bytes_x );
1982
1982
const __m256i pre_shift_x_high_q = _mm256_andnot_si256 ( lowMask , bytes_x );
1983
- __m256i x_high_q = _mm256_srli_epi16 ( pre_shift_x_high_q , 4 );
1983
+ __m256i x_high_q = _mm256_srli_epi16 ( pre_shift_x_high_q , 4 );
1984
1984
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1985
- x_high_q = _mm256_sub_epi16 ( x_high_q , offset_8 );
1986
- x_low_q = _mm256_sub_epi16 ( x_low_q , offset_8 );
1985
+ x_high_q = _mm256_sub_epi16 ( x_high_q , offset_8 );
1986
+ x_low_q = _mm256_sub_epi16 ( x_low_q , offset_8 );
1987
1987
1988
- /* get input from y
1989
- Input: 32 Nibbles (16 bytes) at *y[i+u]
1990
- Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
1988
+ /* get input from y
1989
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
1990
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
1991
1991
1992
- /* Load 16 bytes from memory */
1993
- const __m128i tmp_y = _mm_loadu_si128 ( (const __m128i * ) y [i + u ].qs );
1994
- /* Expand bytes into uint16_t values */
1995
- const __m256i bytes_y = _mm256_cvtepu8_epi16 (tmp_y );
1992
+ /* Load 16 bytes from memory */
1993
+ const __m128i tmp_y = _mm_loadu_si128 ( (const __m128i * ) y [i + u ].qs );
1994
+ /* Expand bytes into uint16_t values */
1995
+ const __m256i bytes_y = _mm256_cvtepu8_epi16 (tmp_y );
1996
1996
/* Unpack values into individual bytes */
1997
- const __m256i pre_shift_y_high_q = _mm256_andnot_si256 ( lowMask , bytes_y );
1998
- __m256i y_high_q = _mm256_srli_epi16 ( pre_shift_y_high_q , 4 );
1999
- __m256i y_low_q = _mm256_and_si256 ( lowMask , bytes_y );
1997
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256 ( lowMask , bytes_y );
1998
+ __m256i y_high_q = _mm256_srli_epi16 ( pre_shift_y_high_q , 4 );
1999
+ __m256i y_low_q = _mm256_and_si256 ( lowMask , bytes_y );
2000
2000
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2001
- y_high_q = _mm256_sub_epi16 ( y_high_q , offset_8 );
2002
- y_low_q = _mm256_sub_epi16 ( y_low_q , offset_8 );
2001
+ y_high_q = _mm256_sub_epi16 ( y_high_q , offset_8 );
2002
+ y_low_q = _mm256_sub_epi16 ( y_low_q , offset_8 );
2003
2003
2004
- /* Compute products of int16_t integers, add pairwise, store as int32_t */
2005
- __m256i xy_high_q = _mm256_madd_epi16 ( x_high_q , y_high_q );
2006
- __m256i xy_low_q = _mm256_madd_epi16 ( x_low_q , y_low_q );
2004
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
2005
+ __m256i xy_high_q = _mm256_madd_epi16 ( x_high_q , y_high_q );
2006
+ __m256i xy_low_q = _mm256_madd_epi16 ( x_low_q , y_low_q );
2007
2007
2008
- /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2009
- __m256i xy_q = _mm256_add_epi32 ( xy_high_q , xy_low_q );
2008
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2009
+ __m256i xy_q = _mm256_add_epi32 ( xy_high_q , xy_low_q );
2010
2010
2011
- /* Convert to vectore of 8 int32_t to 8 floats */
2012
- __m256 q = _mm256_cvtepi32_ps ( xy_q );
2011
+ /* Convert to vectore of 8 int32_t to 8 floats */
2012
+ __m256 q = _mm256_cvtepi32_ps ( xy_q );
2013
2013
2014
- /* Multiply q with scale and accumulate */
2015
- acc = _mm256_fmadd_ps ( scale , q , acc );
2014
+ /* Multiply q with scale and accumulate */
2015
+ acc = _mm256_fmadd_ps ( scale , q , acc );
2016
2016
}
2017
-
2018
- }
2017
+
2018
+ }
2019
2019
2020
2020
// Return horizontal sum of the acc vector
2021
2021
__m128 res = _mm256_extractf128_ps ( acc , 1 );
@@ -2631,9 +2631,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
2631
2631
2632
2632
"FLASH_ATTN" ,
2633
2633
"FLASH_FF" ,
2634
+
2635
+ "MAP_UNARY" ,
2636
+ "MAP_BINARY" ,
2634
2637
};
2635
2638
2636
- static_assert (GGML_OP_COUNT == 35 , "GGML_OP_COUNT != 35 " );
2639
+ static_assert (GGML_OP_COUNT == 37 , "GGML_OP_COUNT != 37 " );
2637
2640
2638
2641
static const char * GGML_OP_SYMBOL [GGML_OP_COUNT ] = {
2639
2642
"none" ,
@@ -2675,9 +2678,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
2675
2678
2676
2679
"flash_attn(x)" ,
2677
2680
"flash_ff(x)" ,
2681
+
2682
+ "f(x)" ,
2683
+ "f(x,y)" ,
2678
2684
};
2679
2685
2680
- static_assert (GGML_OP_COUNT == 35 , "GGML_OP_COUNT != 35 " );
2686
+ static_assert (GGML_OP_COUNT == 37 , "GGML_OP_COUNT != 37 " );
2681
2687
2682
2688
static_assert (sizeof (struct ggml_object )%GGML_MEM_ALIGN == 0 , "ggml_object size must be a multiple of GGML_MEM_ALIGN" );
2683
2689
static_assert (sizeof (struct ggml_tensor )%GGML_MEM_ALIGN == 0 , "ggml_tensor size must be a multiple of GGML_MEM_ALIGN" );
@@ -3589,6 +3595,92 @@ struct ggml_tensor * ggml_dup_inplace(
3589
3595
return ggml_dup_impl (ctx , a , true);
3590
3596
}
3591
3597
3598
+
3599
+ // ggml_map_binary
3600
+
3601
+ struct ggml_tensor * ggml_map_binary_impl (
3602
+ struct ggml_context * ctx ,
3603
+ struct ggml_tensor * a ,
3604
+ struct ggml_tensor * b ,
3605
+ void (* const fun )(int , float * , float * , float * ),
3606
+ bool inplace ) {
3607
+ GGML_ASSERT (ggml_are_same_shape (a , b ));
3608
+
3609
+ bool is_node = false;
3610
+
3611
+ if (!inplace && (a -> grad || b -> grad )) {
3612
+ is_node = true;
3613
+ }
3614
+
3615
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d (ctx , GGML_TYPE_I32 , sizeof (void * ) / sizeof (int32_t ));
3616
+ * ((void * * )addr_tensor -> data ) = fun ;
3617
+ struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
3618
+
3619
+ result -> op = GGML_OP_MAP_BINARY ;
3620
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
3621
+ result -> src0 = a ;
3622
+ result -> src1 = b ;
3623
+ result -> opt [0 ] = addr_tensor ;
3624
+
3625
+ return result ;
3626
+ }
3627
+
3628
+ struct ggml_tensor * ggml_map_binary (
3629
+ struct ggml_context * ctx ,
3630
+ struct ggml_tensor * a ,
3631
+ struct ggml_tensor * b ,
3632
+ void (* const fun )(int , float * , float * , float * )) {
3633
+ return ggml_map_binary_impl (ctx , a , b , fun , false);
3634
+ }
3635
+
3636
+ struct ggml_tensor * ggml_map_binary_inplace (
3637
+ struct ggml_context * ctx ,
3638
+ struct ggml_tensor * a ,
3639
+ struct ggml_tensor * b ,
3640
+ void (* const fun )(int , float * , float * , float * )) {
3641
+ return ggml_map_binary_impl (ctx , a , b , fun , true);
3642
+ }
3643
+
3644
+ // ggml_map_unary
3645
+
3646
+ struct ggml_tensor * ggml_map_unary_impl (
3647
+ struct ggml_context * ctx ,
3648
+ struct ggml_tensor * a ,
3649
+ void (* const fun )(int , float * , float * ),
3650
+ bool inplace ) {
3651
+ bool is_node = false;
3652
+
3653
+ if (!inplace && a -> grad ) {
3654
+ is_node = true;
3655
+ }
3656
+
3657
+ struct ggml_tensor * addr_tensor = ggml_new_tensor_1d (ctx , GGML_TYPE_I32 , sizeof (void * ) / sizeof (int32_t ));
3658
+ * ((void * * )addr_tensor -> data ) = fun ;
3659
+ struct ggml_tensor * result = inplace ? ggml_view_tensor (ctx , a ) : ggml_dup_tensor (ctx , a );
3660
+
3661
+ result -> op = GGML_OP_MAP_UNARY ;
3662
+ result -> grad = is_node ? ggml_dup_tensor (ctx , result ) : NULL ;
3663
+ result -> src0 = a ;
3664
+ result -> opt [0 ] = addr_tensor ;
3665
+
3666
+ return result ;
3667
+ }
3668
+
3669
+ struct ggml_tensor * ggml_map_unary (
3670
+ struct ggml_context * ctx ,
3671
+ struct ggml_tensor * a ,
3672
+ void (* const fun )(int , float * , float * )) {
3673
+ return ggml_map_unary_impl (ctx , a , fun , false);
3674
+ }
3675
+
3676
+ struct ggml_tensor * ggml_map_unary_inplace (
3677
+ struct ggml_context * ctx ,
3678
+ struct ggml_tensor * a ,
3679
+ void (* const fun )(int , float * , float * )) {
3680
+ return ggml_map_unary_impl (ctx , a , fun , true);
3681
+ }
3682
+
3683
+
3592
3684
// ggml_add
3593
3685
3594
3686
struct ggml_tensor * ggml_add_impl (
@@ -5034,6 +5126,111 @@ static void ggml_compute_forward_dup(
5034
5126
}
5035
5127
}
5036
5128
5129
+ // ggml_compute_forward_map_unary
5130
+
5131
+ static void ggml_compute_forward_map_unary_f32 (
5132
+ const struct ggml_compute_params * params ,
5133
+ const struct ggml_tensor * src0 ,
5134
+ struct ggml_tensor * dst ,
5135
+ void (* const fun )(int , float * , float * )) {
5136
+ GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
5137
+
5138
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5139
+ return ;
5140
+ }
5141
+
5142
+ const int n = ggml_nrows (src0 );
5143
+ const int nc = src0 -> ne [0 ];
5144
+
5145
+ assert ( dst -> nb [0 ] == sizeof (float ));
5146
+ assert (src0 -> nb [0 ] == sizeof (float ));
5147
+
5148
+ for (int i = 0 ; i < n ; i ++ ) {
5149
+ fun (nc ,
5150
+ (float * ) ((char * ) dst -> data + i * ( dst -> nb [1 ])),
5151
+ (float * ) ((char * ) src0 -> data + i * (src0 -> nb [1 ])));
5152
+ }
5153
+ }
5154
+
5155
+
5156
+ static void ggml_compute_forward_map_unary (
5157
+ const struct ggml_compute_params * params ,
5158
+ const struct ggml_tensor * src0 ,
5159
+ struct ggml_tensor * dst ,
5160
+ void (* const fun )(int , float * , float * )) {
5161
+ switch (src0 -> type ) {
5162
+ case GGML_TYPE_F32 :
5163
+ {
5164
+ ggml_compute_forward_map_unary_f32 (params , src0 , dst , fun );
5165
+ } break ;
5166
+ case GGML_TYPE_Q4_0 :
5167
+ case GGML_TYPE_Q4_1 :
5168
+ case GGML_TYPE_I8 :
5169
+ case GGML_TYPE_I16 :
5170
+ case GGML_TYPE_I32 :
5171
+ case GGML_TYPE_F16 :
5172
+ case GGML_TYPE_COUNT :
5173
+ {
5174
+ GGML_ASSERT (false);
5175
+ } break ;
5176
+ }
5177
+ }
5178
+
5179
+ // ggml_compute_forward_map_binary
5180
+
5181
+ static void ggml_compute_forward_map_binary_f32 (
5182
+ const struct ggml_compute_params * params ,
5183
+ const struct ggml_tensor * src0 ,
5184
+ const struct ggml_tensor * src1 ,
5185
+ struct ggml_tensor * dst ,
5186
+ void (* const fun )(int , float * , float * , float * )) {
5187
+ assert (params -> ith == 0 );
5188
+ assert (ggml_are_same_shape (src0 , src1 ) && ggml_are_same_shape (src0 , dst ));
5189
+
5190
+ if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
5191
+ return ;
5192
+ }
5193
+
5194
+ const int n = ggml_nrows (src0 );
5195
+ const int nc = src0 -> ne [0 ];
5196
+
5197
+ assert ( dst -> nb [0 ] == sizeof (float ));
5198
+ assert (src0 -> nb [0 ] == sizeof (float ));
5199
+ assert (src1 -> nb [0 ] == sizeof (float ));
5200
+
5201
+ for (int i = 0 ; i < n ; i ++ ) {
5202
+ fun (nc ,
5203
+ (float * ) ((char * ) dst -> data + i * ( dst -> nb [1 ])),
5204
+ (float * ) ((char * ) src0 -> data + i * (src0 -> nb [1 ])),
5205
+ (float * ) ((char * ) src1 -> data + i * (src1 -> nb [1 ])));
5206
+ }
5207
+ }
5208
+
5209
+
5210
+ static void ggml_compute_forward_map_binary (
5211
+ const struct ggml_compute_params * params ,
5212
+ const struct ggml_tensor * src0 ,
5213
+ const struct ggml_tensor * src1 ,
5214
+ struct ggml_tensor * dst ,
5215
+ void (* const fun )(int , float * , float * , float * )) {
5216
+ switch (src0 -> type ) {
5217
+ case GGML_TYPE_F32 :
5218
+ {
5219
+ ggml_compute_forward_map_binary_f32 (params , src0 , src1 , dst , fun );
5220
+ } break ;
5221
+ case GGML_TYPE_Q4_0 :
5222
+ case GGML_TYPE_Q4_1 :
5223
+ case GGML_TYPE_I8 :
5224
+ case GGML_TYPE_I16 :
5225
+ case GGML_TYPE_I32 :
5226
+ case GGML_TYPE_F16 :
5227
+ case GGML_TYPE_COUNT :
5228
+ {
5229
+ GGML_ASSERT (false);
5230
+ } break ;
5231
+ }
5232
+ }
5233
+
5037
5234
// ggml_compute_forward_add
5038
5235
5039
5236
static void ggml_compute_forward_add_f32 (
@@ -8567,7 +8764,19 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
8567
8764
{
8568
8765
ggml_compute_forward_dup (params , tensor -> src0 , tensor );
8569
8766
} break ;
8570
- case GGML_OP_ADD :
8767
+ case GGML_OP_MAP_UNARY :
8768
+ {
8769
+ void (* const fun )(int , float * , float * ) = * ((void * * )tensor -> opt [0 ]-> data );
8770
+ ggml_compute_forward_map_unary (params , tensor -> src0 , tensor , fun );
8771
+ }
8772
+ break ;
8773
+ case GGML_OP_MAP_BINARY :
8774
+ {
8775
+ void (* const fun )(int , float * , float * , float * ) = * ((void * * )tensor -> opt [0 ]-> data );
8776
+ ggml_compute_forward_map_binary (params , tensor -> src0 , tensor -> src1 , tensor , fun );
8777
+ }
8778
+ break ;
8779
+ case GGML_OP_ADD :
8571
8780
{
8572
8781
ggml_compute_forward_add (params , tensor -> src0 , tensor -> src1 , tensor );
8573
8782
} break ;
0 commit comments