Skip to content

Commit 01ba38d

Browse files
committed
GGML map ops proof of concept.
1 parent 180b693 commit 01ba38d

File tree

2 files changed

+267
-44
lines changed

2 files changed

+267
-44
lines changed

ggml.c

Lines changed: 253 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1951,7 +1951,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19511951
// Initialize accumulator with zeros
19521952
__m256 acc = _mm256_setzero_ps();
19531953

1954-
/* Prepare the constants we will need during execution */
1954+
/* Prepare the constants we will need during execution */
19551955
const __m256i lowMask = _mm256_set1_epi8( 0xF );
19561956
const __m256i offset_8 = _mm256_set1_epi16( 8 );
19571957

@@ -1962,60 +1962,60 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19621962
// Main loop
19631963
for (int i = 0; i < nb; i+=UNROLL_COUNT) {
19641964

1965-
// This loop will be unrolled by the compiler
1965+
// This loop will be unrolled by the compiler
19661966
for (int u=0;u<UNROLL_COUNT;u++) {
1967-
/* Compute combined scale for the block */
1968-
const __m256 scale = _mm256_mul_ps(
1969-
_mm256_broadcast_ss( &x[i+u].d ),
1970-
_mm256_broadcast_ss( &y[i+u].d ) );
1971-
1972-
/* get input from x
1973-
Input: 32 Nibbles (16 bytes) at *x[i+u]
1974-
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1975-
1976-
/* Load 16 bytes from memory */
1977-
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
1978-
/* Expand bytes into uint16_t values */
1979-
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
1967+
/* Compute combined scale for the block */
1968+
const __m256 scale = _mm256_mul_ps(
1969+
_mm256_broadcast_ss( &x[i+u].d ),
1970+
_mm256_broadcast_ss( &y[i+u].d ) );
1971+
1972+
/* get input from x
1973+
Input: 32 Nibbles (16 bytes) at *x[i+u]
1974+
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1975+
1976+
/* Load 16 bytes from memory */
1977+
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
1978+
/* Expand bytes into uint16_t values */
1979+
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
19801980
/* Unpack values into individual bytes */
19811981
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
19821982
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
1983-
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
1983+
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
19841984
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1985-
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
1986-
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
1985+
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
1986+
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
19871987

1988-
/* get input from y
1989-
Input: 32 Nibbles (16 bytes) at *y[i+u]
1990-
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
1988+
/* get input from y
1989+
Input: 32 Nibbles (16 bytes) at *y[i+u]
1990+
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
19911991

1992-
/* Load 16 bytes from memory */
1993-
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
1994-
/* Expand bytes into uint16_t values */
1995-
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
1992+
/* Load 16 bytes from memory */
1993+
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
1994+
/* Expand bytes into uint16_t values */
1995+
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
19961996
/* Unpack values into individual bytes */
1997-
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
1998-
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
1999-
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
1997+
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
1998+
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
1999+
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
20002000
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2001-
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2002-
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2001+
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2002+
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
20032003

2004-
/* Compute products of int16_t integers, add pairwise, store as int32_t */
2005-
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2006-
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2004+
/* Compute products of int16_t integers, add pairwise, store as int32_t */
2005+
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2006+
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
20072007

2008-
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2009-
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2008+
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2009+
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
20102010

2011-
/* Convert to vectore of 8 int32_t to 8 floats */
2012-
__m256 q = _mm256_cvtepi32_ps( xy_q );
2011+
/* Convert to vectore of 8 int32_t to 8 floats */
2012+
__m256 q = _mm256_cvtepi32_ps( xy_q );
20132013

2014-
/* Multiply q with scale and accumulate */
2015-
acc = _mm256_fmadd_ps( scale, q, acc );
2014+
/* Multiply q with scale and accumulate */
2015+
acc = _mm256_fmadd_ps( scale, q, acc );
20162016
}
2017-
2018-
}
2017+
2018+
}
20192019

20202020
// Return horizontal sum of the acc vector
20212021
__m128 res = _mm256_extractf128_ps( acc, 1 );
@@ -2631,9 +2631,12 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
26312631

26322632
"FLASH_ATTN",
26332633
"FLASH_FF",
2634+
2635+
"MAP_UNARY",
2636+
"MAP_BINARY",
26342637
};
26352638

2636-
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2639+
static_assert(GGML_OP_COUNT == 37, "GGML_OP_COUNT != 37");
26372640

26382641
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
26392642
"none",
@@ -2675,9 +2678,12 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
26752678

26762679
"flash_attn(x)",
26772680
"flash_ff(x)",
2681+
2682+
"f(x)",
2683+
"f(x,y)",
26782684
};
26792685

2680-
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
2686+
static_assert(GGML_OP_COUNT == 37, "GGML_OP_COUNT != 37");
26812687

