Skip to content

Commit 5ac58d9

Browse files
fairydreamingsszymczy
authored andcommitted
Add support for encoder-only T5 models (ggml-org#8900)
* gguf-py : add T5ENCODER model architecture * common : call llama_decode() during warmup only if the model has decoder * convert-hf : add T5EncoderModel * llama : add llama_model_has_decoder() API function * llama : split build_t5() into build_t5_encoder() and build_t5_decoder() * llama : add support for LLM_ARCH_T5ENCODER * llama-embedding : add support for LLAMA_POOLING_TYPE_NONE * llama-embedding : add support for encoder-only models --------- Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent 8560c77 commit 5ac58d9

File tree

6 files changed

+649
-282
lines changed

6 files changed

+649
-282
lines changed

common/common.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2156,7 +2156,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
21562156
tmp.clear();
21572157
tmp.push_back(decoder_start_token_id);
21582158
}
2159-
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
2159+
if (llama_model_has_decoder(model)) {
2160+
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
2161+
}
21602162
llama_kv_cache_clear(lctx);
21612163
llama_synchronize(lctx);
21622164
llama_reset_timings(lctx);

convert_hf_to_gguf.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3324,6 +3324,145 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
33243324
return [(self.map_tensor_name(name), data_torch)]
33253325

33263326

3327+
@Model.register("T5EncoderModel")
3328+
class T5EncoderModel(Model):
3329+
model_arch = gguf.MODEL_ARCH.T5ENCODER
3330+
3331+
def __init__(self, *args, **kwargs):
3332+
super().__init__(*args, **kwargs)
3333+
self.shared_token_embeddings_found = False
3334+
3335+
def set_vocab(self):
3336+
# to avoid TypeError: Descriptors cannot be created directly
3337+
# exception when importing sentencepiece_model_pb2
3338+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
3339+
from sentencepiece import SentencePieceProcessor
3340+
from sentencepiece import sentencepiece_model_pb2 as model
3341+
3342+
tokenizer_path = self.dir_model / 'tokenizer.model'
3343+
3344+
# many older models use spiece.model tokenizer model filename
3345+
if not tokenizer_path.is_file():
3346+
tokenizer_path = self.dir_model / 'spiece.model'
3347+
3348+
if not tokenizer_path.is_file():
3349+
raise FileNotFoundError(f"File not found: {tokenizer_path}")
3350+
3351+
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
3352+
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
3353+
3354+
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
3355+
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
3356+
# assure the tokenizer model file name is correct
3357+
assert tokenizer_path.name == 'tokenizer.model'
3358+
return self._set_vocab_sentencepiece()
3359+
else:
3360+
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
3361+
3362+
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
3363+
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
3364+
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
3365+
3366+
tokenizer = SentencePieceProcessor()
3367+
tokenizer.LoadFromFile(str(tokenizer_path))
3368+
3369+
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
3370+
3371+
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
3372+
scores: list[float] = [-10000.0] * vocab_size
3373+
toktypes: list[int] = [SentencePieceTokenTypes.UNUSED] * vocab_size
3374+
3375+
for token_id in range(tokenizer.vocab_size()):
3376+
piece = tokenizer.IdToPiece(token_id)
3377+
text = piece.encode("utf-8")
3378+
score = tokenizer.GetScore(token_id)
3379+
3380+
toktype = SentencePieceTokenTypes.NORMAL
3381+
if tokenizer.IsUnknown(token_id):
3382+
toktype = SentencePieceTokenTypes.UNKNOWN
3383+
elif tokenizer.IsControl(token_id):
3384+
toktype = SentencePieceTokenTypes.CONTROL
3385+
elif tokenizer.IsUnused(token_id):
3386+
toktype = SentencePieceTokenTypes.UNUSED
3387+
elif tokenizer.IsByte(token_id):
3388+
toktype = SentencePieceTokenTypes.BYTE
3389+
3390+
tokens[token_id] = text
3391+
scores[token_id] = score
3392+
toktypes[token_id] = toktype
3393+
3394+
added_tokens_file = self.dir_model / 'added_tokens.json'
3395+
if added_tokens_file.is_file():
3396+
with open(added_tokens_file, "r", encoding="utf-8") as f:
3397+
added_tokens_json = json.load(f)
3398+
for key in added_tokens_json:
3399+
token_id = added_tokens_json[key]
3400+
if token_id >= vocab_size:
3401+
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
3402+
continue
3403+
3404+
tokens[token_id] = key.encode("utf-8")
3405+
scores[token_id] = -1000.0
3406+
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
3407+
3408+
if vocab_size > len(tokens):
3409+
pad_count = vocab_size - len(tokens)
3410+
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
3411+
for i in range(1, pad_count + 1):
3412+
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
3413+
scores.append(-1000.0)
3414+
toktypes.append(SentencePieceTokenTypes.UNUSED)
3415+
3416+
self.gguf_writer.add_tokenizer_model("t5")
3417+
self.gguf_writer.add_tokenizer_pre("default")
3418+
self.gguf_writer.add_token_list(tokens)
3419+
self.gguf_writer.add_token_scores(scores)
3420+
self.gguf_writer.add_token_types(toktypes)
3421+
self.gguf_writer.add_add_space_prefix(add_prefix)
3422+
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
3423+
if precompiled_charsmap:
3424+
self.gguf_writer.add_precompiled_charsmap(precompiled_charsmap)
3425+
3426+
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
3427+
special_vocab.add_to_gguf(self.gguf_writer)
3428+
3429+
self.gguf_writer.add_add_bos_token(False)
3430+
self.gguf_writer.add_add_eos_token(True)
3431+
3432+
def set_gguf_parameters(self):
3433+
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
3434+
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
3435+
n_ctx = 512
3436+
self.gguf_writer.add_context_length(n_ctx)
3437+
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
3438+
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
3439+
self.gguf_writer.add_block_count(self.hparams["num_layers"])
3440+
self.gguf_writer.add_head_count(self.hparams["num_heads"])
3441+
self.gguf_writer.add_key_length(self.hparams["d_kv"])
3442+
self.gguf_writer.add_value_length(self.hparams["d_kv"])
3443+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
3444+
self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"])
3445+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
3446+
self.gguf_writer.add_file_type(self.ftype)
3447+
3448+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3449+
del bid # unused
3450+
3451+
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
3452+
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
3453+
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
3454+
# and decoder and ignore the remaining ones.
3455+
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
3456+
if not self.shared_token_embeddings_found:
3457+
name = "shared.weight"
3458+
self.shared_token_embeddings_found = True
3459+
else:
3460+
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
3461+
return []
3462+
3463+
return [(self.map_tensor_name(name), data_torch)]
3464+
3465+
33273466
@Model.register("JAISLMHeadModel")
33283467
class JaisModel(Model):
33293468
model_arch = gguf.MODEL_ARCH.JAIS

