Skip to content

Commit f858db8

Browse files
committed
Merge remote-tracking branch 'upstream/master'
2 parents 638ff1a + 6e08281 commit f858db8

File tree

24 files changed

+4221
-4185
lines changed

24 files changed

+4221
-4185
lines changed

.github/ISSUE_TEMPLATE/bug.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
name: Bug template
33
about: Used to report bugs in llama.cpp
4-
labels: ["bug"]
4+
labels: ["bug-unconfirmed"]
55
assignees: ''
66

77
---

CMakeLists.txt

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ option(LLAMA_CLBLAST "llama: use CLBlast"
9494
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
9595
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
9696
option(LLAMA_MPI "llama: use MPI" OFF)
97-
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
9897
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
9998

10099
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
@@ -278,13 +277,8 @@ if (LLAMA_BLAS)
278277
endif()
279278
endif()
280279

281-
if (LLAMA_K_QUANTS)
282-
set(GGML_HEADERS_EXTRA k_quants.h)
283-
set(GGML_SOURCES_EXTRA k_quants.c)
284-
add_compile_definitions(GGML_USE_K_QUANTS)
285-
if (LLAMA_QKK_64)
286-
add_compile_definitions(GGML_QKK_64)
287-
endif()
280+
if (LLAMA_QKK_64)
281+
add_compile_definitions(GGML_QKK_64)
288282
endif()
289283

290284
if (LLAMA_CUBLAS)
@@ -673,6 +667,8 @@ add_library(ggml OBJECT
673667
ggml-alloc.h
674668
ggml-backend.c
675669
ggml-backend.h
670+
ggml-quants.c
671+
ggml-quants.h
676672
${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA}
677673
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
678674
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}

Makefile

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -342,13 +342,9 @@ else
342342
MK_CXXFLAGS += -march=rv64gcv -mabi=lp64d
343343
endif
344344

345-
ifndef LLAMA_NO_K_QUANTS
346-
MK_CPPFLAGS += -DGGML_USE_K_QUANTS
347-
OBJS += k_quants.o
348345
ifdef LLAMA_QKK_64
349346
MK_CPPFLAGS += -DGGML_QKK_64
350347
endif
351-
endif
352348

353349
ifndef LLAMA_NO_ACCELERATE
354350
# Mac OS - include Accelerate framework.
@@ -365,7 +361,7 @@ ifdef LLAMA_MPI
365361
MK_CPPFLAGS += -DGGML_USE_MPI
366362
MK_CFLAGS += -Wno-cast-qual
367363
MK_CXXFLAGS += -Wno-cast-qual
368-
OBJS += ggml-mpi.o
364+
OBJS += ggml-mpi.o
369365
endif # LLAMA_MPI
370366

371367
ifdef LLAMA_OPENBLAS
@@ -382,7 +378,7 @@ endif # LLAMA_BLIS
382378
ifdef LLAMA_CUBLAS
383379
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
384380
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
385-
OBJS += ggml-cuda.o
381+
OBJS += ggml-cuda.o
386382
NVCCFLAGS = --forward-unknown-to-host-compiler -use_fast_math
387383
ifdef LLAMA_CUDA_NVCC
388384
NVCC = $(LLAMA_CUDA_NVCC)
@@ -497,11 +493,6 @@ ggml-mpi.o: ggml-mpi.c ggml-mpi.h
497493
$(CC) $(CFLAGS) -c $< -o $@
498494
endif # LLAMA_MPI
499495

500-
ifndef LLAMA_NO_K_QUANTS
501-
k_quants.o: k_quants.c k_quants.h
502-
$(CC) $(CFLAGS) -c $< -o $@
503-
endif # LLAMA_NO_K_QUANTS
504-
505496
# combine build flags with cmdline overrides
506497
override CFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CFLAGS) $(CFLAGS)
507498
override CXXFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CXXFLAGS) $(CXXFLAGS)
@@ -542,15 +533,18 @@ ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h
542533
ggml-backend.o: ggml-backend.c ggml.h ggml-backend.h
543534
$(CC) $(CFLAGS) -c $< -o $@
544535

