1
1
/// Payload tokenization logic
2
2
use crate :: TextEmbeddingsError ;
3
+ use std:: collections:: HashMap ;
3
4
use tokenizers:: tokenizer:: Tokenizer ;
4
5
pub use tokenizers:: Encoding as RawEncoding ;
5
6
use tokenizers:: { TruncationDirection , TruncationParams , TruncationStrategy } ;
@@ -19,6 +20,8 @@ impl Tokenization {
19
20
tokenizer : Tokenizer ,
20
21
max_input_length : usize ,
21
22
position_offset : usize ,
23
+ default_prompt : Option < String > ,
24
+ prompts : Option < HashMap < String , String > > ,
22
25
) -> Self {
23
26
tracing:: info!( "Starting {workers} tokenization workers" ) ;
24
27
@@ -29,12 +32,16 @@ impl Tokenization {
29
32
for _ in 0 ..workers {
30
33
let tokenizer_clone = tokenizer. clone ( ) ;
31
34
let receiver_clone = receiver. clone ( ) ;
35
+ let default_prompt_clone = default_prompt. clone ( ) ;
36
+ let prompts_clone = prompts. clone ( ) ;
32
37
// Spawn worker
33
38
std:: thread:: spawn ( move || {
34
39
tokenizer_worker (
35
40
tokenizer_clone,
36
41
max_input_length,
37
42
position_offset,
43
+ default_prompt_clone,
44
+ prompts_clone,
38
45
receiver_clone,
39
46
)
40
47
} ) ;
@@ -49,6 +56,7 @@ impl Tokenization {
49
56
inputs : EncodingInput ,
50
57
truncate : bool ,
51
58
truncation_direction : TruncationDirection ,
59
+ prompt_name : Option < String > ,
52
60
) -> Result < ValidEncoding , TextEmbeddingsError > {
53
61
// Check if inputs is empty
54
62
if inputs. is_empty ( ) {
@@ -66,6 +74,7 @@ impl Tokenization {
66
74
inputs,
67
75
truncate,
68
76
truncation_direction,
77
+ prompt_name,
69
78
response_sender,
70
79
Span :: current ( ) ,
71
80
) )
@@ -82,7 +91,8 @@ impl Tokenization {
82
91
& self ,
83
92
inputs : EncodingInput ,
84
93
add_special_tokens : bool ,
85
- ) -> Result < RawEncoding , TextEmbeddingsError > {
94
+ prompt_name : Option < String > ,
95
+ ) -> Result < ( Option < String > , RawEncoding ) , TextEmbeddingsError > {
86
96
// Check if inputs is empty
87
97
if inputs. is_empty ( ) {
88
98
return Err ( TextEmbeddingsError :: Validation (
@@ -98,6 +108,7 @@ impl Tokenization {
98
108
. send ( TokenizerRequest :: Tokenize (
99
109
inputs,
100
110
add_special_tokens,
111
+ prompt_name,
101
112
response_sender,
102
113
Span :: current ( ) ,
103
114
) )
@@ -147,6 +158,8 @@ fn tokenizer_worker(
147
158
mut tokenizer : Tokenizer ,
148
159
max_input_length : usize ,
149
160
position_offset : usize ,
161
+ default_prompt : Option < String > ,
162
+ prompts : Option < HashMap < String , String > > ,
150
163
receiver : async_channel:: Receiver < TokenizerRequest > ,
151
164
) {
152
165
// Loop over requests
@@ -156,11 +169,17 @@ fn tokenizer_worker(
156
169
inputs,
157
170
truncate,
158
171
truncation_direction,
172
+ prompt_name,
159
173
response_tx,
160
174
parent_span,
161
175
) => {
162
176
parent_span. in_scope ( || {
163
177
if !response_tx. is_closed ( ) {
178
+ let default_prompt_clone = match prompt_name {
179
+ None => default_prompt. clone ( ) ,
180
+ Some ( _) => None ,
181
+ } ;
182
+
164
183
// It's possible that the user dropped its request resulting in a send error.
165
184
// We just discard the error
166
185
let _ = response_tx. send ( encode_input (
@@ -169,20 +188,37 @@ fn tokenizer_worker(
169
188
truncation_direction,
170
189
max_input_length,
171
190
position_offset,
191
+ default_prompt_clone,
192
+ prompt_name,
193
+ prompts. as_ref ( ) ,
172
194
& mut tokenizer,
173
195
) ) ;
174
196
}
175
197
} )
176
198
}
177
- TokenizerRequest :: Tokenize ( inputs, add_special_tokens, response_tx, parent_span) => {
199
+ TokenizerRequest :: Tokenize (
200
+ inputs,
201
+ add_special_tokens,
202
+ prompt_name,
203
+ response_tx,
204
+ parent_span,
205
+ ) => {
178
206
parent_span. in_scope ( || {
179
207
if !response_tx. is_closed ( ) {
208
+ let default_prompt_clone = match prompt_name {
209
+ None => default_prompt. clone ( ) ,
210
+ Some ( _) => None ,
211
+ } ;
212
+
180
213
// It's possible that the user dropped its request resulting in a send error.
181
214
// We just discard the error
182
215
let _ = response_tx. send ( tokenize_input (
183
216
inputs,
184
217
add_special_tokens,
185
218
None ,
219
+ default_prompt_clone,
220
+ prompt_name,
221
+ prompts. as_ref ( ) ,
186
222
& mut tokenizer,
187
223
) ) ;
188
224
}
@@ -212,40 +248,104 @@ fn decode_ids(
212
248
. decode ( & ids, skip_special_tokens) ?)
213
249
}
214
250
251
+ fn prepare_pre_prompt (
252
+ default_prompt : Option < String > ,
253
+ prompt_name : Option < String > ,
254
+ prompts : Option < & HashMap < String , String > > ,
255
+ ) -> Result < Option < String > , TextEmbeddingsError > {
256
+ let pre_prompt = if let Some ( prompt_name) = prompt_name. as_ref ( ) {
257
+ match prompts {
258
+ None => {
259
+ return Err ( TextEmbeddingsError :: Validation ( format ! ( "`default-prompt-name` is set to `{prompt_name}` but no prompts were found in the Sentence Transformers configuration" ) ) ) ;
260
+ }
261
+ Some ( prompts) if !prompts. contains_key ( prompt_name) => {
262
+ return Err ( TextEmbeddingsError :: Validation ( format ! ( "`default-prompt-name` is set to `{prompt_name}` but it was not found in the Sentence Transformers prompts. Available prompts: {:?}" , prompts. keys( ) ) ) ) ;
263
+ }
264
+ Some ( prompts) => prompts. get ( prompt_name) . cloned ( ) ,
265
+ }
266
+ } else {
267
+ default_prompt
268
+ } ;
269
+ Ok ( pre_prompt)
270
+ }
271
+
215
272
fn tokenize_input (
216
273
inputs : EncodingInput ,
217
274
add_special_tokens : bool ,
218
275
truncate_params : Option < TruncationParams > ,
276
+ default_prompt : Option < String > ,
277
+ prompt_name : Option < String > ,
278
+ prompts : Option < & HashMap < String , String > > ,
219
279
tokenizer : & mut Tokenizer ,
220
- ) -> Result < RawEncoding , TextEmbeddingsError > {
280
+ ) -> Result < ( Option < String > , RawEncoding ) , TextEmbeddingsError > {
281
+ let pre_prompt = prepare_pre_prompt ( default_prompt, prompt_name, prompts) ?;
282
+
221
283
let encoding = match inputs {
222
284
// encode input
223
- EncodingInput :: Single ( s) => tokenizer
224
- . with_truncation ( truncate_params) ?
225
- . encode :: < String > ( s, add_special_tokens) ?,
226
- EncodingInput :: Dual ( s1, s2) => {
227
- tokenizer
285
+ EncodingInput :: Single ( s) => {
286
+ let s = if let Some ( mut pre_prompt) = pre_prompt {
287
+ pre_prompt. push_str ( & s) ;
288
+ pre_prompt
289
+ } else {
290
+ s
291
+ } ;
292
+
293
+ let encoding = tokenizer
228
294
. with_truncation ( truncate_params) ?
229
- . encode :: < ( String , String ) > ( ( s1, s2) , add_special_tokens) ?
295
+ . encode :: < & str > ( & s, add_special_tokens) ?;
296
+
297
+ ( Some ( s) , encoding)
298
+ }
299
+ EncodingInput :: Dual ( s1, s2) => {
300
+ if pre_prompt. is_some ( ) {
301
+ return Err ( TextEmbeddingsError :: Validation (
302
+ "`prompt_name` cannot be set with dual inputs" . to_string ( ) ,
303
+ ) ) ;
304
+ }
305
+
306
+ (
307
+ None ,
308
+ tokenizer
309
+ . with_truncation ( truncate_params) ?
310
+ . encode :: < ( String , String ) > ( ( s1, s2) , add_special_tokens) ?,
311
+ )
230
312
}
231
313
// input is encoded -> convert to tokenizers Encoding
232
314
EncodingInput :: Ids ( ids) => {
233
- let text = tokenizer. decode ( & ids, false ) ?;
234
- tokenizer
235
- . with_truncation ( truncate_params) ?
236
- . encode :: < String > ( text, false ) ?
315
+ if let Some ( mut pre_prompt) = pre_prompt {
316
+ let text = tokenizer. decode ( & ids, true ) ?;
317
+ pre_prompt. push_str ( & text) ;
318
+
319
+ let encoding = tokenizer
320
+ . with_truncation ( truncate_params) ?
321
+ . encode :: < & str > ( & pre_prompt, true ) ?;
322
+
323
+ ( Some ( pre_prompt) , encoding)
324
+ } else {
325
+ let text = tokenizer. decode ( & ids, false ) ?;
326
+
327
+ let encoding = tokenizer
328
+ . with_truncation ( truncate_params) ?
329
+ . encode :: < & str > ( & text, false ) ?;
330
+
331
+ ( Some ( text) , encoding)
332
+ }
237
333
}
238
334
} ;
239
335
Ok ( encoding)
240
336
}
241
337
242
338
/// Get input length and optionally truncate it
339
+ #[ allow( clippy:: too_many_arguments) ]
243
340
fn encode_input (
244
341
inputs : EncodingInput ,
245
342
truncate : bool ,
246
343
truncation_direction : TruncationDirection ,
247
344
max_input_length : usize ,
248
345
position_offset : usize ,
346
+ default_prompt : Option < String > ,
347
+ prompt_name : Option < String > ,
348
+ prompts : Option < & HashMap < String , String > > ,
249
349
tokenizer : & mut Tokenizer ,
250
350
) -> Result < ValidEncoding , TextEmbeddingsError > {
251
351
// Default truncation params
@@ -256,7 +356,15 @@ fn encode_input(
256
356
stride : 0 ,
257
357
} ) ;
258
358
259
- let encoding = tokenize_input ( inputs, true , truncate_params, tokenizer) ?;
359
+ let ( _, encoding) = tokenize_input (
360
+ inputs,
361
+ true ,
362
+ truncate_params,
363
+ default_prompt,
364
+ prompt_name,
365
+ prompts,
366
+ tokenizer,
367
+ ) ?;
260
368
let seq_len = encoding. len ( ) ;
261
369
262
370
if seq_len > max_input_length {
@@ -315,13 +423,15 @@ enum TokenizerRequest {
315
423
EncodingInput ,
316
424
bool ,
317
425
TruncationDirection ,
426
+ Option < String > ,
318
427
oneshot:: Sender < Result < ValidEncoding , TextEmbeddingsError > > ,
319
428
Span ,
320
429
) ,
321
430
Tokenize (
322
431
EncodingInput ,
323
432
bool ,
324
- oneshot:: Sender < Result < RawEncoding , TextEmbeddingsError > > ,
433
+ Option < String > ,
434
+ oneshot:: Sender < Result < ( Option < String > , RawEncoding ) , TextEmbeddingsError > > ,
325
435
Span ,
326
436
) ,
327
437
Decode (
0 commit comments