@@ -185,10 +185,13 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
185
185
static id <MTLBuffer > ggml_metal_get_buffer (struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
186
186
// fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
187
187
188
+ const int64_t tsize = ggml_nbytes (t);
189
+
190
+ // find the view that contains the tensor fully
188
191
for (int i = 0 ; i < ctx->n_buffers ; ++i) {
189
192
const int64_t ioffs = (int64_t ) t->data - (int64_t ) ctx->buffers [i].data ;
190
193
191
- if (ioffs >= 0 && ioffs < (int64_t ) ctx->buffers [i].size ) {
194
+ if (ioffs >= 0 && ioffs + tsize <= (int64_t ) ctx->buffers [i].size ) {
192
195
*offs = (size_t ) ioffs;
193
196
194
197
// fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
@@ -214,39 +217,67 @@ bool ggml_metal_add_buffer(
214
217
215
218
if (data) {
216
219
// verify that the buffer does not overlap with any of the existing buffers
217
- for (int i = 0 ; i < ctx->n_buffers ; ++i) {
218
- const int64_t ioffs = (int64_t ) data - (int64_t ) ctx->buffers [i].data ;
220
+ // for (int i = 0; i < ctx->n_buffers; ++i) {
221
+ // const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
219
222
220
- if (ioffs >= 0 && ioffs < (int64_t ) ctx->buffers [i].size ) {
221
- fprintf (stderr, " %s : error: buffer '%s ' overlaps with '%s '\n " , __func__, name, ctx->buffers [i].name );
222
- return false ;
223
- }
224
- }
223
+ // if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
224
+ // fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
225
+ // return false;
226
+ // }
227
+ // }
225
228
226
- size_t page_size = getpagesize ();
227
- size_t aligned_size = size;
228
- if ((aligned_size % page_size) != 0 ) {
229
- aligned_size += (page_size - (aligned_size % page_size));
229
+ const size_t size_page = getpagesize ();
230
+
231
+ size_t size_aligned = size;
232
+ if ((size_aligned % size_page) != 0 ) {
233
+ size_aligned += (size_page - (size_aligned % size_page));
230
234
}
231
235
232
- ctx->buffers [ctx->n_buffers].name = name;
233
- ctx->buffers [ctx->n_buffers].data = data;
234
- ctx->buffers [ctx->n_buffers].size = size;
236
+ // the buffer fits into the max buffer size allowed by the device
237
+ if (size_aligned <= ctx->device .maxBufferLength ) {
238
+ ctx->buffers [ctx->n_buffers].name = name;
239
+ ctx->buffers [ctx->n_buffers].data = data;
240
+ ctx->buffers [ctx->n_buffers].size = size;
235
241
236
- if (ctx->device .maxBufferLength < aligned_size) {
237
- fprintf (stderr, " %s : buffer '%s ' size %zu is larger than buffer maximum of %zu \n " , __func__, name, aligned_size, ctx->device .maxBufferLength );
238
- return false ;
239
- }
240
- ctx->buffers [ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy: data length: aligned_size options: MTLResourceStorageModeShared deallocator: nil ];
242
+ ctx->buffers [ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy: data length: size_aligned options: MTLResourceStorageModeShared deallocator: nil ];
243
+
244
+ if (ctx->buffers [ctx->n_buffers].metal == nil ) {
245
+ fprintf (stderr, " %s : failed to allocate '%-16s ' buffer, size = %8.2f MB\n " , __func__, name, size_aligned / 1024.0 / 1024.0 );
246
+ return false ;
247
+ }
241
248
242
- if (ctx-> buffers [ctx->n_buffers]. metal == nil ) {
243
- fprintf (stderr, " %s : failed to allocate ' %-16s ' buffer, size = %8.2f MB \n " , __func__, name, aligned_size / 1024.0 / 1024.0 );
244
- return false ;
249
+ fprintf (stderr, " %s : allocated ' %-16s ' buffer, size = %8.2f MB \n " , __func__, name, size_aligned / 1024.0 / 1024.0 );
250
+
251
+ ++ctx-> n_buffers ;
245
252
} else {
246
- fprintf (stderr, " %s : allocated '%-16s ' buffer, size = %8.2f MB\n " , __func__, name, aligned_size / 1024.0 / 1024.0 );
247
- }
253
+ // Example, say you want to map 16GB buffer. Create 3 views, each 8GB of size:
254
+ //
255
+ // view 0 has offset 0, i.e. range [0GB, 8GB]
256
+ // view 1 has offset 4GB, i.e range [4GB, 8GB]
257
+ // view 2 has offset 8GB, i.e. range [8GB, 16GB]
258
+ //
259
+ const size_t size_step = ctx->device .maxBufferLength /2 ;
260
+ const size_t size_view = ctx->device .maxBufferLength ;
261
+
262
+ for (size_t i = 0 ; i < size; i += size_step) {
263
+ const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
264
+
265
+ ctx->buffers [ctx->n_buffers].name = name;
266
+ ctx->buffers [ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
267
+ ctx->buffers [ctx->n_buffers].size = size_step_aligned;
268
+
269
+ ctx->buffers [ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy: (void *) ((uint8_t *) data + i) length: size_step_aligned options: MTLResourceStorageModeShared deallocator: nil ];
248
270
249
- ++ctx->n_buffers ;
271
+ if (ctx->buffers [ctx->n_buffers].metal == nil ) {
272
+ fprintf (stderr, " %s : failed to allocate '%-16s ' buffer, size = %8.2f MB\n " , __func__, name, size_step_aligned / 1024.0 / 1024.0 );
273
+ return false ;
274
+ }
275
+
276
+ fprintf (stderr, " %s : allocated '%-16s ' buffer, size = %8.2f MB, offs = %12ld \n " , __func__, name, size_step_aligned / 1024.0 / 1024.0 , i);
277
+
278
+ ++ctx->n_buffers ;
279
+ }
280
+ }
250
281
}
251
282
252
283
return true ;
0 commit comments