Skip to content

Commit 6bd6c2f

Browse files
li-plusyusiwen
authored andcommitted
metal : alibi for arbitrary number of heads (ggml-org#3426)
1 parent 6730941 commit 6bd6c2f

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

ggml-metal.m

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,12 +1213,9 @@ void ggml_metal_graph_compute(
12131213
float max_bias;
12141214
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
12151215

1216-
if (__builtin_popcount(n_head) != 1) {
1217-
GGML_ASSERT(false && "only power-of-two n_head implemented");
1218-
}
1219-
12201216
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
12211217
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
1218+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
12221219

12231220
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
12241221
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1239,7 +1236,9 @@ void ggml_metal_graph_compute(
12391236
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
12401237
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
12411238
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
1242-
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1239+
[encoder setBytes:&m0 length:sizeof( float) atIndex:18];
1240+
[encoder setBytes:&m1 length:sizeof( float) atIndex:19];
1241+
[encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
12431242

12441243
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
12451244
} break;

ggml-metal.metal

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,9 @@ kernel void kernel_alibi_f32(
830830
constant uint64_t & nb1,
831831
constant uint64_t & nb2,
832832
constant uint64_t & nb3,
833-
constant float & m0,
833+
constant float & m0,
834+
constant float & m1,
835+
constant int & n_heads_log2_floor,
834836
uint3 tgpig[[threadgroup_position_in_grid]],
835837
uint3 tpitg[[thread_position_in_threadgroup]],
836838
uint3 ntg[[threads_per_threadgroup]]) {
@@ -846,7 +848,12 @@ kernel void kernel_alibi_f32(
846848
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
847849

848850
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
849-
float m_k = pow(m0, i2 + 1);
851+
float m_k;
852+
if (i2 < n_heads_log2_floor) {
853+
m_k = pow(m0, i2 + 1);
854+
} else {
855+
m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
856+
}
850857
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
851858
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
852859
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);

0 commit comments

Comments
 (0)