File tree Expand file tree Collapse file tree 1 file changed +4
-12
lines changed Expand file tree Collapse file tree 1 file changed +4
-12
lines changed Original file line number Diff line number Diff line change @@ -17171,18 +17171,10 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
17171
17171
17172
17172
auto * buft = ggml_backend_cpu_buffer_type();
17173
17173
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
17174
- ggml_tensor * output_tensor = lctx.model.output;
17175
- if (!output_tensor) {
17176
- // bert models don't have an output tensor, use the last layer
17177
- output_tensor = lctx.model.layers.back().layer_out_norm;
17178
- }
17179
- if (output_tensor) {
17180
- auto * output_buft = ggml_backend_buffer_get_type(output_tensor->buffer);
17181
- auto * output_dev = ggml_backend_buft_get_device(output_buft);
17182
- auto * output_dev_host_buft = ggml_backend_dev_host_buffer_type(output_dev);
17183
- if (output_dev_host_buft) {
17184
- buft = output_dev_host_buft;
17185
- }
17174
+ auto * output_dev = lctx.model.dev_output.dev;
17175
+ auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
17176
+ if (output_dev_host_buft) {
17177
+ buft = output_dev_host_buft;
17186
17178
}
17187
17179
lctx.buf_output = ggml_backend_buft_alloc_buffer(buft, new_size);
17188
17180
if (lctx.buf_output == nullptr) {
You can’t perform that action at this time.
0 commit comments