Skip to content

feat(speculative-sampling): Add speculative sampling #200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
317 changes: 211 additions & 106 deletions binding.cpp

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions binding.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void* load_model(const char *fname,
bool numa,
float rope_freq_base,
float rope_freq_scale,
bool mul_mat_q, const char *lora, const char *lora_base
bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity
);

int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings);
Expand All @@ -41,8 +41,10 @@ void* llama_allocate_params(const char *prompt, int seed, int threads, int token
int repeat_last_n, bool ignore_eos, bool memory_f16,
int n_batch, int n_keep, const char** antiprompt, int antiprompt_count,
float tfs_z, float typical_p, float frequency_penalty, float presence_penalty, int mirostat, float mirostat_eta, float mirostat_tau, bool penalize_nl, const char *logit_bias, const char *session_file, bool prompt_cache_all, bool mlock, bool mmap, const char *maingpu, const char *tensorsplit ,
bool prompt_cache_ro, const char *grammar, float rope_freq_base, float rope_freq_scale, float negative_prompt_scale, const char* negative_prompt
);
bool prompt_cache_ro, const char *grammar, float rope_freq_base, float rope_freq_scale, float negative_prompt_scale, const char* negative_prompt,
int n_draft);

int speculative_sampling(void* params_ptr, void* target_model, void* draft_model, char* result, bool debug);

void llama_free_params(void* params_ptr);

Expand Down
2 changes: 1 addition & 1 deletion llama.cpp
65 changes: 64 additions & 1 deletion llama.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func New(model string, opts ...ModelOption) (*LLama, error) {
C.bool(mo.F16Memory), C.bool(mo.MLock), C.bool(mo.Embeddings), C.bool(mo.MMap), C.bool(mo.LowVRAM),
C.int(mo.NGPULayers), C.int(mo.NBatch), C.CString(mo.MainGPU), C.CString(mo.TensorSplit), C.bool(mo.NUMA),
C.float(mo.FreqRopeBase), C.float(mo.FreqRopeScale),
C.bool(MulMatQ), loraAdapter, loraBase,
C.bool(MulMatQ), loraAdapter, loraBase, C.bool(mo.Perplexity),
)

if result == nil {
Expand Down Expand Up @@ -123,6 +123,7 @@ func (l *LLama) TokenEmbeddings(tokens []int, opts ...PredictOption) ([]float32,
C.bool(po.PromptCacheRO),
C.CString(po.Grammar),
C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt),
C.int(po.NDraft),
)
ret := C.get_token_embeddings(params, l.state, myArray, C.int(len(tokens)), (*C.float)(&floats[0]))
if ret != 0 {
Expand Down Expand Up @@ -164,6 +165,7 @@ func (l *LLama) Embeddings(text string, opts ...PredictOption) ([]float32, error
C.bool(po.PromptCacheRO),
C.CString(po.Grammar),
C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt),
C.int(po.NDraft),
)

ret := C.get_embeddings(params, l.state, (*C.float)(&floats[0]))
Expand Down Expand Up @@ -202,6 +204,7 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error {
C.bool(po.PromptCacheRO),
C.CString(po.Grammar),
C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt),
C.int(po.NDraft),
)
ret := C.eval(params, l.state, input)
if ret != 0 {
Expand All @@ -213,6 +216,64 @@ func (l *LLama) Eval(text string, opts ...PredictOption) error {
return nil
}

func (l *LLama) SpeculativeSampling(ll *LLama, text string, opts ...PredictOption) (string, error) {
po := NewPredictOptions(opts...)

if po.TokenCallback != nil {
setCallback(l.state, po.TokenCallback)
}

input := C.CString(text)
if po.Tokens == 0 {
po.Tokens = 99999999
}
out := make([]byte, po.Tokens)

reverseCount := len(po.StopPrompts)
reversePrompt := make([]*C.char, reverseCount)
var pass **C.char
for i, s := range po.StopPrompts {
cs := C.CString(s)
reversePrompt[i] = cs
pass = &reversePrompt[0]
}

params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
C.bool(po.IgnoreEOS), C.bool(po.F16KV),
C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount),
C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty),
C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias),
C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap),
C.CString(po.MainGPU), C.CString(po.TensorSplit),
C.bool(po.PromptCacheRO),
C.CString(po.Grammar),
C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt),
C.int(po.NDraft),
)
ret := C.speculative_sampling(params, l.state, ll.state, (*C.char)(unsafe.Pointer(&out[0])), C.bool(po.DebugMode))
if ret != 0 {
return "", fmt.Errorf("inference failed")
}
res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))

res = strings.TrimPrefix(res, " ")
res = strings.TrimPrefix(res, text)
res = strings.TrimPrefix(res, "\n")

for _, s := range po.StopPrompts {
res = strings.TrimRight(res, s)
}

C.llama_free_params(params)

if po.TokenCallback != nil {
setCallback(l.state, nil)
}

return res, nil
}

func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
po := NewPredictOptions(opts...)

