Skip to content

Commit 9bda0d8

Browse files
committed
Add llama_beam_search().
1 parent 785829d commit 9bda0d8

File tree

3 files changed

+276
-71
lines changed

3 files changed

+276
-71
lines changed

llama-util.h

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <cstdlib>
1313
#include <climits>
1414

15+
#include <memory>
1516
#include <string>
1617
#include <vector>
1718
#include <stdexcept>
@@ -413,89 +414,80 @@ struct llama_mlock {
413414

414415
// Replacement for std::vector<uint8_t> that doesn't require zero-initialization.
415416
struct llama_buffer {
416-
uint8_t * addr = NULL;
417+
#ifdef GGML_USE_METAL
418+
struct metal_deleter { void operator()(uint8_t* addr) const { free(addr); } };
419+
std::unique_ptr<uint8_t, metal_deleter> addr;
420+
#else
421+
std::unique_ptr<uint8_t[]> addr;
422+
#endif
417423
size_t size = 0;
418424

419425
llama_buffer() = default;
426+
llama_buffer(const llama_buffer& rhs) { *this = rhs; }
427+
llama_buffer& operator=(const llama_buffer& rhs) {
428+
resize(rhs.size);
429+
memcpy(addr.get(), rhs.addr.get(), size);
430+
return *this;
431+
}
420432

421433
void resize(size_t len) {
434+
addr.reset();
422435
#ifdef GGML_USE_METAL
423-
free(addr);
424-
int result = posix_memalign((void **) &addr, getpagesize(), len);
436+
size = 0;
437+
uint8_t* ptr;
438+
int result = posix_memalign((void **) &ptr, getpagesize(), len);
425439
if (result == 0) {
426-
memset(addr, 0, len);
427-
}
428-
else {
429-
addr = NULL;
440+
memset(ptr, 0, len);
441+
addr.reset(ptr);
442+
size = len;
430443
}
431444
#else
432-
delete[] addr;
433-
addr = new uint8_t[len];
434-
#endif
445+
addr.reset(new uint8_t[len]);
435446
size = len;
436-
}
437-
438-
~llama_buffer() {
439-
#ifdef GGML_USE_METAL
440-
free(addr);
441-
#else
442-
delete[] addr;
443447
#endif
444-
addr = NULL;
445448
}
446-
447-
// disable copy and move
448-
llama_buffer(const llama_buffer&) = delete;
449-
llama_buffer(llama_buffer&&) = delete;
450-
llama_buffer& operator=(const llama_buffer&) = delete;
451-
llama_buffer& operator=(llama_buffer&&) = delete;
452449
};
453450

454451
#ifdef GGML_USE_CUBLAS
455452
#include "ggml-cuda.h"
456453
struct llama_ctx_buffer {
457-
uint8_t * addr = NULL;
458-
bool is_cuda;
454+
struct cuda_deleter {
455+
bool is_cuda;
456+
void operator()(uint8_t* addr) const {
457+
if (addr) {
458+
if (is_cuda) {
459+
ggml_cuda_host_free(addr);
460+
} else {
461+
delete[] addr;
462+
}
463+
}
464+
}
465+
};
466+
using Addr = std::unique_ptr<uint8_t, cuda_deleter>;
467+
Addr addr;
459468
size_t size = 0;
460469

461470
llama_ctx_buffer() = default;
471+
llama_ctx_buffer(const llama_ctx_buffer& rhs) { *this = rhs; }
472+
llama_ctx_buffer& operator=(const llama_ctx_buffer& rhs) {
473+
resize(rhs.size);
474+
memcpy(addr.get(), rhs.addr.get(), size);
475+
return *this;
476+
}
462477

463-
void resize(size_t size) {
464-
free();
478+
void resize(size_t len) {
479+
addr.reset();
465480

466-
addr = (uint8_t *) ggml_cuda_host_malloc(size);
467-
if (addr) {
468-
is_cuda = true;
469-
}
470-
else {
481+
bool is_cuda = true;
482+
auto* ptr = (uint8_t*) ggml_cuda_host_malloc(len);
483+
if (!ptr) {
471484
// fall back to pageable memory
472-
addr = new uint8_t[size];
485+
ptr = new uint8_t[len];
473486
is_cuda = false;
474487
}
475-
this->size = size;
476-
}
477-
478-
void free() {
479-
if (addr) {
480-
if (is_cuda) {
481-
ggml_cuda_host_free(addr);
482-
}
483-
else {
484-
delete[] addr;
485-
}
486-
}
487-
addr = NULL;
488-
}
489-
490-
~llama_ctx_buffer() {
491-
free();
488+
addr = Addr(ptr, {is_cuda});
489+
size = len;
492490
}
493-
494-
// disable copy and move
495-
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
496-
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
497-
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
498-
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
499491
};
500492
#else
501493
typedef llama_buffer llama_ctx_buffer;

0 commit comments

Comments
 (0)