Skip to content

Commit bcebd7d

Browse files
dranger003iamlemecggerganov
authored
llama : add support for GritLM (#5959)
* add gritlm example * gritlm results match * tabs to spaces * comment out debug printing * rebase to new embed * gritlm embeddings are back babeee * add to gitignore * allow to toggle embedding mode * Clean-up GritLM sample code. * Fix types. * Flush stdout and output ending newline if streaming. * mostly style fixes; correct KQ_mask comment * add causal_attn flag to llama_cparams * gritml : minor * llama : minor --------- Co-authored-by: Douglas Hanley <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 2960eae commit bcebd7d

File tree

7 files changed

+267
-4
lines changed

7 files changed

+267
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ models-mnt
4545
/embedding
4646
/gguf
4747
/gguf-llama-simple
48+
/gritlm
4849
/imatrix
4950
/infill
5051
/libllama.so

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
BUILD_TARGETS = \
33
main quantize quantize-stats perplexity imatrix embedding vdot q8dot train-text-from-scratch convert-llama2c-to-ggml \
44
simple batched batched-bench save-load-state server gguf llama-bench libllava.a llava-cli baby-llama beam-search \
5-
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey tests/test-c.o
5+
speculative infill tokenize benchmark-matmult parallel finetune export-lora lookahead lookup passkey gritlm tests/test-c.o
66

77
# Binaries only useful for tests
88
TEST_TARGETS = \
@@ -724,6 +724,10 @@ embedding: examples/embedding/embedding.cpp ggml.o llama.o $(C
724724
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
725725
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
726726

727+
gritlm: examples/gritlm/gritlm.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
728+
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
729+
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)
730+
727731
save-load-state: examples/save-load-state/save-load-state.cpp ggml.o llama.o $(COMMON_DEPS) $(OBJS)
728732
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
729733
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ else()
2020
add_subdirectory(convert-llama2c-to-ggml)
2121
add_subdirectory(embedding)
2222
add_subdirectory(finetune)
23+
add_subdirectory(gritlm)
2324
add_subdirectory(infill)
2425
add_subdirectory(llama-bench)
2526
add_subdirectory(llava)

examples/gritlm/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
set(TARGET gritlm)
2+
add_executable(${TARGET} gritlm.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)

examples/gritlm/gritlm.cpp

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
#include "common.h"
2+
#include "llama.h"
3+
4+
#include <string>
5+
#include <vector>
6+
7+
// #define GRIT_DEBUG
8+
9+
static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
10+
float dot = 0.0f;
11+
for (uint64_t i = 0; i < v1.size(); ++i) {
12+
dot += v1[i] * v2[i];
13+
}
14+
return dot;
15+
}
16+
17+
static float norm(const std::vector<float> & v) {
18+
return std::sqrt(dot_product(v, v));
19+
}
20+
21+
static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
22+
return dot_product(v1, v2) / (norm(v1) * norm(v2));
23+
}
24+
25+
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
26+
std::vector<std::vector<float>> result;
27+
28+
const llama_model * mdl = llama_get_model(ctx);
29+
30+
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
31+
32+
for (uint64_t i = 0; i < sentences.size(); i++) {
33+
llama_batch_clear(batch);
34+
35+
const std::string input_string = instruction + sentences[i];
36+
37+
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
38+
39+
const int32_t n_toks = inputs.size();
40+
41+
// GritLM seems to have EOS = ""
42+
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
43+
// inputs.push_back(llama_token_eos(mdl));
44+
45+
// we want to ignore instruction tokens for mean pooling
46+
const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
47+
48+
#ifdef GRIT_DEBUG
49+
// debug tokens - should be matching as referenced in the GritLM sample
50+
std::for_each(inputs.begin(), inputs.end(), [&ctx](llama_token t) {
51+
std::printf("[%u:%s]", t, llama_token_to_piece(ctx, t).c_str());
52+
});
53+
std::printf("\n");
54+
#endif
55+
56+
// add input to batch (this increments n_tokens)
57+
for (int32_t j = 0; j < n_toks; j++) {
58+
llama_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst);
59+
}
60+
61+
// clear previous kv_cache values (irrelevant for embeddings)
62+
llama_kv_cache_clear(ctx);
63+
llama_set_causal_attn(ctx, false);
64+
65+
// run model
66+
llama_decode(ctx, batch);
67+
68+
// get embedding dimensions
69+
uint64_t n_embd = llama_n_embd(mdl);
70+
71+
// allocate embedding output
72+
std::vector<float> emb_unorm(n_embd, 0.0f);
73+
74+
// sum up all token embeddings
75+
for (int32_t k = n_inst; k < n_toks; k++) {
76+
float * emb = llama_get_embeddings_ith(ctx, k);
77+
for (uint64_t j = 0; j < n_embd; j++) {
78+
emb_unorm[j] += emb[j];
79+
}
80+
}
81+
82+
// divide by number of tokens (mean pooling)
83+
{
84+
const uint64_t n_sent = n_toks - n_inst;
85+
86+
for (uint64_t j = 0; j < n_embd; j++) {
87+
emb_unorm[j] /= n_sent;
88+
}
89+
}
90+
91+
std::vector<float> emb_norm(emb_unorm.size());
92+
llama_embd_normalize(emb_unorm.data(), emb_norm.data(), n_embd);
93+
result.push_back(emb_norm);
94+
95+
#ifdef GRIT_DEBUG
96+
// print out emb_norm
97+
std::printf("embedding %ld: ", i);
98+
for (uint64_t j = 0; j < n_embd; j++) {
99+
std::printf("%.5f ", emb_norm[j]);
100+
}
101+
std::printf("\n\n");
102+
#endif
103+
}
104+
105+
llama_batch_free(batch);
106+
107+
return result;
108+
}
109+
110+
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
111+
std::string result;
112+
113+
const llama_model * mdl = llama_get_model(ctx);
114+
llama_token eos_token = llama_token_eos(mdl);
115+
116+
llama_kv_cache_clear(ctx);
117+
llama_set_causal_attn(ctx, true);
118+
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
119+
120+
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
121+
int32_t i_current_token = 0;
122+
123+
while (true) {
124+
llama_batch_clear(bat);
125+
auto n_inputs = (int32_t)inputs.size();
126+
for (int32_t i = 0; i < n_inputs; i++) {
127+
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
128+
}
129+
inputs.clear();
130+
131+
llama_decode(ctx, bat);
132+
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
133+
134+
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
135+
auto n_candidates = (int32_t)candidates.size();
136+
for (int32_t token = 0; token < n_candidates; token++) {
137+
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
138+
}
139+
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
140+
141+
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
142+
if (token == eos_token) {
143+
break;
144+
}
145+
146+
std::string piece = llama_token_to_piece(ctx, token);
147+
if (stream) {
148+
std::printf("%s", piece.c_str());
149+
std::fflush(stdout);
150+
}
151+
152+
inputs.push_back(token);
153+
154+
result += piece;
155+
}
156+
157+
if (stream) {
158+
std::printf("\n");
159+
}
160+
161+
llama_batch_free(bat);
162+
163+
return result;
164+
}
165+
166+
static std::string gritlm_instruction(const std::string & instruction) {
167+
return !instruction.empty() ? "<|user|>\n" + instruction + "\n<|embed|>\n" : "<|embed|>\n";
168+
}
169+
170+
int main(int argc, char * argv[]) {
171+
gpt_params params;
172+
if (!gpt_params_parse(argc, argv, params)) {
173+
return 1;
174+
}
175+
176+
llama_model_params mparams = llama_model_params_from_gpt_params(params);
177+
llama_context_params cparams = llama_context_params_from_gpt_params(params);
178+
179+
llama_backend_init();
180+
181+
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
182+
183+
// create new context - set to embedding mode
184+
cparams.embeddings = true;
185+
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
186+
187+
// ### Embedding/Representation ###
188+
// samples taken from: https://github.com/ContextualAI/gritlm#basic
189+
{
190+
const std::string instruction = "Given a scientific paper title, retrieve the paper's abstract";
191+
192+
const std::vector<std::string> queries = {
193+
"Bitcoin: A Peer-to-Peer Electronic Cash System",
194+
"Generative Representational Instruction Tuning",
195+
};
196+
197+
const std::vector<std::string> documents = {
198+
"A purely peer-to-peer version of electronic cash would allow online payments to be sent directly from one party to another without going through a financial institution. Digital signatures provide part of the solution, but the main benefits are lost if a trusted third party is still required to prevent double-spending. We propose a solution to the double-spending problem using a peer-to-peer network. The network timestamps transactions by hashing them into an ongoing chain of hash-based proof-of-work, forming a record that cannot be changed without redoing the proof-of-work. The longest chain not only serves as proof of the sequence of events witnessed, but proof that it came from the largest pool of CPU power. As long as a majority of CPU power is controlled by nodes that are not cooperating to attack the network, they'll generate the longest chain and outpace attackers. The network itself requires minimal structure. Messages are broadcast on a best effort basis, and nodes can leave and rejoin the network at will, accepting the longest proof-of-work chain as proof of what happened while they were gone.",
199+
"All text-based language problems can be reduced to either generation or embedding. Current models only perform well at one or the other. We introduce generative representational instruction tuning (GRIT) whereby a large language model is trained to handle both generative and embedding tasks by distinguishing between them through instructions. Compared to other open models, our resulting GritLM 7B sets a new state of the art on the Massive Text Embedding Benchmark (MTEB) and outperforms all models up to its size on a range of generative tasks. By scaling up further, GritLM 8X7B outperforms all open generative language models that we tried while still being among the best embedding models. Notably, we find that GRIT matches training on only generative or embedding data, thus we can unify both at no performance loss. Among other benefits, the unification via GRIT speeds up Retrieval-Augmented Generation (RAG) by > 60% for long documents, by no longer requiring separate retrieval and generation models. Models, code, etc. are freely available at https://github.com/ContextualAI/gritlm.",
200+
};
201+
202+
// No need to add instruction for retrieval documents
203+
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
204+
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
205+
206+
const float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
207+
const float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
208+
const float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]);
209+
const float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]);
210+
211+
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
212+
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);
213+
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[0].c_str(), cosine_sim_q1_d0);
214+
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[1].c_str(), documents[1].c_str(), cosine_sim_q1_d1);
215+
}
216+
217+
// ### Generation ###
218+
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
219+
{
220+
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
221+
std::string response = generate(ctx, prompt, true);
222+
}
223+
224+
llama_free(ctx);
225+
llama_free_model(mdl);
226+
llama_backend_free();
227+
228+
return 0;
229+
}