545-
OBJS += ggml-alloc.o ggml-backend.o
536+
ggml-quants.o: ggml-quants.c ggml.h ggml-quants.h
537+
$(CC) $(CFLAGS) -c $< -o $@
538+
539+
OBJS += ggml-alloc.o ggml-backend.o ggml-quants.o
546540

547541
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-backend.h ggml-cuda.h ggml-metal.h llama.h
548542
$(CXX) $(CXXFLAGS) -c $< -o $@
549543

550-
COMMON_H_DEPS = common/common.h common/sampling.h build-info.h common/log.h
551-
COMMON_DEPS = $(COMMON_H_DEPS) common.o sampling.o grammar-parser.o
544+
COMMON_H_DEPS = common/common.h common/sampling.h common/log.h
545+
COMMON_DEPS = common.o sampling.o grammar-parser.o
552546

553-
common.o: common/common.cpp $(COMMON_H_DEPS)
547+
common.o: common/common.cpp build-info.h $(COMMON_H_DEPS)
554548
$(CXX) $(CXXFLAGS) -c $< -o $@
555549

556550
sampling.o: common/sampling.cpp $(COMMON_H_DEPS)

Package.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,12 @@ let package = Package(
4242
"llama.cpp",
4343
"ggml-alloc.c",
4444
"ggml-backend.c",
45-
"k_quants.c",
45+
"ggml-quants.c",
4646
] + additionalSources,
4747
resources: resources,
4848
publicHeadersPath: "spm-headers",
4949
cSettings: [
5050
.unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]),
51-
.define("GGML_USE_K_QUANTS"),
5251
.define("GGML_USE_ACCELERATE")
5352
// NOTE: NEW_LAPACK will required iOS version 16.4+
5453
// We should consider add this in the future when we drop support for iOS 14

