Skip to content

Commit acf8fbb

Browse files
authored
sync : ggml (HBM + Metal + style) (ggml-org#1264)
1 parent 8fa08e4 commit acf8fbb

File tree

3 files changed

+42
-39
lines changed

3 files changed

+42
-39
lines changed

ggml-metal.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,7 @@ void ggml_metal_graph_compute(
11411141
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
11421142
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
11431143

1144-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1144+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
11451145
} break;
11461146
case GGML_OP_DUP:
11471147
case GGML_OP_CPY:

ggml-metal.metal

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -220,27 +220,17 @@ kernel void kernel_norm(
220220
}
221221
threadgroup_barrier(mem_flags::mem_threadgroup);
222222
}
223-
//// broadcast
224-
//if (tpitg == 0) {
225-
// sum[0] /= ne00;
226-
//}
227-
//threadgroup_barrier(mem_flags::mem_threadgroup);
228-
const float mean = sum[0];
223+
const float mean = sum[0] / ne00;
229224

230225
// recenter and VARIANCE
226+
threadgroup_barrier(mem_flags::mem_threadgroup);
231227
device float * y = dst + tgpig*ne00;
232228
sum[tpitg] = 0.0f;
233229
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
234230
y[i00] = x[i00] - mean;
235231
sum[tpitg] += y[i00] * y[i00];
236232
}
237233

238-
//// VARIANCE
239-
//// parallel sum
240-
//sum[tpitg] = 0.0f;
241-
//for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
242-
// sum[tpitg] += y[i00] * y[i00];
243-
//}
244234
// reduce
245235
threadgroup_barrier(mem_flags::mem_threadgroup);
246236
for (uint i = ntg/2; i > 0; i /= 2) {
@@ -249,20 +239,14 @@ kernel void kernel_norm(
249239
}
250240
threadgroup_barrier(mem_flags::mem_threadgroup);
251241
}
252-
//// broadcast
253-
//if (tpitg == 0) {
254-
// sum[0] /= ne00;
255-
//}
256-
//threadgroup_barrier(mem_flags::mem_threadgroup);
257-
const float variance = sum[0];
242+
const float variance = sum[0] / ne00;
258243

259244
const float scale = 1.0f/sqrt(variance + eps);
260245
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
261246
y[i00] = y[i00] * scale;
262247
}
263248
}
264249

265-
266250
kernel void kernel_rms_norm(
267251
device const void * src0,
268252
device float * dst,
@@ -630,7 +614,6 @@ kernel void kernel_mul_mat_f16_f32(
630614
}
631615
}
632616
}
633-
634617
}
635618