Expand Down Expand Up @@ -246,6 +307,7 @@ func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
C.bool(po.PromptCacheRO),
C.CString(po.Grammar),
C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt),
C.int(po.NDraft),
)
ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.bool(po.DebugMode))
if ret != 0 {
Expand Down Expand Up @@ -294,6 +356,7 @@ func (l *LLama) TokenizeString(text string, opts ...PredictOption) (int32, []int
C.bool(po.PromptCacheRO),
C.CString(po.Grammar),
C.float(po.RopeFreqBase), C.float(po.RopeFreqScale), C.float(po.NegativePromptScale), C.CString(po.NegativePrompt),
C.int(po.NDraft),
)

tokRet := C.llama_tokenize_string(params, l.state, (*C.int)(unsafe.Pointer(&out[0]))) //, C.int(po.Tokens), true)
Expand Down
33 changes: 33 additions & 0 deletions llama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package llama_test
import (
"os"

"github.com/go-skynet/go-llama.cpp"
. "github.com/go-skynet/go-llama.cpp"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
Expand Down Expand Up @@ -45,6 +46,38 @@ how much is 2+2?
Expect(text).To(ContainSubstring("4"), text)
})

It("speculative sampling predicts", func() {
if testModelPath == "" {
Skip("test skipped - only makes sense if the TEST_MODEL environment variable is set.")
}
model, err := New(
testModelPath,
EnableF16Memory,
SetContext(128),
SetMMap(true),
SetNBatch(512),
SetPerplexity(true),
)
Expect(err).ToNot(HaveOccurred())
Expect(model).ToNot(BeNil())
model2, err := New(
testModelPath,
EnableF16Memory,
SetContext(128),
SetMMap(true),
SetNBatch(512),
SetPerplexity(true),
)
Expect(err).ToNot(HaveOccurred())
Expect(model).ToNot(BeNil())
text, err := model.SpeculativeSampling(model2, `[INST] Answer to the following question:
how much is 2+2?
[/INST]`, llama.SetNDraft(16),
)
Expect(err).ToNot(HaveOccurred(), text)
Expect(text).To(ContainSubstring("4"), text)
})

It("tokenizes strings successfully", func() {
if testModelPath == "" {
Skip("test skipped - only makes sense if the TEST_MODEL environment variable is set.")
Expand Down
14 changes: 14 additions & 0 deletions options.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ type ModelOptions struct {
MulMatQ *bool
LoraBase string
LoraAdapter string
Perplexity bool
}

type PredictOptions struct {
Seed, Threads, Tokens, TopK, Repeat, Batch, NKeep int
TopP, Temperature, Penalty float32
NDraft int
F16KV bool
DebugMode bool
StopPrompts []string
Expand Down Expand Up @@ -193,6 +195,18 @@ func SetRopeFreqScale(rfs float32) PredictOption {
}
}

func SetNDraft(nd int) PredictOption {
return func(p *PredictOptions) {
p.NDraft = nd
}
}

func SetPerplexity(b bool) ModelOption {
return func(p *ModelOptions) {
p.Perplexity = b
}
}

func SetNegativePromptScale(nps float32) PredictOption {
return func(p *PredictOptions) {
p.NegativePromptScale = nps
Expand Down
13 changes: 7 additions & 6 deletions patches/1902-cuda.patch
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
diff --git a/common/common.cpp b/common/common.cpp
index ed09fc2..ced02e8 100644
index 3138213..af93a32 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1107,3 +1107,82 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
@@ -1257,3 +1257,83 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "typical_p: %f # default: 1.0\n", params.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
}
Expand All @@ -22,7 +22,7 @@ index ed09fc2..ced02e8 100644
+ return lparams;
+}
+
+void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base) {
+void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity) {
+ // load the model
+ gpt_params * lparams = create_gpt_params(fname, lora, lora_base);
+ llama_model * model;
Expand All @@ -35,6 +35,7 @@ index ed09fc2..ced02e8 100644
+ lparams->embedding = embeddings;
+ lparams->use_mlock = mlock;
+ lparams->n_gpu_layers = n_gpu_layers;
+ lparams->perplexity = perplexity;
+ lparams->use_mmap = mmap;
+
+ lparams->low_vram = low_vram;
Expand Down Expand Up @@ -87,10 +88,10 @@ index ed09fc2..ced02e8 100644
+}
\ No newline at end of file
diff --git a/common/common.h b/common/common.h
index 5a37968..8b09050 100644
index 105fb09..8f60434 100644
--- a/common/common.h
+++ b/common/common.h
@@ -165,3 +165,10 @@ std::string get_sortable_timestamp();
@@ -201,3 +201,10 @@ std::string get_sortable_timestamp();
void dump_non_result_info_yaml(
FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
Expand All @@ -100,4 +101,4 @@ index 5a37968..8b09050 100644
+ llama_model * model;
+};
+
+void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base);
+void* load_binding_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, int n_gpu_layers, int n_batch, const char *maingpu, const char *tensorsplit, bool numa, float rope_freq_base, float rope_freq_scale, bool mul_mat_q, const char *lora, const char *lora_base, bool perplexity);