Skip to content

Commit 5e9bab8

Browse files
Delete dead code
Differential Revision: D61166041 Pull Request resolved: #4678 --------- Co-authored-by: helunwencser <[email protected]>
1 parent b6de6ed commit 5e9bab8

File tree

1 file changed

+1
-77
lines changed

1 file changed

+1
-77
lines changed

examples/models/llama2/runner/generation.py

Lines changed: 1 addition & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
import torch.nn.functional as F
1515
from executorch.examples.models.llama2.llama_transformer import ModelArgs
1616

17-
from executorch.examples.models.llama2.tokenizer.tiktoken import (
18-
Dialog,
19-
Message,
20-
Tokenizer,
21-
)
17+
from executorch.examples.models.llama2.tokenizer.tiktoken import Tokenizer
2218
from executorch.extension.pybindings.portable_lib import _load_for_executorch
2319

2420

@@ -28,12 +24,6 @@ class CompletionPrediction(TypedDict, total=False):
2824
logprobs: List[float] # not required
2925

3026

31-
class ChatPrediction(TypedDict, total=False):
32-
generation: Message
33-
tokens: List[str] # not required
34-
logprobs: List[float] # not required
35-
36-
3727
def sample_top_p(probs, p):
3828
"""
3929
Perform top-p (nucleus) sampling on a probability distribution.
@@ -225,72 +215,6 @@ def text_completion(
225215
]
226216
return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
227217

228-
def chat_completion(
229-
self,
230-
dialogs: List[Dialog],
231-
temperature: float = 0.6,
232-
top_p: float = 0.9,
233-
max_gen_len: Optional[int] = None,
234-
logprobs: bool = False,
235-
) -> List[ChatPrediction]:
236-
"""
237-
Generate assistant responses for a list of conversational dialogs using the language generation model.
238-
239-
Args:
240-
dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
241-
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
242-
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
243-
max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
244-
If not provided, it's set to the model's maximum sequence length minus 1.
245-
logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
246-
247-
Returns:
248-
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
249-
250-
Raises:
251-
AssertionError: If the last message in a dialog is not from the user.
252-
AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
253-
254-
Note:
255-
This method generates assistant responses for the provided conversational dialogs.
256-
It employs nucleus sampling to introduce controlled randomness in text generation.
257-
If logprobs is True, token log probabilities are computed for each generated token.
258-
"""
259-
if max_gen_len is None:
260-
max_gen_len = self.model.params.max_seq_len - 1
261-
262-
prompt_tokens = [
263-
self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
264-
]
265-
generation_tokens, generation_logprobs = self.generate(
266-
prompt_tokens=prompt_tokens,
267-
max_gen_len=max_gen_len,
268-
temperature=temperature,
269-
top_p=top_p,
270-
logprobs=logprobs,
271-
)
272-
if logprobs:
273-
return [
274-
{
275-
"generation": {
276-
"role": "assistant",
277-
"content": self.tokenizer.decode(t),
278-
},
279-
"tokens": [self.tokenizer.decode([x]) for x in t],
280-
"logprobs": logprobs_i,
281-
}
282-
for t, logprobs_i in zip(generation_tokens, generation_logprobs)
283-
]
284-
return [
285-
{
286-
"generation": {
287-
"role": "assistant",
288-
"content": self.tokenizer.decode(t),
289-
},
290-
}
291-
for t in generation_tokens
292-
]
293-
294218

295219
def build_args_parser() -> argparse.ArgumentParser:
296220
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)