examples/embedding/embedding.cpp

Lines changed: 98 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,47 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
3131
}
3232

3333
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
34+
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
35+
const struct llama_model * model = llama_get_model(ctx);
36+
3437
// clear previous kv_cache values (irrelevant for embeddings)
3538
llama_kv_cache_clear(ctx);
3639

3740
// run model
3841
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
39-
if (llama_decode(ctx, batch) < 0) {
40-
fprintf(stderr, "%s : failed to decode\n", __func__);
42+
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
43+
// encoder-only model
44+
if (llama_encode(ctx, batch) < 0) {
45+
fprintf(stderr, "%s : failed to encode\n", __func__);
46+
}
47+
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
48+
// decoder-only model
49+
if (llama_decode(ctx, batch) < 0) {
50+
fprintf(stderr, "%s : failed to decode\n", __func__);
51+
}
4152
}
4253

4354
for (int i = 0; i < batch.n_tokens; i++) {
4455
if (!batch.logits[i]) {
4556
continue;
4657
}
4758

48-
// try to get sequence embeddings - supported only when pooling_type is not NONE
49-
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
50-
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
59+
const float * embd = nullptr;
60+
int embd_pos = 0;
61+
62+
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
63+
// try to get token embeddings
64+
embd = llama_get_embeddings_ith(ctx, i);
65+
embd_pos = i;
66+
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
67+
} else {
68+
// try to get sequence embeddings - supported only when pooling_type is not NONE
69+
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
70+
embd_pos = batch.seq_id[i][0];
71+
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
72+
}
5173

52-
float * out = output + batch.seq_id[i][0] * n_embd;
74+
float * out = output + embd_pos * n_embd;
5375
llama_embd_normalize(embd, out, n_embd, embd_norm);
5476
}
5577
}
@@ -93,8 +115,9 @@ int main(int argc, char ** argv) {
93115
const int n_ctx = llama_n_ctx(ctx);
94116

95117
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
96-
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
97-
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
118+
119+
if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
120+
fprintf(stderr, "%s: error: computing embeddings in encoder-decoder models is not supported\n", __func__);
98121
return 1;
99122
}
100123

@@ -153,13 +176,23 @@ int main(int argc, char ** argv) {
153176
const int n_prompts = prompts.size();
154177
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
155178

179+
// count number of embeddings
180+
int n_embd_count = 0;
181+
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
182+
for (int k = 0; k < n_prompts; k++) {
183+
n_embd_count += inputs[k].size();
184+
}
185+
} else {
186+
n_embd_count = n_prompts;
187+
}
188+
156189
// allocate output
157190
const int n_embd = llama_n_embd(model);
158-
std::vector<float> embeddings(n_prompts * n_embd, 0);
191+
std::vector<float> embeddings(n_embd_count * n_embd, 0);
159192
float * emb = embeddings.data();
160193