636619
kernel void kernel_alibi_f32(
@@ -699,25 +682,27 @@ kernel void kernel_rope(
699682
constant int & mode,
700683
constant float & freq_base,
701684
constant float & freq_scale,
702-
uint3 tpig[[thread_position_in_grid]]) {
703-
const int64_t i3 = tpig[2];
704-
const int64_t i2 = tpig[1];
705-
const int64_t i1 = tpig[0];
685+
uint tiitg[[thread_index_in_threadgroup]],
686+
uint3 tptg[[threads_per_threadgroup]],
687+
uint3 tgpig[[threadgroup_position_in_grid]]) {
688+
const int64_t i3 = tgpig[2];
689+
const int64_t i2 = tgpig[1];
690+
const int64_t i1 = tgpig[0];
706691

707692
const bool is_neox = mode & 2;
708-
const float theta_scale = pow(freq_base, -2.0f/n_dims);
709693

710694
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
711695

712-
float theta = freq_scale * (float)p;
696+
const float theta_0 = freq_scale * (float)p;
697+
const float inv_ndims = -1.f/n_dims;
713698

714699
if (!is_neox) {
715-
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
700+
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
701+
702+
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
716703
const float cos_theta = cos(theta);
717704
const float sin_theta = sin(theta);
718705

719-
theta *= theta_scale;
720-
721706
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
722707
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
723708

@@ -729,12 +714,12 @@ kernel void kernel_rope(
729714
}
730715
} else {
731716
for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
732-
for (int64_t ic = 0; ic < n_dims; ic += 2) {
717+
for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
718+
719+
const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib);
733720
const float cos_theta = cos(theta);
734721
const float sin_theta = sin(theta);
735722

736-
theta *= theta_scale;
737-
738723
const int64_t i0 = ib*n_dims + ic/2;
739724

740725
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);

ggml.c

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ typedef void * thread_ret_t;
106106
#include <sys/stat.h>
107107
#include <unistd.h>
108108

109+
#endif
110+
#ifdef GGML_USE_CPU_HBM
111+
#include <hbwmalloc.h>
109112
#endif
110113

111114
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
@@ -195,8 +198,14 @@ typedef void * thread_ret_t;
195198
#define GGML_ALIGNED_FREE(ptr) _aligned_free(ptr)
196199
#else
197200
inline static void * ggml_aligned_malloc(size_t size) {
201+
if (size == 0) {
202+
GGML_PRINT("WARNING: Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
203+
return NULL;
204+
}
198205
void * aligned_memory = NULL;
199-
#ifdef GGML_USE_METAL
206+
#ifdef GGML_USE_CPU_HBM
207+
int result = hbw_posix_memalign(&aligned_memory, 16, size);
208+
#elif GGML_USE_METAL
200209
int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
201210
#else
202211
int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
@@ -218,8 +227,12 @@ inline static void * ggml_aligned_malloc(size_t size) {
218227
return aligned_memory;
219228
}
220229
#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
230+
#ifdef GGML_USE_CPU_HBM
231+
#define GGML_ALIGNED_FREE(ptr) if(NULL != ptr) hbw_free(ptr)
232+
#else
221233
#define GGML_ALIGNED_FREE(ptr) free(ptr)
222234
#endif
235+
#endif
223236

224237
#define UNUSED GGML_UNUSED
225238
#define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0)
@@ -4571,6 +4584,11 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
45714584
return NULL;
45724585
}
45734586

4587+
// allow to call ggml_init with 0 size
4588+
if (params.mem_size == 0) {
4589+
params.mem_size = GGML_MEM_ALIGN;
4590+
}
4591+
45744592
const size_t mem_size = params.mem_buffer ? params.mem_size : GGML_PAD(params.mem_size, GGML_MEM_ALIGN);
45754593

45764594
*ctx = (struct ggml_context) {
@@ -4773,7 +4791,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
47734791

47744792
size_t obj_alloc_size = 0;
47754793

4776-
if (view_src == NULL && ctx->no_alloc == false) {
4794+
if (view_src == NULL && !ctx->no_alloc) {
47774795
if (ctx->scratch.data != NULL) {
47784796
// allocate tensor data in the scratch buffer
47794797
if (ctx->scratch.offs + data_size > ctx->scratch.size) {
@@ -5474,7 +5492,7 @@ static struct ggml_tensor * ggml_mul_impl(
54745492
}
54755493

54765494
if (inplace) {
5477-
GGML_ASSERT(is_node == false);
5495+
GGML_ASSERT(!is_node);
54785496
}
54795497

54805498
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -5517,7 +5535,7 @@ static struct ggml_tensor * ggml_div_impl(
55175535
}
55185536

55195537
if (inplace) {
5520-
GGML_ASSERT(is_node == false);
5538+
GGML_ASSERT(!is_node);
55215539
}
55225540

55235541
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
@@ -19961,7 +19979,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
1996119979

1996219980
struct ggml_tensor * data = NULL;
1996319981

19964-
if (params.no_alloc == false) {
19982+
if (!params.no_alloc) {
1996519983
data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size);
1996619984

1996719985
ok = ok && data != NULL;
@@ -20002,7 +20020,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
2000220020
}
2000320021

2000420022
// point the data member to the appropriate location in the binary blob using the tensor infos
20005-
if (params.no_alloc == false) {
20023+
if (!params.no_alloc) {
2000620024
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
2000720025
cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data
2000820026
}

0 commit comments

Comments
 (0)