|
12 | 12 | #include <cstdlib>
|
13 | 13 | #include <climits>
|
14 | 14 |
|
| 15 | +#include <memory> |
15 | 16 | #include <string>
|
16 | 17 | #include <vector>
|
17 | 18 | #include <stdexcept>
|
@@ -413,89 +414,80 @@ struct llama_mlock {
|
413 | 414 |
|
414 | 415 | // Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
|
415 | 416 | struct llama_buffer {
|
416 |
| - uint8_t * addr = NULL; |
| 417 | +#ifdef GGML_USE_METAL |
| 418 | + struct metal_deleter { void operator()(uint8_t* addr) const { free(addr); } }; |
| 419 | + std::unique_ptr<uint8_t, metal_deleter> addr; |
| 420 | +#else |
| 421 | + std::unique_ptr<uint8_t[]> addr; |
| 422 | +#endif |
417 | 423 | size_t size = 0;
|
418 | 424 |
|
419 | 425 | llama_buffer() = default;
|
| 426 | + llama_buffer(const llama_buffer& rhs) { *this = rhs; } |
| 427 | + llama_buffer& operator=(const llama_buffer& rhs) { |
| 428 | + resize(rhs.size); |
| 429 | + memcpy(addr.get(), rhs.addr.get(), size); |
| 430 | + return *this; |
| 431 | + } |
420 | 432 |
|
421 | 433 | void resize(size_t len) {
|
| 434 | + addr.reset(); |
422 | 435 | #ifdef GGML_USE_METAL
|
423 |
| - free(addr); |
424 |
| - int result = posix_memalign((void **) &addr, getpagesize(), len); |
| 436 | + size = 0; |
| 437 | + uint8_t* ptr; |
| 438 | + int result = posix_memalign((void **) &ptr, getpagesize(), len); |
425 | 439 | if (result == 0) {
|
426 |
| - memset(addr, 0, len); |
427 |
| - } |
428 |
| - else { |
429 |
| - addr = NULL; |
| 440 | + memset(ptr, 0, len); |
| 441 | + addr.reset(ptr); |
| 442 | + size = len; |
430 | 443 | }
|
431 | 444 | #else
|
432 |
| - delete[] addr; |
433 |
| - addr = new uint8_t[len]; |
434 |
| -#endif |
| 445 | + addr.reset(new uint8_t[len]); |
435 | 446 | size = len;
|
436 |
| - } |
437 |
| - |
438 |
| - ~llama_buffer() { |
439 |
| -#ifdef GGML_USE_METAL |
440 |
| - free(addr); |
441 |
| -#else |
442 |
| - delete[] addr; |
443 | 447 | #endif
|
444 |
| - addr = NULL; |
445 | 448 | }
|
446 |
| - |
447 |
| - // disable copy and move |
448 |
| - llama_buffer(const llama_buffer&) = delete; |
449 |
| - llama_buffer(llama_buffer&&) = delete; |
450 |
| - llama_buffer& operator=(const llama_buffer&) = delete; |
451 |
| - llama_buffer& operator=(llama_buffer&&) = delete; |
452 | 449 | };
|
453 | 450 |
|
454 | 451 | #ifdef GGML_USE_CUBLAS
|
455 | 452 | #include "ggml-cuda.h"
|
456 | 453 | struct llama_ctx_buffer {
|
457 |
| - uint8_t * addr = NULL; |
458 |
| - bool is_cuda; |
| 454 | + struct cuda_deleter { |
| 455 | + bool is_cuda; |
| 456 | + void operator()(uint8_t* addr) const { |
| 457 | + if (addr) { |
| 458 | + if (is_cuda) { |
| 459 | + ggml_cuda_host_free(addr); |
| 460 | + } else { |
| 461 | + delete[] addr; |
| 462 | + } |
| 463 | + } |
| 464 | + } |
| 465 | + }; |
| 466 | + using Addr = std::unique_ptr<uint8_t, cuda_deleter>; |
| 467 | + Addr addr; |
459 | 468 | size_t size = 0;
|
460 | 469 |
|
461 | 470 | llama_ctx_buffer() = default;
|
| 471 | + llama_ctx_buffer(const llama_ctx_buffer& rhs) { *this = rhs; } |
| 472 | + llama_ctx_buffer& operator=(const llama_ctx_buffer& rhs) { |
| 473 | + resize(rhs.size); |
| 474 | + memcpy(addr.get(), rhs.addr.get(), size); |
| 475 | + return *this; |
| 476 | + } |
462 | 477 |
|
463 |
| - void resize(size_t size) { |
464 |
| - free(); |
| 478 | + void resize(size_t len) { |
| 479 | + addr.reset(); |
465 | 480 |
|
466 |
| - addr = (uint8_t *) ggml_cuda_host_malloc(size); |
467 |
| - if (addr) { |
468 |
| - is_cuda = true; |
469 |
| - } |
470 |
| - else { |
| 481 | + bool is_cuda = true; |
| 482 | + auto* ptr = (uint8_t*) ggml_cuda_host_malloc(len); |
| 483 | + if (!ptr) { |
471 | 484 | // fall back to pageable memory
|
472 |
| - addr = new uint8_t[size]; |
| 485 | + ptr = new uint8_t[len]; |
473 | 486 | is_cuda = false;
|
474 | 487 | }
|
475 |
| - this->size = size; |
476 |
| - } |
477 |
| - |
478 |
| - void free() { |
479 |
| - if (addr) { |
480 |
| - if (is_cuda) { |
481 |
| - ggml_cuda_host_free(addr); |
482 |
| - } |
483 |
| - else { |
484 |
| - delete[] addr; |
485 |
| - } |
486 |
| - } |
487 |
| - addr = NULL; |
488 |
| - } |
489 |
| - |
490 |
| - ~llama_ctx_buffer() { |
491 |
| - free(); |
| 488 | + addr = Addr(ptr, {is_cuda}); |
| 489 | + size = len; |
492 | 490 | }
|
493 |
| - |
494 |
| - // disable copy and move |
495 |
| - llama_ctx_buffer(const llama_ctx_buffer&) = delete; |
496 |
| - llama_ctx_buffer(llama_ctx_buffer&&) = delete; |
497 |
| - llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete; |
498 |
| - llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete; |
499 | 491 | };
|
500 | 492 | #else
|
501 | 493 | typedef llama_buffer llama_ctx_buffer;
|
|
0 commit comments