14
14
import torch .nn .functional as F
15
15
from executorch .examples .models .llama2 .llama_transformer import ModelArgs
16
16
17
- from executorch .examples .models .llama2 .tokenizer .tiktoken import (
18
- Dialog ,
19
- Message ,
20
- Tokenizer ,
21
- )
17
+ from executorch .examples .models .llama2 .tokenizer .tiktoken import Tokenizer
22
18
from executorch .extension .pybindings .portable_lib import _load_for_executorch
23
19
24
20
@@ -28,12 +24,6 @@ class CompletionPrediction(TypedDict, total=False):
28
24
logprobs : List [float ] # not required
29
25
30
26
31
- class ChatPrediction (TypedDict , total = False ):
32
- generation : Message
33
- tokens : List [str ] # not required
34
- logprobs : List [float ] # not required
35
-
36
-
37
27
def sample_top_p (probs , p ):
38
28
"""
39
29
Perform top-p (nucleus) sampling on a probability distribution.
@@ -225,72 +215,6 @@ def text_completion(
225
215
]
226
216
return [{"generation" : self .tokenizer .decode (t )} for t in generation_tokens ]
227
217
228
- def chat_completion (
229
- self ,
230
- dialogs : List [Dialog ],
231
- temperature : float = 0.6 ,
232
- top_p : float = 0.9 ,
233
- max_gen_len : Optional [int ] = None ,
234
- logprobs : bool = False ,
235
- ) -> List [ChatPrediction ]:
236
- """
237
- Generate assistant responses for a list of conversational dialogs using the language generation model.
238
-
239
- Args:
240
- dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
241
- temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
242
- top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
243
- max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
244
- If not provided, it's set to the model's maximum sequence length minus 1.
245
- logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
246
-
247
- Returns:
248
- List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
249
-
250
- Raises:
251
- AssertionError: If the last message in a dialog is not from the user.
252
- AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
253
-
254
- Note:
255
- This method generates assistant responses for the provided conversational dialogs.
256
- It employs nucleus sampling to introduce controlled randomness in text generation.
257
- If logprobs is True, token log probabilities are computed for each generated token.
258
- """
259
- if max_gen_len is None :
260
- max_gen_len = self .model .params .max_seq_len - 1
261
-
262
- prompt_tokens = [
263
- self .formatter .encode_dialog_prompt (dialog ) for dialog in dialogs
264
- ]
265
- generation_tokens , generation_logprobs = self .generate (
266
- prompt_tokens = prompt_tokens ,
267
- max_gen_len = max_gen_len ,
268
- temperature = temperature ,
269
- top_p = top_p ,
270
- logprobs = logprobs ,
271
- )
272
- if logprobs :
273
- return [
274
- {
275
- "generation" : {
276
- "role" : "assistant" ,
277
- "content" : self .tokenizer .decode (t ),
278
- },
279
- "tokens" : [self .tokenizer .decode ([x ]) for x in t ],
280
- "logprobs" : logprobs_i ,
281
- }
282
- for t , logprobs_i in zip (generation_tokens , generation_logprobs )
283
- ]
284
- return [
285
- {
286
- "generation" : {
287
- "role" : "assistant" ,
288
- "content" : self .tokenizer .decode (t ),
289
- },
290
- }
291
- for t in generation_tokens
292
- ]
293
-
294
218
295
219
def build_args_parser () -> argparse .ArgumentParser :
296
220
parser = argparse .ArgumentParser ()
0 commit comments