llama.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,6 +1744,7 @@ struct llama_cparams {
17441744
float defrag_thold;
17451745

17461746
bool embeddings;
1747+
bool causal_attn;
17471748
bool offload_kqv;
17481749

17491750
enum llama_pooling_type pooling_type;
@@ -3939,6 +3940,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
39393940
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
39403941
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
39413942
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
3943+
LLAMA_LOG_INFO("%s: causal attm = %d\n", __func__, hparams.causal_attn);
39423944
LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type);
39433945
LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type);
39443946
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
@@ -8532,7 +8534,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
85328534
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
85338535
}
85348536

8535-
if (hparams.causal_attn) {
8537+
GGML_ASSERT(
8538+
(hparams.causal_attn || !cparams.causal_attn) &&
8539+
"non-causal attention with generative models is not supported"
8540+
);
8541+
8542+
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
8543+
if (cparams.causal_attn) {
85368544
const int64_t n_kv = kv_self.n;
85378545
const int64_t n_tokens = batch.n_tokens;
85388546

@@ -8560,8 +8568,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
85608568
}
85618569
}
85628570
} else {
8563-
// non-causal attention attends only the tokens within the batch (i.e. the KV cache is not used)
8571+
// when using kv cache, the mask needs to match the kv cache size
85648572
const int64_t n_tokens = batch.n_tokens;
8573+
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
85658574

85668575
assert(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));
85678576