26822688
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
26832689
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@@ -3589,6 +3595,92 @@ struct ggml_tensor * ggml_dup_inplace(
35893595
return ggml_dup_impl(ctx, a, true);
35903596
}
35913597

3598+
3599+
// ggml_map_binary
3600+
3601+
struct ggml_tensor * ggml_map_binary_impl(
3602+
struct ggml_context * ctx,
3603+
struct ggml_tensor * a,
3604+
struct ggml_tensor * b,
3605+
void (*const fun)(int, float *, float *, float *),
3606+
bool inplace) {
3607+
GGML_ASSERT(ggml_are_same_shape(a, b));
3608+
3609+
bool is_node = false;
3610+
3611+
if (!inplace && (a->grad || b->grad)) {
3612+
is_node = true;
3613+
}
3614+
3615+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
3616+
*((void **)addr_tensor->data) = fun;
3617+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3618+
3619+
result->op = GGML_OP_MAP_BINARY;
3620+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
3621+
result->src0 = a;
3622+
result->src1 = b;
3623+
result->opt[0] = addr_tensor;
3624+
3625+
return result;
3626+
}
3627+
3628+
struct ggml_tensor * ggml_map_binary(
3629+
struct ggml_context * ctx,
3630+
struct ggml_tensor * a,
3631+
struct ggml_tensor * b,
3632+
void (*const fun)(int, float *, float *, float *)) {
3633+
return ggml_map_binary_impl(ctx, a, b, fun, false);
3634+
}
3635+
3636+
struct ggml_tensor * ggml_map_binary_inplace(
3637+
struct ggml_context * ctx,
3638+
struct ggml_tensor * a,
3639+
struct ggml_tensor * b,
3640+
void (*const fun)(int, float *, float *, float *)) {
3641+
return ggml_map_binary_impl(ctx, a, b, fun, true);
3642+
}
3643+
3644+
// ggml_map_unary
3645+
3646+
struct ggml_tensor * ggml_map_unary_impl(
3647+
struct ggml_context * ctx,
3648+
struct ggml_tensor * a,
3649+
void (*const fun)(int, float *, float *),
3650+
bool inplace) {
3651+
bool is_node = false;
3652+
3653+
if (!inplace && a->grad) {
3654+
is_node = true;
3655+
}
3656+
3657+
struct ggml_tensor * addr_tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, sizeof(void *) / sizeof(int32_t));
3658+
*((void **)addr_tensor->data) = fun;
3659+
struct ggml_tensor *result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
3660+
3661+
result->op = GGML_OP_MAP_UNARY;
3662+
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
3663+
result->src0 = a;
3664+
result->opt[0] = addr_tensor;
3665+
3666+
return result;
3667+
}
3668+
3669+
struct ggml_tensor * ggml_map_unary(
3670+
struct ggml_context * ctx,
3671+
struct ggml_tensor * a,
3672+
void (*const fun)(int, float *, float *)) {
3673+
return ggml_map_unary_impl(ctx, a, fun, false);
3674+
}
3675+
3676+
struct ggml_tensor * ggml_map_unary_inplace(
3677+
struct ggml_context * ctx,
3678+
struct ggml_tensor * a,
3679+
void (*const fun)(int, float *, float *)) {
3680+
return ggml_map_unary_impl(ctx, a, fun, true);
3681+
}
3682+
3683+
35923684
// ggml_add
35933685

35943686
struct ggml_tensor * ggml_add_impl(
@@ -5034,6 +5126,111 @@ static void ggml_compute_forward_dup(
50345126
}
50355127
}
50365128