build.zig

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,10 @@ pub fn build(b: *std.build.Builder) !void {
116116
var make = try Maker.init(b);
117117
make.enable_lto = b.option(bool, "lto", "Enable LTO optimization, (default: false)") orelse false;
118118

119-
if (b.option(bool, "k-quants", "Enable K-quants, (default: true)") orelse true) {
120-
try make.addFlag("-DGGML_USE_K_QUANTS");
121-
const k_quants = make.obj("k_quants", "k_quants.c");
122-
try make.objs.append(k_quants);
123-
}
124-
125119
const ggml = make.obj("ggml", "ggml.c");
126120
const ggml_alloc = make.obj("ggml-alloc", "ggml-alloc.c");
127121
const ggml_backend = make.obj("ggml-backend", "ggml-backend.c");
122+
const ggml_quants = make.obj("ggml-quants", "ggml-quants.c");
128123
const llama = make.obj("llama", "llama.cpp");
129124
const common = make.obj("common", "common/common.cpp");
130125
const console = make.obj("console", "common/console.cpp");
@@ -133,14 +128,14 @@ pub fn build(b: *std.build.Builder) !void {
133128
const train = make.obj("train", "common/train.cpp");
134129
const clip = make.obj("clip", "examples/llava/clip.cpp");
135130

136-
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, console, grammar_parser });
137-
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
138-
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
139-
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common });
140-
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
141-
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, train });
131+
_ = make.exe("main", "examples/main/main.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, sampling, console, grammar_parser });
132+
_ = make.exe("quantize", "examples/quantize/quantize.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common });
133+
_ = make.exe("perplexity", "examples/perplexity/perplexity.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common });
134+
_ = make.exe("embedding", "examples/embedding/embedding.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common });
135+
_ = make.exe("finetune", "examples/finetune/finetune.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, train });
136+
_ = make.exe("train-text-from-scratch", "examples/train-text-from-scratch/train-text-from-scratch.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, train });
142137

143-
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, llama, common, sampling, grammar_parser, clip });
138+
const server = make.exe("server", "examples/server/server.cpp", &.{ ggml, ggml_alloc, ggml_backend, ggml_quants, llama, common, sampling, grammar_parser, clip });
144139
if (server.target.isWindows()) {
145140
server.linkSystemLibrary("ws2_32");
146141
}

common/common.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
224224
break;
225225
}
226226
sparams.temp = std::stof(argv[i]);
227+
sparams.temp = std::max(sparams.temp, 0.0f);
227228
} else if (arg == "--tfs") {
228229
if (++i >= argc) {
229230
invalid_param = true;
@@ -743,7 +744,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
743744
#endif // GGML_USE_CUBLAS
744745
#endif
745746
printf(" --verbose-prompt print prompt before generation\n");
746-
fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
747+
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
747748
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
748749
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
749750
printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
@@ -888,7 +889,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
888889

889890
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
890891
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
891-
llama_kv_cache_tokens_rm(lctx, -1, -1);
892+
llama_kv_cache_clear(lctx);
892893
llama_reset_timings(lctx);
893894
}
894895

common/sampling.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,12 @@ llama_token llama_sampling_sample(
167167
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
168168
}
169169

170-
if (temp <= 0) {
171-
// greedy sampling
170+
if (temp < 0.0) {
171+
// greedy sampling, with probs
172+
llama_sample_softmax(ctx_main, &cur_p);
173+
id = cur_p.data[0].id;
174+
} else if (temp == 0.0) {
175+
// greedy sampling, no probs
172176
id = llama_sample_token_greedy(ctx_main, &cur_p);
173177
} else {
174178
if (mirostat == 1) {

convert.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -366,16 +366,19 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> No
366366
added_tokens = {}
367367

368368
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
369-
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
370-
actual_ids = sorted(added_tokens.values())
371-
if expected_ids != actual_ids:
372-
raise Exception(f"Expected added token IDs to be sequential and start at {vocab_size}; got {actual_ids}")
373369

374-
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
375-
self.added_tokens_list = [text for (text, idx) in items]
376-
self.vocab_size_base: int = vocab_size
377-
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
378-
self.fname_tokenizer = fname_tokenizer
370+
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
371+
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
372+
actual_new_ids = sorted(new_tokens.keys())
373+
374+
if expected_new_ids != actual_new_ids:
375+
raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
376+
377+
# Token pieces that were added to the base vocabulary.
378+
self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
379+
self.vocab_size_base = vocab_size
380+
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
381+
self.fname_tokenizer = fname_tokenizer
379382
self.fname_added_tokens = fname_added_tokens
380383

381384
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:

examples/batched-bench/batched-bench.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ int main(int argc, char ** argv) {
185185

186186
const auto t_pp_start = ggml_time_us();
187187

188-
llama_kv_cache_tokens_rm(ctx, -1, -1);
188+
llama_kv_cache_clear(ctx);
189189

190190
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
191191
LOG_TEE("%s: llama_decode() failed\n", __func__);

examples/llama-bench/llama-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ int main(int argc, char ** argv) {
10371037

10381038
test t(inst, lmodel, ctx);
10391039

1040-
llama_kv_cache_tokens_rm(ctx, -1, -1);
1040+
llama_kv_cache_clear(ctx);
10411041

10421042
// warmup run
10431043
if (t.n_prompt > 0) {
@@ -1048,7 +1048,7 @@ int main(int argc, char ** argv) {
10481048
}
10491049

10501050
for (int i = 0; i < params.reps; i++) {
1051-
llama_kv_cache_tokens_rm(ctx, -1, -1);
1051+
llama_kv_cache_clear(ctx);
10521052

10531053
uint64_t t_start = get_time_ns();
10541054
if (t.n_prompt > 0) {

examples/main/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
298298
}
299299

300300
// remove any "future" tokens that we might have inherited from the previous session
301-
llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1);
301+
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
302302
}
303303

304304
LOGLN(

examples/perplexity/perplexity.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
210210
const auto t_start = std::chrono::high_resolution_clock::now();
211211

212212
// clear the KV cache
213-
llama_kv_cache_tokens_rm(ctx, -1, -1);
213+
llama_kv_cache_clear(ctx);
214214

215215
for (int j = 0; j < num_batches; ++j) {
216216
const int batch_start = start + j * n_batch;
@@ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
339339
const auto t_start = std::chrono::high_resolution_clock::now();
340340

341341
// clear the KV cache
342-
llama_kv_cache_tokens_rm(ctx, -1, -1);
342+
llama_kv_cache_clear(ctx);
343343

344344
for (int j = 0; j < num_batches; ++j) {
345345
const int batch_start = start + j * n_batch;
@@ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
573573
}
574574

575575
// clear the KV cache
576-
llama_kv_cache_tokens_rm(ctx, -1, -1);
576+
llama_kv_cache_clear(ctx);
577577

578578
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
579579
if (logits.empty()) {

examples/quantize/quantize.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
1818
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 3.90G, +0.1585 ppl @ LLaMA-v1-7B", },
1919
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", },
2020
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
21-
#ifdef GGML_USE_K_QUANTS
2221
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
2322
{ "Q3_K", LLAMA_FTYPE_MOSTLY_Q3_K_M, "alias for Q3_K_M" },
2423
{ "Q3_K_S", LLAMA_FTYPE_MOSTLY_Q3_K_S, " 2.75G, +0.5551 ppl @ LLaMA-v1-7B", },
@@ -31,7 +30,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
3130
{ "Q5_K_S", LLAMA_FTYPE_MOSTLY_Q5_K_S, " 4.33G, +0.0400 ppl @ LLaMA-v1-7B", },
3231
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
3332
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, -0.0008 ppl @ LLaMA-v1-7B", },
34-
#endif
3533
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
3634
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", },
3735
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
@@ -70,13 +68,14 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp
7068
}
7169

7270
// usage:
73-
// ./quantize [--allow-requantize] [--leave-output-tensor] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
71+
// ./quantize [--allow-requantize] [--leave-output-tensor] [--pure] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads]
7472
//
7573
[[noreturn]]
7674
static void usage(const char * executable) {
77-
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
75+
printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable);
7876
printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n");
7977
printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n");
78+
printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n");
8079
printf("\nAllowed quantization types:\n");
8180
for (auto & it : QUANT_OPTIONS) {
8281
if (it.name != "COPY") {
@@ -103,6 +102,8 @@ int main(int argc, char ** argv) {
103102
params.quantize_output_tensor = false;
104103
} else if (strcmp(argv[arg_idx], "--allow-requantize") == 0) {
105104
params.allow_requantize = true;
105+
} else if (strcmp(argv[arg_idx], "--pure") == 0) {
106+
params.pure = true;
106107
} else {
107108
usage(argv[0]);
108109
}

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ struct llama_server_context
857857

858858
void kv_cache_clear() {
859859
// clear the entire KV cache
860-
llama_kv_cache_tokens_rm(ctx, -1, -1);
860+
llama_kv_cache_clear(ctx);
861861
clean_kv_cache = false;
862862
}
863863

examples/speculative/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ int main(int argc, char ** argv) {
148148
std::vector<seq_draft> drafts(n_seq_dft);
149149

150150
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
151-
params.sparams.temp = std::max(0.01f, params.sparams.temp);
151+
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
152152

153153
for (int s = 0; s < n_seq_dft; ++s) {
154154
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);

flake.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@
5151
};
5252
llama-python =
5353
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]);
54+
# TODO(Green-Sky): find a better way to opt-into the heavy ml python runtime
55+
llama-python-extra =
56+
pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece torchWithoutCuda transformers ]);
5457
postPatch = ''
5558
substituteInPlace ./ggml-metal.m \
5659
--replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";"
@@ -126,5 +129,9 @@
126129
buildInputs = [ llama-python ];
127130
packages = nativeBuildInputs ++ osSpecific;
128131
};
132+
devShells.extra = pkgs.mkShell {
133+
buildInputs = [ llama-python-extra ];
134+
packages = nativeBuildInputs ++ osSpecific;
135+
};
129136
});
130137
}

0 commit comments

Comments
 (0)