161194
// break into batches
162-
int p = 0; // number of prompts processed already
195+
int e = 0; // number of embeddings already stored
163196
int s = 0; // number of prompts in current batch
164197
for (int k = 0; k < n_prompts; k++) {
165198
// clamp to n_batch tokens
@@ -169,11 +202,11 @@ int main(int argc, char ** argv) {
169202

170203
// encode if at capacity
171204
if (batch.n_tokens + n_toks > n_batch) {
172-
float * out = emb + p * n_embd;
205+
float * out = emb + e * n_embd;
173206
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
174-
llama_batch_clear(batch);
175-
p += s;
207+
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
176208
s = 0;
209+
llama_batch_clear(batch);
177210
}
178211

179212
// add to batch
@@ -182,39 +215,62 @@ int main(int argc, char ** argv) {
182215
}
183216

184217
// final batch
185-
float * out = emb + p * n_embd;
218+
float * out = emb + e * n_embd;
186219
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
187220

188221
if (params.embd_out.empty()) {
189-
// print the first part of the embeddings or for a single prompt, the full embedding
190222
fprintf(stdout, "\n");
191-
for (int j = 0; j < n_prompts; j++) {
192-
fprintf(stdout, "embedding %d: ", j);
193-
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
194-
if (params.embd_normalize == 0) {
195-
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
196-
} else {
197-
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
223+
224+
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
225+
for (int j = 0; j < n_embd_count; j++) {
226+
fprintf(stdout, "embedding %d: ", j);
227+
for (int i = 0; i < std::min(3, n_embd); i++) {
228+
if (params.embd_normalize == 0) {
229+
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
230+
} else {
231+
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
232+
}
233+
}
234+
fprintf(stdout, " ... ");
235+
for (int i = n_embd - 3; i < n_embd; i++) {
236+
if (params.embd_normalize == 0) {
237+
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
238+
} else {
239+
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
240+
}
198241
}
242+
fprintf(stdout, "\n");
199243
}
200-
fprintf(stdout, "\n");
201-
}
202-
203-
// print cosine similarity matrix
204-
if (n_prompts > 1) {
205-
fprintf(stdout, "\n");
206-
printf("cosine similarity matrix:\n\n");
207-
for (int i = 0; i < n_prompts; i++) {
208-
fprintf(stdout, "%6.6s ", prompts[i].c_str());
244+
} else {
245+
// print the first part of the embeddings or for a single prompt, the full embedding
246+
for (int j = 0; j < n_prompts; j++) {
247+
fprintf(stdout, "embedding %d: ", j);
248+
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
249+
if (params.embd_normalize == 0) {
250+
fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
251+
} else {
252+
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
253+
}
254+
}
255+
fprintf(stdout, "\n");
209256
}
210-
fprintf(stdout, "\n");
211-
for (int i = 0; i < n_prompts; i++) {
212-
for (int j = 0; j < n_prompts; j++) {
213-
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
214-
fprintf(stdout, "%6.2f ", sim);
257+
258+
// print cosine similarity matrix
259+
if (n_prompts > 1) {
260+
fprintf(stdout, "\n");
261+
printf("cosine similarity matrix:\n\n");
262+
for (int i = 0; i < n_prompts; i++) {
263+
fprintf(stdout, "%6.6s ", prompts[i].c_str());
215264
}
216-
fprintf(stdout, "%1.10s", prompts[i].c_str());
217265
fprintf(stdout, "\n");
266+
for (int i = 0; i < n_prompts; i++) {
267+
for (int j = 0; j < n_prompts; j++) {
268+
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
269+
fprintf(stdout, "%6.2f ", sim);
270+
}
271+
fprintf(stdout, "%1.10s", prompts[i].c_str());
272+
fprintf(stdout, "\n");
273+
}
218274
}
219275
}
220276
}
@@ -233,23 +289,23 @@ int main(int argc, char ** argv) {
233289
}
234290
fprintf(stdout, notArray ? "]\n }" : "]");
235291
j++;
236-
if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break;
292+
if (j < n_embd_count) fprintf(stdout, notArray ? ",\n" : ","); else break;
237293
}
238294
fprintf(stdout, notArray ? "\n ]" : "]\n");
239295

240296
if (params.embd_out == "json+" && n_prompts > 1) {
241297
fprintf(stdout, ",\n \"cosineSimilarity\": [\n");
242-
for (int i = 0;;) { // at least two iteration (n_prompts > 1)
298+
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
243299
fprintf(stdout, " [");
244-
for (int j = 0;;) { // at least two iteration (n_prompts > 1)
300+
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
245301
float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
246302
fprintf(stdout, "%6.2f", sim);
247303
j++;
248-
if (j < n_prompts) fprintf(stdout, ", "); else break;
304+
if (j < n_embd_count) fprintf(stdout, ", "); else break;
249305
}
250306
fprintf(stdout, " ]");
251307
i++;
252-
if (i < n_prompts) fprintf(stdout, ",\n"); else break;
308+
if (i < n_embd_count) fprintf(stdout, ",\n"); else break;
253309
}
254310
fprintf(stdout, "\n ]");
255311
}

0 commit comments

Comments
 (0)