Skip to content

POC: combined scale + diagonal mask infinity + soft max op #3121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
GGML_METAL_DECL_KERNEL(scale_diag_inf_soft_max);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
Expand Down Expand Up @@ -224,6 +225,7 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
GGML_METAL_ADD_KERNEL(scale_diag_inf_soft_max);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
Expand Down Expand Up @@ -294,6 +296,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
GGML_METAL_DEL_KERNEL(scale_diag_inf_soft_max);
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
Expand Down Expand Up @@ -817,6 +821,23 @@ void ggml_metal_graph_compute(
GGML_ASSERT(false);
}
} break;
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
{
const float scale = ((float *)(dst->op_params))[0];
const int n_past = ((int32_t *)(dst->op_params))[1];
const int nth = 32;

[encoder setComputePipelineState:ctx->pipeline_scale_diag_inf_soft_max];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&scale length:sizeof(float) atIndex:5];
[encoder setBytes:&n_past length:sizeof(int) atIndex:6];

[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_SOFT_MAX:
{
const int nth = 32;
Expand Down
42 changes: 42 additions & 0 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,48 @@ kernel void kernel_soft_max_4(
}
}

kernel void kernel_scale_diag_inf_soft_max(
device const float * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
constant int & n_past,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];

device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;

// parallel max
float lmax = psrc0[tpitg[0]];
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
lmax = MAX(lmax, psrc0[i00]);
}
const float max = simd_max(lmax) * scale;

// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
const float exp_psrc0 = i00 > n_past + i01 ? 0.f : exp(scale*psrc0[i00] - max);
lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice.
pdst[i00] = exp_psrc0;
}

const float sum = simd_sum(lsum);

for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] /= sum;
}
}

kernel void kernel_diag_mask_inf(
device const float * src0,
device float * dst,
Expand Down
45 changes: 43 additions & 2 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -4001,7 +4001,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};

static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -4083,7 +4083,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};

static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
static_assert(GGML_OP_COUNT == 69, "GGML_OP_COUNT != 69");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -6952,6 +6952,32 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
return ggml_soft_max_back_impl(ctx, a, b, true);
}

struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace(
struct ggml_context * ctx,
float scale,
int n_past,
struct ggml_tensor * a) {
//bool is_node = false;

//if (a->grad) {
// is_node = true;
//}

struct ggml_tensor * result = ggml_view_tensor(ctx, a);

int32_t params[2];
memcpy(&params[0], &scale, sizeof(scale));
params[1] = n_past;
ggml_set_op_params(result, params, sizeof(params));

result->op = GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX;
//result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->grad = NULL;
result->src[0] = a;

return result;
}

// ggml_rope

static struct ggml_tensor * ggml_rope_impl(
Expand Down Expand Up @@ -15993,6 +16019,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
// nop
} break;
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
{
fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n");
GGML_ASSERT(false);
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
Expand Down Expand Up @@ -16861,6 +16892,11 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{
// nop
} break;
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
{
fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n");
GGML_ASSERT(false);
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
Expand Down Expand Up @@ -17698,6 +17734,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
{
n_tasks = 1;
} break;
case GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX:
{
fprintf(stderr, "GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX not implemented\n");
GGML_ASSERT(false);
} break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
Expand Down
8 changes: 8 additions & 0 deletions ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,8 @@ extern "C" {
GGML_OP_CROSS_ENTROPY_LOSS,
GGML_OP_CROSS_ENTROPY_LOSS_BACK,

GGML_OP_SCALE_DIAG_MASK_INF_SOFTMAX,

GGML_OP_COUNT,
};

Expand Down Expand Up @@ -1209,6 +1211,12 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);

GGML_API struct ggml_tensor * ggml_scale_diag_mask_inf_softmax_inplace(
struct ggml_context * ctx,
float scale,
int n_past,
struct ggml_tensor * a);

// rotary position embedding
// if mode & 1 == 1, skip n_past elements
// if mode & 2 == 1, GPT-NeoX style
Expand Down
56 changes: 34 additions & 22 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2316,6 +2316,8 @@ static struct ggml_cgraph * llm_build_llama(
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");

const float kq_scale = 1.0f/sqrtf(float(n_embd)/n_head);

for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il);

Expand Down Expand Up @@ -2405,22 +2407,26 @@ static struct ggml_cgraph * llm_build_llama(
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");

// KQ_scaled = KQ / sqrt(n_embd_head)
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");

// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");

// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
struct ggml_tensor * KQ_soft_max = ggml_scale_diag_mask_inf_softmax_inplace(ctx0, kq_scale, n_past, KQ);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");

//// KQ_scaled = KQ / sqrt(n_embd_head)
//// KQ_scaled shape [n_past + N, N, n_head, 1]
//struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
//offload_func_kq(KQ_scaled);
//ggml_set_name(KQ_scaled, "KQ_scaled");

//// KQ_masked = mask_past(KQ_scaled)
//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
//offload_func_kq(KQ_masked);
//ggml_set_name(KQ_masked, "KQ_masked");

//// KQ = soft_max(KQ_masked)
//struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
//offload_func_v(KQ_soft_max);
//ggml_set_name(KQ_soft_max, "KQ_soft_max");

// split cached V into n_head heads
struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
Expand Down Expand Up @@ -2647,6 +2653,8 @@ static struct ggml_cgraph * llm_build_falcon(
}
ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");

const float kq_scale = 1.0f/sqrtf(float(n_embd)/n_head);

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * attn_norm;

Expand Down Expand Up @@ -2764,18 +2772,22 @@ static struct ggml_cgraph * llm_build_falcon(
offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");

struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");

struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");

struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
struct ggml_tensor * KQ_soft_max = ggml_scale_diag_mask_inf_softmax_inplace(ctx0, kq_scale, n_past, KQ);
offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");

//struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
//offload_func_kq(KQ_scaled);
//ggml_set_name(KQ_scaled, "KQ_scaled");

//struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
//offload_func_kq(KQ_masked);
//ggml_set_name(KQ_masked, "KQ_masked");

//struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
//offload_func_v(KQ_soft_max);
//ggml_set_name(KQ_soft_max, "KQ_soft_max");

struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd_head, n_head_kv,
Expand Down