@@ -193,10 +193,13 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
193
193
static id <MTLBuffer > ggml_metal_get_buffer (struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
194
194
// fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
195
195
196
+ const int64_t tsize = ggml_nbytes (t);
197
+
198
+ // find the view that contains the tensor fully
196
199
for (int i = 0 ; i < ctx->n_buffers ; ++i) {
197
200
const int64_t ioffs = (int64_t ) t->data - (int64_t ) ctx->buffers [i].data ;
198
201
199
- if (ioffs >= 0 && ioffs < (int64_t ) ctx->buffers [i].size ) {
202
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t ) ctx->buffers [i].size ) {
200
203
*offs = (size_t ) ioffs;
201
204
202
205
// fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
@@ -222,39 +225,67 @@ bool ggml_metal_add_buffer(
222
225
223
226
if (data) {
224
227
// verify that the buffer does not overlap with any of the existing buffers
225
- for (int i = 0 ; i < ctx->n_buffers ; ++i) {
226
- const int64_t ioffs = (int64_t ) data - (int64_t ) ctx->buffers [i].data ;
228
+ // for (int i = 0; i < ctx->n_buffers; ++i) {
229
+ // const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
227
230
228
- if (ioffs >= 0 && ioffs < (int64_t ) ctx->buffers [i].size ) {
229
- fprintf (stderr, " %s : error: buffer '%s ' overlaps with '%s '\n " , __func__, name, ctx->buffers [i].name );
230
- return false ;
231
- }
232
- }
231
+ // if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
232
+ // fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
233
+ // return false;
234
+ // }
235
+ // }
233
236
234
- size_t page_size = getpagesize ();
235
- size_t aligned_size = size;
236
- if ((aligned_size % page_size) != 0 ) {
237
- aligned_size += (page_size - (aligned_size % page_size));
237
+ const size_t size_page = getpagesize ();
238
+
239
+ size_t size_aligned = size;
240
+ if ((size_aligned % size_page) != 0 ) {
241
+ size_aligned += (size_page - (size_aligned % size_page));
238
242
}
239
243
240
- ctx->buffers [ctx->n_buffers].name = name;
241
- ctx->buffers [ctx->n_buffers].data = data;
242
- ctx->buffers [ctx->n_buffers].size = size;
244
+ // the buffer fits into the max buffer size allowed by the device
245
+ if (size_aligned <= ctx->device .maxBufferLength ) {
246
+ ctx->buffers [ctx->n_buffers].name = name;
247
+ ctx->buffers [ctx->n_buffers].data = data;
248
+ ctx->buffers [ctx->n_buffers].size = size;
243
249
244
- if (ctx->device .maxBufferLength < aligned_size) {
245
- fprintf (stderr, " %s : buffer '%s ' size %zu is larger than buffer maximum of %zu \n " , __func__, name, aligned_size, ctx->device .maxBufferLength );
246
- return false ;
247
- }
248
- ctx->buffers [ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy: data length: aligned_size options: MTLResourceStorageModeShared deallocator: nil ];
250
+ ctx->buffers [ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy: data length: size_aligned options: MTLResourceStorageModeShared deallocator: nil ];
249
251
250
- if (ctx->buffers [ctx->n_buffers].metal == nil ) {
251
- fprintf (stderr, " %s : failed to allocate '%-16s ' buffer, size = %8.2f MB\n " , __func__, name, aligned_size / 1024.0 / 1024.0 );
252
- return false ;
252
+ if (ctx->buffers [ctx->n_buffers].metal == nil ) {
253
+ fprintf (stderr, " %s : failed to allocate '%-16s ' buffer, size = %8.2f MB\n " , __func__, name, size_aligned / 1024.0 / 1024.0 );
254
+ return false ;
255
+ }
256
+
257
+ fprintf (stderr, " %s : allocated '%-16s ' buffer, size = %8.2f MB\n " , __func__, name, size_aligned / 1024.0 / 1024.0 );
258
+
259
+ ++ctx->n_buffers ;
253
260
} else {
254
- fprintf (stderr, " %s : allocated '%-16s ' buffer, size = %8.2f MB\n " , __func__, name, aligned_size / 1024.0 / 1024.0 );
255
- }
261
+ // Example, say you want to map 16GB buffer. Create 3 views, each 8GB of size:
262
+ //
263
+ // view 0 has offset 0, i.e. range [0GB, 8GB]
264
+ // view 1 has offset 4GB, i.e range [4GB, 8GB]
265
+ // view 2 has offset 8GB, i.e. range [8GB, 16GB]
266
+ //
267
+ const size_t size_step = ctx->device .maxBufferLength /2 ;
268
+ const size_t size_view = ctx->device .maxBufferLength ;
269
+
270
+ for (size_t i = 0 ; i < size; i += size_step) {
271
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
272
+
273
+ ctx->buffers [ctx->n_buffers].name = name;
274
+ ctx->buffers [ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
275
+ ctx->buffers [ctx->n_buffers].size = size_step_aligned;
276
+
277
+ ctx->buffers [ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy: (void *) ((uint8_t *) data + i) length: size_step_aligned options: MTLResourceStorageModeShared deallocator: nil ];
278
+
279
+ if (ctx->buffers [ctx->n_buffers].metal == nil ) {
280
+ fprintf (stderr, " %s : failed to allocate '%-16s ' buffer, size = %8.2f MB\n " , __func__, name, size_step_aligned / 1024.0 / 1024.0 );
281
+ return false ;
282
+ }
256
283
257
- ++ctx->n_buffers ;
284
+ fprintf (stderr, " %s : allocated '%-16s ' buffer, size = %8.2f MB, offs = %12ld \n " , __func__, name, size_step_aligned / 1024.0 / 1024.0 , i);
285
+
286
+ ++ctx->n_buffers ;
287
+ }
288
+ }
258
289
}
259
290
260
291
return true ;
0 commit comments