Skip to content

Commit 6bc7411

Browse files
committed
Merge remote-tracking branch 'upstream' into cancel-model-load
2 parents bdfe4ba + afefa31 commit 6bc7411

27 files changed

+1199
-1034
lines changed

.github/workflows/docker.yml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,31 @@ jobs:
5252
username: ${{ github.repository_owner }}
5353
password: ${{ secrets.GITHUB_TOKEN }}
5454

55+
# https://github.com/jlumbroso/free-disk-space/tree/54081f138730dfa15788a46383842cd2f914a1be#example
56+
- name: Free Disk Space (Ubuntu)
57+
uses: jlumbroso/free-disk-space@main
58+
with:
59+
# this might remove tools that are actually needed,
60+
# if set to "true" but frees about 6 GB
61+
tool-cache: false
62+
63+
# all of these default to true, but feel free to set to
64+
# "false" if necessary for your workflow
65+
android: true
66+
dotnet: true
67+
haskell: true
68+
large-packages: true
69+
docker-images: true
70+
swap-storage: true
71+
5572
- name: Build and push Docker image (versioned)
5673
if: github.event_name == 'push'
5774
uses: docker/build-push-action@v4
5875
with:
5976
context: .
6077
push: true
6178
platforms: ${{ matrix.config.platforms }}
62-
tags: "ghcr.io/ggerganov/llama.cpp:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
79+
tags: "ghcr.io/${{ github.repository_owner }}/llama.cpp:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}"
6380
file: ${{ matrix.config.dockerfile }}
6481

6582
- name: Build and push Docker image (tagged)
@@ -68,5 +85,5 @@ jobs:
6885
context: .
6986
push: ${{ github.event_name == 'push' }}
7087
platforms: ${{ matrix.config.platforms }}
71-
tags: "ghcr.io/ggerganov/llama.cpp:${{ matrix.config.tag }}"
88+
tags: "ghcr.io/${{ github.repository_owner }}/llama.cpp:${{ matrix.config.tag }}"
7289
file: ${{ matrix.config.dockerfile }}

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
9191
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
9292
"llama: max. batch size for using peer access")
9393
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
94+
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
9495
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
9596
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
9697
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
@@ -377,6 +378,9 @@ if (LLAMA_HIPBLAS)
377378
if (${hipblas_FOUND} AND ${hip_FOUND})
378379
message(STATUS "HIP and hipBLAS found")
379380
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
381+
if (LLAMA_HIP_UMA)
382+
add_compile_definitions(GGML_HIP_UMA)
383+
endif()
380384
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
381385
if (BUILD_SHARED_LIBS)
382386
set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ test: $(TEST_TARGETS)
6767
./$$test_target; \
6868
fi; \
6969
if [ $$? -ne 0 ]; then \
70-
printf 'Test $$test_target FAILED!\n\n' $$test_target; \
70+
printf 'Test %s FAILED!\n\n' $$test_target; \
7171
failures=$$(( failures + 1 )); \
7272
else \
7373
printf 'Test %s passed.\n\n' $$test_target; \
@@ -608,7 +608,7 @@ save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(C
608608
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp examples/server/index.html.hpp examples/server/index.js.hpp examples/server/completion.js.hpp examples/llava/clip.cpp examples/llava/clip.h common/stb_image.h ggml.o llama.o $(COMMON_DEPS) grammar-parser.o $(OBJS)
609609
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) -Wno-cast-qual
610610

611-
gguf: examples/gguf/gguf.cpp ggml.o llama.o $(OBJS)
611+
gguf: examples/gguf/gguf.cpp ggml.o $(OBJS)
612612
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
613613

614614
train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp ggml.o llama.o $(COMMON_DEPS) train.o $(OBJS)