5129+
// ggml_compute_forward_map_unary
5130+
5131+
static void ggml_compute_forward_map_unary_f32(
5132+
const struct ggml_compute_params * params,
5133+
const struct ggml_tensor * src0,
5134+
struct ggml_tensor * dst,
5135+
void (*const fun)(int, float *, float *)) {
5136+
GGML_ASSERT(ggml_are_same_shape(src0, dst));
5137+
5138+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5139+
return;
5140+
}
5141+
5142+
const int n = ggml_nrows(src0);
5143+
const int nc = src0->ne[0];
5144+
5145+
assert( dst->nb[0] == sizeof(float));
5146+
assert(src0->nb[0] == sizeof(float));
5147+
5148+
for (int i = 0; i < n; i++) {
5149+
fun(nc,
5150+
(float *) ((char *) dst->data + i*( dst->nb[1])),
5151+
(float *) ((char *) src0->data + i*(src0->nb[1])));
5152+
}
5153+
}
5154+
5155+
5156+
static void ggml_compute_forward_map_unary(
5157+
const struct ggml_compute_params * params,
5158+
const struct ggml_tensor * src0,
5159+
struct ggml_tensor * dst,
5160+
void (*const fun)(int, float *, float *)) {
5161+
switch (src0->type) {
5162+
case GGML_TYPE_F32:
5163+
{
5164+
ggml_compute_forward_map_unary_f32(params, src0, dst, fun);
5165+
} break;
5166+
case GGML_TYPE_Q4_0:
5167+
case GGML_TYPE_Q4_1:
5168+
case GGML_TYPE_I8:
5169+
case GGML_TYPE_I16:
5170+
case GGML_TYPE_I32:
5171+
case GGML_TYPE_F16:
5172+
case GGML_TYPE_COUNT:
5173+
{
5174+
GGML_ASSERT(false);
5175+
} break;
5176+
}
5177+
}
5178+
5179+
// ggml_compute_forward_map_binary
5180+
5181+
static void ggml_compute_forward_map_binary_f32(
5182+
const struct ggml_compute_params * params,
5183+
const struct ggml_tensor * src0,
5184+
const struct ggml_tensor * src1,
5185+
struct ggml_tensor * dst,
5186+
void (*const fun)(int, float *, float *, float *)) {
5187+
assert(params->ith == 0);
5188+
assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
5189+
5190+
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
5191+
return;
5192+
}
5193+
5194+
const int n = ggml_nrows(src0);
5195+
const int nc = src0->ne[0];
5196+
5197+
assert( dst->nb[0] == sizeof(float));
5198+
assert(src0->nb[0] == sizeof(float));
5199+
assert(src1->nb[0] == sizeof(float));
5200+
5201+
for (int i = 0; i < n; i++) {
5202+
fun(nc,
5203+
(float *) ((char *) dst->data + i*( dst->nb[1])),
5204+
(float *) ((char *) src0->data + i*(src0->nb[1])),
5205+
(float *) ((char *) src1->data + i*(src1->nb[1])));
5206+
}
5207+
}
5208+
5209+
5210+
static void ggml_compute_forward_map_binary(
5211+
const struct ggml_compute_params * params,
5212+
const struct ggml_tensor * src0,
5213+
const struct ggml_tensor * src1,
5214+
struct ggml_tensor * dst,
5215+
void (*const fun)(int, float *, float *, float *)) {
5216+
switch (src0->type) {
5217+
case GGML_TYPE_F32:
5218+
{
5219+
ggml_compute_forward_map_binary_f32(params, src0, src1, dst, fun);
5220+
} break;
5221+
case GGML_TYPE_Q4_0:
5222+
case GGML_TYPE_Q4_1:
5223+
case GGML_TYPE_I8:
5224+
case GGML_TYPE_I16:
5225+
case GGML_TYPE_I32:
5226+
case GGML_TYPE_F16:
5227+
case GGML_TYPE_COUNT:
5228+
{
5229+
GGML_ASSERT(false);
5230+
} break;
5231+
}
5232+
}
5233+
50375234
// ggml_compute_forward_add
50385235

50395236
static void ggml_compute_forward_add_f32(
@@ -8567,7 +8764,19 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
85678764
{
85688765
ggml_compute_forward_dup(params, tensor->src0, tensor);
85698766
} break;
8570-
case GGML_OP_ADD:
8767+
case GGML_OP_MAP_UNARY:
8768+
{
8769+
void (*const fun)(int, float *, float *) = *((void **)tensor->opt[0]->data);
8770+
ggml_compute_forward_map_unary(params, tensor->src0, tensor, fun);
8771+
}
8772+
break;
8773+
case GGML_OP_MAP_BINARY:
8774+
{
8775+
void (*const fun)(int, float *, float *, float *) = *((void **)tensor->opt[0]->data);
8776+
ggml_compute_forward_map_binary(params, tensor->src0, tensor->src1, tensor, fun);
8777+
}
8778+
break;
8779+
case GGML_OP_ADD:
85718780
{
85728781
ggml_compute_forward_add(params, tensor->src0, tensor->src1, tensor);
85738782
} break;

ggml.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ enum ggml_op {
250250
GGML_OP_FLASH_ATTN,
251251
GGML_OP_FLASH_FF,
252252

253+
GGML_OP_MAP_UNARY,
254+
GGML_OP_MAP_BINARY,
255+
253256
GGML_OP_COUNT,
254257
};
255258

@@ -416,6 +419,17 @@ struct ggml_tensor * ggml_dup(
416419
struct ggml_context * ctx,
417420
struct ggml_tensor * a);
418421

422+
struct ggml_tensor *ggml_map_unary(
423+
struct ggml_context *ctx,
424+
struct ggml_tensor *a,
425+
void (*const fun)(int, float *, float *));
426+
427+
struct ggml_tensor *ggml_map_binary(
428+
struct ggml_context *ctx,
429+
struct ggml_tensor *a,
430+
struct ggml_tensor *b,
431+
void (*const fun)(int, float *, float *, float *));
432+
419433
struct ggml_tensor * ggml_add(
420434
struct ggml_context * ctx,
421435
struct ggml_tensor * a,

0 commit comments

Comments
 (0)