Skip to content

Commit 0361479

Browse files
feat: add default prompts (#312)
1 parent 35aefeb commit 0361479

File tree

12 files changed

+545
-449
lines changed

12 files changed

+545
-449
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

core/src/download.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,10 @@ pub async fn download_st_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
102102

103103
Err(err)
104104
}
105+
106+
#[instrument(skip_all)]
107+
pub async fn download_new_st_config(api: &ApiRepo) -> Result<PathBuf, ApiError> {
108+
tracing::info!("Downloading `config_sentence_transformers.json`");
109+
let pool_config_path = api.get("config_sentence_transformers.json").await?;
110+
Ok(pool_config_path)
111+
}

core/src/infer.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,10 @@ impl Infer {
6060
&self,
6161
inputs: I,
6262
add_special_tokens: bool,
63-
) -> Result<RawEncoding, TextEmbeddingsError> {
63+
prompt_name: Option<String>,
64+
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
6465
self.tokenization
65-
.tokenize(inputs.into(), add_special_tokens)
66+
.tokenize(inputs.into(), add_special_tokens, prompt_name)
6667
.await
6768
.map_err(|err| {
6869
let counter = metrics::counter!("te_request_failure", "err" => "tokenization");
@@ -119,6 +120,7 @@ impl Infer {
119120
inputs: I,
120121
truncate: bool,
121122
truncation_direction: TruncationDirection,
123+
prompt_name: Option<String>,
122124
permit: OwnedSemaphorePermit,
123125
) -> Result<AllEmbeddingsInferResponse, TextEmbeddingsError> {
124126
let start_time = Instant::now();
@@ -138,6 +140,7 @@ impl Infer {
138140
inputs,
139141
truncate,
140142
truncation_direction,
143+
prompt_name,
141144
false,
142145
&start_time,
143146
permit,
@@ -172,6 +175,7 @@ impl Infer {
172175
inputs: I,
173176
truncate: bool,
174177
truncation_direction: TruncationDirection,
178+
prompt_name: Option<String>,
175179
permit: OwnedSemaphorePermit,
176180
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
177181
let start_time = Instant::now();
@@ -191,6 +195,7 @@ impl Infer {
191195
inputs,
192196
truncate,
193197
truncation_direction,
198+
prompt_name,
194199
true,
195200
&start_time,
196201
permit,
@@ -225,6 +230,7 @@ impl Infer {
225230
inputs: I,
226231
truncate: bool,
227232
truncation_direction: TruncationDirection,
233+
prompt_name: Option<String>,
228234
normalize: bool,
229235
permit: OwnedSemaphorePermit,
230236
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
@@ -245,6 +251,7 @@ impl Infer {
245251
inputs,
246252
truncate,
247253
truncation_direction,
254+
prompt_name,
248255
true,
249256
&start_time,
250257
permit,
@@ -290,11 +297,13 @@ impl Infer {
290297
Ok(response)
291298
}
292299

300+
#[allow(clippy::too_many_arguments)]
293301
async fn embed<I: Into<EncodingInput> + std::fmt::Debug>(
294302
&self,
295303
inputs: I,
296304
truncate: bool,
297305
truncation_direction: TruncationDirection,
306+
prompt_name: Option<String>,
298307
pooling: bool,
299308
start_time: &Instant,
300309
_permit: OwnedSemaphorePermit,
@@ -315,7 +324,7 @@ impl Infer {
315324
// Tokenization
316325
let encoding = self
317326
.tokenization
318-
.encode(inputs.into(), truncate, truncation_direction)
327+
.encode(inputs.into(), truncate, truncation_direction, prompt_name)
319328
.await
320329
.map_err(|err| {
321330
let counter = metrics::counter!("te_request_failure", "err" => "tokenization");
@@ -381,7 +390,7 @@ impl Infer {
381390
// Tokenization
382391
let encoding = self
383392
.tokenization
384-
.encode(inputs.into(), truncate, truncation_direction)
393+
.encode(inputs.into(), truncate, truncation_direction, None)
385394
.await
386395
.map_err(|err| {
387396
let counter = metrics::counter!("te_request_failure", "err" => "tokenization");

core/src/tokenization.rs

Lines changed: 125 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/// Payload tokenization logic
22
use crate::TextEmbeddingsError;
3+
use std::collections::HashMap;
34
use tokenizers::tokenizer::Tokenizer;
45
pub use tokenizers::Encoding as RawEncoding;
56
use tokenizers::{TruncationDirection, TruncationParams, TruncationStrategy};
@@ -19,6 +20,8 @@ impl Tokenization {
1920
tokenizer: Tokenizer,
2021
max_input_length: usize,
2122
position_offset: usize,
23+
default_prompt: Option<String>,
24+
prompts: Option<HashMap<String, String>>,
2225
) -> Self {
2326
tracing::info!("Starting {workers} tokenization workers");
2427

@@ -29,12 +32,16 @@ impl Tokenization {
2932
for _ in 0..workers {
3033
let tokenizer_clone = tokenizer.clone();
3134
let receiver_clone = receiver.clone();
35+
let default_prompt_clone = default_prompt.clone();
36+
let prompts_clone = prompts.clone();
3237
// Spawn worker
3338
std::thread::spawn(move || {
3439
tokenizer_worker(
3540
tokenizer_clone,
3641
max_input_length,
3742
position_offset,
43+
default_prompt_clone,
44+
prompts_clone,
3845
receiver_clone,
3946
)
4047
});
@@ -49,6 +56,7 @@ impl Tokenization {
4956
inputs: EncodingInput,
5057
truncate: bool,
5158
truncation_direction: TruncationDirection,
59+
prompt_name: Option<String>,
5260
) -> Result<ValidEncoding, TextEmbeddingsError> {
5361
// Check if inputs is empty
5462
if inputs.is_empty() {
@@ -66,6 +74,7 @@ impl Tokenization {
6674
inputs,
6775
truncate,
6876
truncation_direction,
77+
prompt_name,
6978
response_sender,
7079
Span::current(),
7180
))
@@ -82,7 +91,8 @@ impl Tokenization {
8291
&self,
8392
inputs: EncodingInput,
8493
add_special_tokens: bool,
85-
) -> Result<RawEncoding, TextEmbeddingsError> {
94+
prompt_name: Option<String>,
95+
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
8696
// Check if inputs is empty
8797
if inputs.is_empty() {
8898
return Err(TextEmbeddingsError::Validation(
@@ -98,6 +108,7 @@ impl Tokenization {
98108
.send(TokenizerRequest::Tokenize(
99109
inputs,
100110
add_special_tokens,
111+
prompt_name,
101112
response_sender,
102113
Span::current(),
103114
))
@@ -147,6 +158,8 @@ fn tokenizer_worker(
147158
mut tokenizer: Tokenizer,
148159
max_input_length: usize,
149160
position_offset: usize,
161+
default_prompt: Option<String>,
162+
prompts: Option<HashMap<String, String>>,
150163
receiver: async_channel::Receiver<TokenizerRequest>,
151164
) {
152165
// Loop over requests
@@ -156,11 +169,17 @@ fn tokenizer_worker(
156169
inputs,
157170
truncate,
158171
truncation_direction,
172+
prompt_name,
159173
response_tx,
160174
parent_span,
161175
) => {
162176
parent_span.in_scope(|| {
163177
if !response_tx.is_closed() {
178+
let default_prompt_clone = match prompt_name {
179+
None => default_prompt.clone(),
180+
Some(_) => None,
181+
};
182+
164183
// It's possible that the user dropped its request resulting in a send error.
165184
// We just discard the error
166185
let _ = response_tx.send(encode_input(
@@ -169,20 +188,37 @@ fn tokenizer_worker(
169188
truncation_direction,
170189
max_input_length,
171190
position_offset,
191+
default_prompt_clone,
192+
prompt_name,
193+
prompts.as_ref(),
172194
&mut tokenizer,
173195
));
174196
}
175197
})
176198
}
177-
TokenizerRequest::Tokenize(inputs, add_special_tokens, response_tx, parent_span) => {
199+
TokenizerRequest::Tokenize(
200+
inputs,
201+
add_special_tokens,
202+
prompt_name,
203+
response_tx,
204+
parent_span,
205+
) => {
178206
parent_span.in_scope(|| {
179207
if !response_tx.is_closed() {
208+
let default_prompt_clone = match prompt_name {
209+
None => default_prompt.clone(),
210+
Some(_) => None,
211+
};
212+
180213
// It's possible that the user dropped its request resulting in a send error.
181214
// We just discard the error
182215
let _ = response_tx.send(tokenize_input(
183216
inputs,
184217
add_special_tokens,
185218
None,
219+
default_prompt_clone,
220+
prompt_name,
221+
prompts.as_ref(),
186222
&mut tokenizer,
187223
));
188224
}
@@ -212,40 +248,104 @@ fn decode_ids(
212248
.decode(&ids, skip_special_tokens)?)
213249
}
214250

251+
fn prepare_pre_prompt(
252+
default_prompt: Option<String>,
253+
prompt_name: Option<String>,
254+
prompts: Option<&HashMap<String, String>>,
255+
) -> Result<Option<String>, TextEmbeddingsError> {
256+
let pre_prompt = if let Some(prompt_name) = prompt_name.as_ref() {
257+
match prompts {
258+
None => {
259+
return Err(TextEmbeddingsError::Validation(format!("`default-prompt-name` is set to `{prompt_name}` but no prompts were found in the Sentence Transformers configuration")));
260+
}
261+
Some(prompts) if !prompts.contains_key(prompt_name) => {
262+
return Err(TextEmbeddingsError::Validation(format!("`default-prompt-name` is set to `{prompt_name}` but it was not found in the Sentence Transformers prompts. Available prompts: {:?}", prompts.keys())));
263+
}
264+
Some(prompts) => prompts.get(prompt_name).cloned(),
265+
}
266+
} else {
267+
default_prompt
268+
};
269+
Ok(pre_prompt)
270+
}
271+
215272
fn tokenize_input(
216273
inputs: EncodingInput,
217274
add_special_tokens: bool,
218275
truncate_params: Option<TruncationParams>,
276+
default_prompt: Option<String>,
277+
prompt_name: Option<String>,
278+
prompts: Option<&HashMap<String, String>>,
219279
tokenizer: &mut Tokenizer,
220-
) -> Result<RawEncoding, TextEmbeddingsError> {
280+
) -> Result<(Option<String>, RawEncoding), TextEmbeddingsError> {
281+
let pre_prompt = prepare_pre_prompt(default_prompt, prompt_name, prompts)?;
282+
221283
let encoding = match inputs {
222284
// encode input
223-
EncodingInput::Single(s) => tokenizer
224-
.with_truncation(truncate_params)?
225-
.encode::<String>(s, add_special_tokens)?,
226-
EncodingInput::Dual(s1, s2) => {
227-
tokenizer
285+
EncodingInput::Single(s) => {
286+
let s = if let Some(mut pre_prompt) = pre_prompt {
287+
pre_prompt.push_str(&s);
288+
pre_prompt
289+
} else {
290+
s
291+
};
292+
293+
let encoding = tokenizer
228294
.with_truncation(truncate_params)?
229-
.encode::<(String, String)>((s1, s2), add_special_tokens)?
295+
.encode::<&str>(&s, add_special_tokens)?;
296+
297+
(Some(s), encoding)
298+
}
299+
EncodingInput::Dual(s1, s2) => {
300+
if pre_prompt.is_some() {
301+
return Err(TextEmbeddingsError::Validation(
302+
"`prompt_name` cannot be set with dual inputs".to_string(),
303+
));
304+
}
305+
306+
(
307+
None,
308+
tokenizer
309+
.with_truncation(truncate_params)?
310+
.encode::<(String, String)>((s1, s2), add_special_tokens)?,
311+
)
230312
}
231313
// input is encoded -> convert to tokenizers Encoding
232314
EncodingInput::Ids(ids) => {
233-
let text = tokenizer.decode(&ids, false)?;
234-
tokenizer
235-
.with_truncation(truncate_params)?
236-
.encode::<String>(text, false)?
315+
if let Some(mut pre_prompt) = pre_prompt {
316+
let text = tokenizer.decode(&ids, true)?;
317+
pre_prompt.push_str(&text);
318+
319+
let encoding = tokenizer
320+
.with_truncation(truncate_params)?
321+
.encode::<&str>(&pre_prompt, true)?;
322+
323+
(Some(pre_prompt), encoding)
324+
} else {
325+
let text = tokenizer.decode(&ids, false)?;
326+
327+
let encoding = tokenizer
328+
.with_truncation(truncate_params)?
329+
.encode::<&str>(&text, false)?;
330+
331+
(Some(text), encoding)
332+
}
237333
}
238334
};
239335
Ok(encoding)
240336
}
241337

242338
/// Get input length and optionally truncate it
339+
#[allow(clippy::too_many_arguments)]
243340
fn encode_input(
244341
inputs: EncodingInput,
245342
truncate: bool,
246343
truncation_direction: TruncationDirection,
247344
max_input_length: usize,
248345
position_offset: usize,
346+
default_prompt: Option<String>,
347+
prompt_name: Option<String>,
348+
prompts: Option<&HashMap<String, String>>,
249349
tokenizer: &mut Tokenizer,
250350
) -> Result<ValidEncoding, TextEmbeddingsError> {
251351
// Default truncation params
@@ -256,7 +356,15 @@ fn encode_input(
256356
stride: 0,
257357
});
258358

259-
let encoding = tokenize_input(inputs, true, truncate_params, tokenizer)?;
359+
let (_, encoding) = tokenize_input(
360+
inputs,
361+
true,
362+
truncate_params,
363+
default_prompt,
364+
prompt_name,
365+
prompts,
366+
tokenizer,
367+
)?;
260368
let seq_len = encoding.len();
261369

262370
if seq_len > max_input_length {
@@ -315,13 +423,15 @@ enum TokenizerRequest {
315423
EncodingInput,
316424
bool,
317425
TruncationDirection,
426+
Option<String>,
318427
oneshot::Sender<Result<ValidEncoding, TextEmbeddingsError>>,
319428
Span,
320429
),
321430
Tokenize(
322431
EncodingInput,
323432
bool,
324-
oneshot::Sender<Result<RawEncoding, TextEmbeddingsError>>,
433+
Option<String>,
434+
oneshot::Sender<Result<(Option<String>, RawEncoding), TextEmbeddingsError>>,
325435
Span,
326436
),
327437
Decode(

0 commit comments

Comments
 (0)