Skip to content

Commit 96c20e4

Browse files
slarenarthw
authored andcommitted
llama : improve output buffer type selection (ggml-org#10098)
1 parent 948dfbf commit 96c20e4

File tree

1 file changed

+4
-12
lines changed

1 file changed

+4
-12
lines changed

src/llama.cpp

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17170,18 +17170,10 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
1717017170

1717117171
auto * buft = ggml_backend_cpu_buffer_type();
1717217172
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
17173-
ggml_tensor * output_tensor = lctx.model.output;
17174-
if (!output_tensor) {
17175-
// bert models don't have an output tensor, use the last layer
17176-
output_tensor = lctx.model.layers.back().layer_out_norm;
17177-
}
17178-
if (output_tensor) {
17179-
auto * output_buft = ggml_backend_buffer_get_type(output_tensor->buffer);
17180-
auto * output_dev = ggml_backend_buft_get_device(output_buft);
17181-
auto * output_dev_host_buft = ggml_backend_dev_host_buffer_type(output_dev);
17182-
if (output_dev_host_buft) {
17183-
buft = output_dev_host_buft;
17184-
}
17173+
auto * output_dev = lctx.model.dev_output.dev;
17174+
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
17175+
if (output_dev_host_buft) {
17176+
buft = output_dev_host_buft;
1718517177
}
1718617178
lctx.buf_output = ggml_backend_buft_alloc_buffer(buft, new_size);
1718717179
if (lctx.buf_output == nullptr) {

0 commit comments

Comments
 (0)