README.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -432,14 +432,15 @@ Building the program with BLAS support may lead to some performance improvements
432432
```bash
433433
make LLAMA_HIPBLAS=1
434434
```
435-
- Using `CMake` for Linux:
435+
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
436436
```bash
437-
mkdir build
438-
cd build
439-
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON
440-
cmake --build .
437+
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ \
438+
cmake -H. -Bbuild -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
439+
&& cmake --build build -- -j 16
441440
```
442-
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS):
441+
On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DLLAMA_HIP_UMA=ON"`.
442+
However, this hurts performance for non-integrated GPUs.
443+
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
443444
```bash
444445
set PATH=%HIP_PATH%\bin;%PATH%
445446
mkdir build
@@ -448,10 +449,11 @@ Building the program with BLAS support may lead to some performance improvements
448449
cmake --build .
449450
```
450451
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
452+
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
451453
452454
453455
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
454-
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3.
456+
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
455457
The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above):
456458
457459
| Option | Legal values | Default | Description |
@@ -982,6 +984,8 @@ docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /
982984
- There are no strict rules for the code style, but try to follow the patterns in the code (indentation, spaces, etc.). Vertical alignment makes things more readable and easier to batch edit
983985
- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a`
984986
- See [good first issues](https://github.com/ggerganov/llama.cpp/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) for tasks suitable for first contributions
987+
- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices
988+
- Matrix multiplication is unconventional: [`z = ggml_mul_mat(ctx, x, y)`](https://github.com/ggerganov/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means `zT = x @ yT`
985989

986990
### Docs
987991

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -920,7 +920,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
920920
printf(" -m FNAME, --model FNAME\n");
921921
printf(" model path (default: %s)\n", params.model.c_str());
922922
printf(" -md FNAME, --model-draft FNAME\n");
923-
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
923+
printf(" draft model for speculative decoding\n");
924924
printf(" -ld LOGDIR, --logdir LOGDIR\n");
925925
printf(" path under which to save YAML logs (no logging if unset)\n");
926926
printf(" --override-kv KEY=TYPE:VALUE\n");

examples/baby-llama/baby-llama.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -575,10 +575,7 @@ static struct ggml_tensor * forward(
575575

576576
// KQ_scaled = KQ / sqrt(n_embd/n_head)
577577
// KQ_scaled shape [n_past + N, N, n_head, 1]
578-
struct ggml_tensor * KQ_scaled =
579-
ggml_scale(ctx0,
580-
KQ,
581-
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
578+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
582579

583580
// KQ_masked = mask_past(KQ_scaled)
584581
// KQ_masked shape [n_past + N, N, n_head, 1]
@@ -844,10 +841,7 @@ static struct ggml_tensor * forward_batch(
844841

845842
// KQ_scaled = KQ / sqrt(n_embd/n_head)
846843
// KQ_scaled shape [n_past + N, N, n_head, n_batch]
847-
struct ggml_tensor * KQ_scaled =
848-
ggml_scale(ctx0,
849-
KQ,
850-
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
844+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
851845
assert_shape_4d(KQ_scaled, n_past + N, N, n_head, n_batch);
852846

853847
// KQ_masked = mask_past(KQ_scaled)
@@ -1131,10 +1125,7 @@ static struct ggml_tensor * forward_lora(
11311125

11321126
// KQ_scaled = KQ / sqrt(n_embd/n_head)
11331127
// KQ_scaled shape [n_past + N, N, n_head, 1]
1134-
struct ggml_tensor * KQ_scaled =
1135-
ggml_scale(ctx0,
1136-
KQ,
1137-
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
1128+
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, 1.0f/sqrtf(float(n_embd)/n_head));
11381129

11391130
// KQ_masked = mask_past(KQ_scaled)
11401131
// KQ_masked shape [n_past + N, N, n_head, 1]

examples/export-lora/export-lora.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ static struct ggml_cgraph * build_graph_lora(
309309
) {
310310
struct ggml_tensor * ab = ggml_mul_mat(ctx, lora_a, lora_b);
311311
if (scaling != 1.0f) {
312-
ab = ggml_scale(ctx, ab, ggml_new_f32(ctx, scaling));
312+
ab = ggml_scale(ctx, ab, scaling);
313313
}
314314
struct ggml_tensor * res = ggml_add_inplace(ctx, tensor, ab);
315315

examples/finetune/finetune.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ static void load_model_hparams_gguf(struct gguf_context * ctx, struct my_llama_h
269269
float rope_freq_scale = 1.0f;
270270
GGUF_GET_KEY(ctx, hparams->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
271271
GGUF_GET_KEY(ctx, hparams->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
272-
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
272+
GGUF_GET_KEY(ctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
273273
if (rope_freq_scale != 1.0f) {
274274
hparams->rope_freq_scale = 1.0f / rope_freq_scale;
275275
}
@@ -612,6 +612,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
612612
const int n_rot = hparams.n_embd_head();
613613
const int n_embd_head = hparams.n_embd_head();
614614
const int n_embd_gqa = hparams.n_embd_gqa();
615+
615616
const float rms_norm_eps = hparams.f_norm_rms_eps;
616617
const float rope_freq_base = hparams.rope_freq_base;
617618
const float rope_freq_scale = hparams.rope_freq_scale;
@@ -680,10 +681,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
680681
checkpoints.push_back(t01);
681682
}
682683

683-
struct ggml_tensor * kv_scale = NULL;
684-
if (!enable_flash_attn) {
685-
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
686-
}
684+
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
687685

688686
for (int il = 0; il < n_layer; ++il) {
689687
struct my_llama_layer & layer = model->layers[il];
@@ -781,32 +779,32 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
781779
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
782780
int n_leafs_before = gb->n_leafs;
783781
int n_nodes_before = gb->n_nodes;
784-
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
782+
785783
// output tensors
786-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
787-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
784+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
785+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
788786
// input gradient
789-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
787+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
790788
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
791789
ggml_allocr_alloc(alloc, t36->grad);
792790
// KQ_pos
793-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
791+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
794792

795793
// make sure base model tensors data cannot be used in viewable operations
796-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, one));
797-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, one));
798-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, one));
794+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->tok_embeddings, 1.0f));
795+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->norm, 1.0f));
796+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, model->output, 1.0f));
799797
for (int il = 0; il < n_layer; ++il) {
800798
struct my_llama_layer & layer = model->layers[il];
801-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, one));
802-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, one));
803-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, one));
804-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, one));
805-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, one));
806-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, one));
807-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, one));
808-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, one));
809-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
799+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.attention_norm, 1.0f));
800+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.ffn_norm, 1.0f));
801+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wq, 1.0f));
802+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wk, 1.0f));
803+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wv, 1.0f));
804+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.wo, 1.0f));
805+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w1, 1.0f));
806+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w2, 1.0f));
807+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, 1.0f));
810808
}
811809

812810
// allocating checkpoints in one block to reduce memory fragmentation

examples/gguf/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
set(TARGET gguf)
22
add_executable(${TARGET} gguf.cpp)
33
install(TARGETS ${TARGET} RUNTIME)
4-
target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT})
4+
target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
55
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/gguf/gguf.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include "ggml.h"
2-
#include "llama.h"
32

43
#include <cstdio>
54
#include <cinttypes>

examples/llava/clip.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,12 +330,6 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
330330
ggml_repeat(ctx0, model.pre_ln_b, embeddings));
331331
}
332332

333-
struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
334-
ggml_allocr_alloc(ctx->alloc, KQ_scale);
335-
if (!ggml_allocr_is_measure(ctx->alloc)) {
336-
ggml_set_f32(KQ_scale, 1.0f / sqrt((float)d_head));
337-
}
338-
339333
// loop over layers
340334
for (int il = 0; il < n_layer - 1; il++) {
341335
struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
@@ -356,7 +350,7 @@ static ggml_cgraph * clip_image_build_graph(const clip_ctx * ctx, const clip_ima
356350
struct ggml_tensor * Q =
357351
ggml_add(ctx0, ggml_repeat(ctx0, model.layers[il].q_b, cur), ggml_mul_mat(ctx0, model.layers[il].q_w, cur));
358352

359-
Q = ggml_scale_inplace(ctx0, Q, KQ_scale);
353+
Q = ggml_scale_inplace(ctx0, Q, 1.0f / sqrt((float)d_head));
360354
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
361355
Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
362356
Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);

examples/train-text-from-scratch/train-text-from-scratch.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,7 @@ static struct ggml_tensor * llama_build_train_graphs(
369369
checkpoints.push_back(t00);
370370
checkpoints.push_back(t01);
371371

372-
struct ggml_tensor * kv_scale = NULL;
373-
if (!enable_flash_attn) {
374-
kv_scale = ggml_new_f32(ctx, 1.0f/sqrtf(float(n_embd)/n_head));
375-
}
372+
const float kv_scale = 1.0f/sqrtf(float(n_embd)/n_head);
376373

377374
for (int il = 0; il < n_layer; ++il) {
378375
struct my_llama_layer & layer = model->layers[il];
@@ -444,14 +441,13 @@ static struct ggml_tensor * llama_build_train_graphs(
444441
// make sure some tensors are not reallocated by inserting new temporary nodes depending on them
445442
int n_leafs_before = gb->n_leafs;
446443
int n_nodes_before = gb->n_nodes;
447-
struct ggml_tensor * one = ggml_new_f32(ctx, 1.0f);
448444
// output tensors
449-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, one));
450-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
445+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t35, 1.0f));
446+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, 1.0f));
451447
// input gradient
452-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
448+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, 1.0f));
453449
// KQ_pos
454-
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
450+
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, 1.0f));
455451
GGML_ASSERT(t36->grad->data == NULL && t36->grad->view_src == NULL);
456452

457453
ggml_allocr_alloc(alloc, t36->grad);

ggml-alloc.c

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -449,11 +449,10 @@ static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool upd
449449
if (update_backend) {
450450
view->backend = view->view_src->backend;
451451
}
452-
view->buffer = view->view_src->buffer;
452+
// views are initialized in the alloc buffer rather than the view_src buffer
453+
view->buffer = alloc->buffer;
453454
view->data = (char *)view->view_src->data + view->view_offs;
454455

455-
// FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
456-
// due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
457456
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->buft == alloc->buffer->buft);
458457

459458
if (!alloc->measure) {
@@ -736,6 +735,10 @@ void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
736735
}
737736

738737
void ggml_allocr_free(ggml_allocr_t alloc) {
738+
if (alloc == NULL) {
739+
return;
740+
}
741+
739742
ggml_gallocr_free(alloc->galloc);
740743
ggml_tallocr_free(alloc->talloc);
741744
free(alloc);
@@ -775,7 +778,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
775778
}
776779

777780
if (nbytes == 0) {
778-
fprintf(stderr, "%s: no tensors to allocate\n", __func__);
781+
// all the tensors in the context are already allocated
779782
return NULL;
780783
}
781784

@@ -789,6 +792,11 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
789792
} else {
790793
ggml_backend_view_init(buffer, t);
791794
}
795+
} else {
796+
if (t->view_src != NULL) {
797+
// view of a pre-allocated tensor
798+
ggml_backend_view_init(buffer, t);
799+
}
792800
}
793801
}
794802

0 commit comments

Comments
 (0)