Skip to content

Commit 6a2ac4f

Browse files
committed
Make llama_buffer and llama_ctx_buffer copyable+moveable.
1 parent ac77583 commit 6a2ac4f

File tree

1 file changed

+55
-45
lines changed

1 file changed

+55
-45
lines changed

llama-util.h

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -417,39 +417,45 @@ struct llama_buffer {
417417
size_t size = 0;
418418

419419
llama_buffer() = default;
420+
~llama_buffer() { resize(0); }
421+
llama_buffer(const llama_buffer& rhs) { *this = rhs; }
422+
llama_buffer& operator=(const llama_buffer& rhs) {
423+
resize(rhs.size);
424+
memcpy(addr, rhs.addr, size);
425+
return *this;
426+
}
427+
llama_buffer(llama_buffer&& rhs):addr(rhs.addr), size(rhs.size) {
428+
new (&rhs) llama_buffer();
429+
}
430+
llama_buffer& operator=(llama_buffer&& rhs) {
431+
this->~llama_buffer();
432+
addr = rhs.addr;
433+
size = rhs.size;
434+
new (&rhs) llama_buffer();
435+
return *this;
436+
}
420437

421438
void resize(size_t len) {
439+
size = 0;
422440
#ifdef GGML_USE_METAL
423441
free(addr);
424-
int result = posix_memalign((void **) &addr, getpagesize(), len);
425-
if (result == 0) {
426-
memset(addr, 0, len);
427-
}
428-
else {
429-
addr = NULL;
430-
len = 0;
442+
if (len) {
443+
int result = posix_memalign((void **) &addr, getpagesize(), len);
444+
if (result == 0) {
445+
memset(addr, 0, len);
446+
size = len;
447+
} else {
448+
addr = NULL;
449+
}
431450
}
432451
#else
433452
delete[] addr;
434-
addr = new uint8_t[len];
435-
#endif
436-
size = len;
437-
}
438-
439-
~llama_buffer() {
440-
#ifdef GGML_USE_METAL
441-
free(addr);
442-
#else
443-
delete[] addr;
453+
if (len) {
454+
addr = new uint8_t[len];
455+
size = len;
456+
}
444457
#endif
445-
addr = NULL;
446458
}
447-
448-
// disable copy and move
449-
llama_buffer(const llama_buffer&) = delete;
450-
llama_buffer(llama_buffer&&) = delete;
451-
llama_buffer& operator=(const llama_buffer&) = delete;
452-
llama_buffer& operator=(llama_buffer&&) = delete;
453459
};
454460

455461
#ifdef GGML_USE_CUBLAS
@@ -459,44 +465,48 @@ struct llama_ctx_buffer {
459465
bool is_cuda;
460466
size_t size = 0;
461467

462-
llama_ctx_buffer() = default;
463-
464-
void resize(size_t size) {
468+
void resize(size_t len) {
465469
free();
466-
467-
addr = (uint8_t *) ggml_cuda_host_malloc(size);
468-
if (addr) {
469-
is_cuda = true;
470-
}
471-
else {
470+
addr = (uint8_t *) ggml_cuda_host_malloc(len);
471+
is_cuda = static_cast<bool>(addr);
472+
if (!is_cuda) {
472473
// fall back to pageable memory
473474
addr = new uint8_t[size];
474-
is_cuda = false;
475475
}
476-
this->size = size;
476+
size = len;
477477
}
478478

479479
void free() {
480480
if (addr) {
481481
if (is_cuda) {
482482
ggml_cuda_host_free(addr);
483-
}
484-
else {
483+
} else {
485484
delete[] addr;
486485
}
487486
}
488-
addr = NULL;
487+
new (this) llama_ctx_buffer();
489488
}
490489

491-
~llama_ctx_buffer() {
492-
free();
490+
llama_ctx_buffer() = default;
491+
~llama_ctx_buffer() { free(); }
492+
llama_ctx_buffer(const llama_ctx_buffer& rhs) { *this = rhs; }
493+
llama_ctx_buffer& operator=(const llama_ctx_buffer& rhs) {
494+
resize(rhs.size);
495+
memcpy(addr, rhs.addr, size); // cuda memcpy if is_cuda?
496+
return *this;
497+
}
498+
llama_ctx_buffer(llama_ctx_buffer&& rhs):addr(rhs.addr), is_cuda(rhs.is_cuda), size(rhs.size) {
499+
new (&rhs) llama_ctx_buffer();
500+
}
501+
llama_ctx_buffer& operator=(llama_ctx_buffer&& rhs) {
502+
this->~llama_ctx_buffer();
503+
addr = rhs.addr;
504+
is_cuda = rhs.is_cuda;
505+
size = rhs.size;
506+
new (&rhs) llama_ctx_buffer();
507+
return *this;
493508
}
494509

495-
// disable copy and move
496-
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
497-
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
498-
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
499-
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
500510
};
501511
#else
502512
typedef llama_buffer llama_ctx_buffer;

0 commit comments

Comments
 (0)