Skip to content

Commit 4cf08d9

Browse files
committed
metal : handle buffers larger than device's maxBufferLength
1 parent e4caa8d commit 4cf08d9

File tree

1 file changed

+57
-26
lines changed

1 file changed

+57
-26
lines changed

ggml-metal.m

Lines changed: 57 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,13 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
185185
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
186186
//fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
187187

188+
const int64_t tsize = ggml_nbytes(t);
189+
190+
// find the view that contains the tensor fully
188191
for (int i = 0; i < ctx->n_buffers; ++i) {
189192
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
190193

191-
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
194+
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
192195
*offs = (size_t) ioffs;
193196

194197
//fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
@@ -214,39 +217,67 @@ bool ggml_metal_add_buffer(
214217

215218
if (data) {
216219
// verify that the buffer does not overlap with any of the existing buffers
217-
for (int i = 0; i < ctx->n_buffers; ++i) {
218-
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
220+
//for (int i = 0; i < ctx->n_buffers; ++i) {
221+
// const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
219222

220-
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
221-
fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
222-
return false;
223-
}
224-
}
223+
// if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
224+
// fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
225+
// return false;
226+
// }
227+
//}
225228

226-
size_t page_size = getpagesize();
227-
size_t aligned_size = size;
228-
if ((aligned_size % page_size) != 0) {
229-
aligned_size += (page_size - (aligned_size % page_size));
229+
const size_t size_page = getpagesize();
230+
231+
size_t size_aligned = size;
232+
if ((size_aligned % size_page) != 0) {
233+
size_aligned += (size_page - (size_aligned % size_page));
230234
}
231235

232-
ctx->buffers[ctx->n_buffers].name = name;
233-
ctx->buffers[ctx->n_buffers].data = data;
234-
ctx->buffers[ctx->n_buffers].size = size;
236+
// the buffer fits into the max buffer size allowed by the device
237+
if (size_aligned <= ctx->device.maxBufferLength) {
238+
ctx->buffers[ctx->n_buffers].name = name;
239+
ctx->buffers[ctx->n_buffers].data = data;
240+
ctx->buffers[ctx->n_buffers].size = size;
235241

236-
if (ctx->device.maxBufferLength < aligned_size) {
237-
fprintf(stderr, "%s: buffer '%s' size %zu is larger than buffer maximum of %zu\n", __func__, name, aligned_size, ctx->device.maxBufferLength);
238-
return false;
239-
}
240-
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:aligned_size options:MTLResourceStorageModeShared deallocator:nil];
242+
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
243+
244+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
245+
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
246+
return false;
247+
}
241248

242-
if (ctx->buffers[ctx->n_buffers].metal == nil) {
243-
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
244-
return false;
249+
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
250+
251+
++ctx->n_buffers;
245252
} else {
246-
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB\n", __func__, name, aligned_size / 1024.0 / 1024.0);
247-
}
253+
// Example, say you want to map 16GB buffer. Create 3 views, each 8GB of size:
254+
//
255+
// view 0 has offset 0, i.e. range [0GB, 8GB]
256+
// view 1 has offset 4GB, i.e range [4GB, 8GB]
257+
// view 2 has offset 8GB, i.e. range [8GB, 16GB]
258+
//
259+
const size_t size_step = ctx->device.maxBufferLength/2;
260+
const size_t size_view = ctx->device.maxBufferLength;
261+
262+
for (size_t i = 0; i < size; i += size_step) {
263+
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
264+
265+
ctx->buffers[ctx->n_buffers].name = name;
266+
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
267+
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
268+
269+
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
248270

249-
++ctx->n_buffers;
271+
if (ctx->buffers[ctx->n_buffers].metal == nil) {
272+
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
273+
return false;
274+
}
275+
276+
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld\n", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
277+
278+
++ctx->n_buffers;
279+
}
280+
}
250281
}
251282

252283
return true;

0 commit comments

Comments
 (0)