@@ -8580,7 +8589,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
85808589
}
85818590
}
85828591

8583-
data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
8592+
data[h*(n_tokens*n_tokens) + j*n_stride + i] = f;
8593+
}
8594+
8595+
for (int i = n_tokens; i < n_stride; ++i) {
8596+
data[h*(n_tokens*n_tokens) + j*n_stride + i] = -INFINITY;
85848597
}
85858598
}
85868599
}
@@ -12733,6 +12746,8 @@ struct llama_context * llama_new_context_with_model(
1273312746
cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_YARN ? 1.0f : 0.0f;
1273412747
}
1273512748

12749+
cparams.causal_attn = hparams.causal_attn;
12750+
1273612751
if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
1273712752
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
1273812753
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
@@ -13767,6 +13782,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
1376713782
ctx->abort_callback_data = abort_callback_data;
1376813783
}
1376913784

13785+
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
13786+
ctx->cparams.causal_attn = causal_attn;
13787+
}
13788+
1377013789
struct llama_batch llama_batch_get_one(
1377113790
llama_token * tokens,
1377213791
int32_t n_tokens,

llama.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,10 @@ extern "C" {
643643
// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
644644
LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
645645

646+
// Set whether to use causal attention or not
647+
// If set to true, the model will only attend to the past tokens
648+
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);
649+
646650
// Set abort callback
647651
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
648652

0 commit comments

Comments
 (0)