Skip to content

Commit a99cc90

Browse files
committed
Revert making llama_context, llama_buffer, and llama_ctx_buffer copyable. Beams now share one llama_context via a pointer, and pass their own growing tokens vector to llama_eval() on each iteration.
1 parent b000f76 commit a99cc90

File tree

2 files changed

+100
-115
lines changed

2 files changed

+100
-115
lines changed

llama-util.h

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <cstdlib>
1313
#include <climits>
1414

15-
#include <memory>
1615
#include <string>
1716
#include <vector>
1817
#include <stdexcept>
@@ -414,80 +413,90 @@ struct llama_mlock {
414413

415414
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
416415
struct llama_buffer {
417-
#ifdef GGML_USE_METAL
418-
struct metal_deleter { void operator()(uint8_t* addr) const { free(addr); } };
419-
std::unique_ptr<uint8_t, metal_deleter> addr;
420-
#else
421-
std::unique_ptr<uint8_t[]> addr;
422-
#endif
416+
uint8_t * addr = NULL;
423417
size_t size = 0;
424418

425419
llama_buffer() = default;
426-
llama_buffer(const llama_buffer& rhs) { *this = rhs; }
427-
llama_buffer& operator=(const llama_buffer& rhs) {
428-
resize(rhs.size);
429-
memcpy(addr.get(), rhs.addr.get(), size);
430-
return *this;
431-
}
432420

433421
void resize(size_t len) {
434-
addr.reset();
435422
#ifdef GGML_USE_METAL
436-
size = 0;
437-
uint8_t* ptr;
438-
int result = posix_memalign((void **) &ptr, getpagesize(), len);
423+
free(addr);
424+
int result = posix_memalign((void **) &addr, getpagesize(), len);
439425
if (result == 0) {
440-
memset(ptr, 0, len);
441-
addr.reset(ptr);
442-
size = len;
426+
memset(addr, 0, len);
427+
}
428+
else {
429+
addr = NULL;
430+
len = 0;
443431
}
444432
#else
445-
addr.reset(new uint8_t[len]);
433+
delete[] addr;
434+
addr = new uint8_t[len];
435+
#endif
446436
size = len;
437+
}
438+
439+
~llama_buffer() {
440+
#ifdef GGML_USE_METAL
441+
free(addr);
442+
#else
443+
delete[] addr;
447444
#endif
445+
addr = NULL;
448446
}
447+
448+
// disable copy and move
449+
llama_buffer(const llama_buffer&) = delete;
450+
llama_buffer(llama_buffer&&) = delete;
451+
llama_buffer& operator=(const llama_buffer&) = delete;
452+
llama_buffer& operator=(llama_buffer&&) = delete;
449453
};
450454

451455
#ifdef GGML_USE_CUBLAS
452456
#include "ggml-cuda.h"
453457
struct llama_ctx_buffer {
454-
struct cuda_deleter {
455-
bool is_cuda;
456-
void operator()(uint8_t* addr) const {
457-
if (addr) {
458-
if (is_cuda) {
459-
ggml_cuda_host_free(addr);
460-
} else {
461-
delete[] addr;
462-
}
463-
}
464-
}
465-
};
466-
using Addr = std::unique_ptr<uint8_t, cuda_deleter>;
467-
Addr addr;
458+
uint8_t * addr = NULL;
459+
bool is_cuda;
468460
size_t size = 0;
469461

470462
llama_ctx_buffer() = default;
471-
llama_ctx_buffer(const llama_ctx_buffer& rhs) { *this = rhs; }
472-
llama_ctx_buffer& operator=(const llama_ctx_buffer& rhs) {
473-
resize(rhs.size);
474-
memcpy(addr.get(), rhs.addr.get(), size);
475-
return *this;
476-
}
477463

478-
void resize(size_t len) {
479-
addr.reset();
464+
void resize(size_t size) {
465+
free();
480466

481-
bool is_cuda = true;
482-
auto* ptr = (uint8_t*) ggml_cuda_host_malloc(len);
483-
if (!ptr) {
467+
addr = (uint8_t *) ggml_cuda_host_malloc(size);
468+
if (addr) {
469+
is_cuda = true;
470+
}
471+
else {
484472
// fall back to pageable memory
485-
ptr = new uint8_t[len];
473+
addr = new uint8_t[size];
486474
is_cuda = false;
487475
}
488-
addr = Addr(ptr, {is_cuda});
489-
size = len;
476+
this->size = size;
490477
}
478+
479+
void free() {
480+
if (addr) {
481+
if (is_cuda) {
482+
ggml_cuda_host_free(addr);
483+
}
484+
else {
485+
delete[] addr;
486+
}
487+
}
488+
addr = NULL;
489+
}
490+
491+
~llama_ctx_buffer() {
492+
free();
493+
}
494+
495+
// disable copy and move
496+
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
497+
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
498+
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
499+
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
491500
};
492501
#else
493502
typedef llama_buffer llama_ctx_buffer;

0 commit comments

Comments
 (0)