Skip to content

Commit 623a9a6

Browse files
add the ability to have multi-round conversation with llama (#6769)
* update llama runner to decode single token Pull Request resolved: #6703 Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. ghstack-source-id: 252924039 Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) * add the ability to have multi-round conversation with llama Ad the ability to have multi-round conversations with LLM. This will be helpful for testing long context length. Differential Revision: [D65771122](https://our.internmc.facebook.com/intern/diff/D65771122/) ghstack-source-id: 252934165 Pull Request resolved: #6758 --------- Co-authored-by: Lunwen He <[email protected]>
1 parent 671f9c5 commit 623a9a6

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

examples/models/llama/runner/eager.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser:
5454
parser.add_argument(
5555
"--prompt",
5656
type=str,
57-
default="Hello",
57+
default=None,
5858
)
5959

6060
parser.add_argument(
@@ -70,6 +70,13 @@ def build_args_parser() -> argparse.ArgumentParser:
7070
help="Show the tokens that were generated",
7171
)
7272

73+
parser.add_argument(
74+
"--chat",
75+
action="store_true",
76+
default=False,
77+
help="Have multi-turn chat with the model",
78+
)
79+
7380
return parser
7481

7582

@@ -78,9 +85,13 @@ def main() -> None:
7885
args = parser.parse_args()
7986

8087
runner = EagerLlamaRunner(args)
81-
generated_tokens = runner.text_completion(
82-
prompt=args.prompt,
83-
temperature=args.temperature,
88+
generated_tokens = (
89+
runner.chat_completion(temperature=args.temperature)
90+
if args.chat
91+
else runner.text_completion(
92+
prompt=args.prompt,
93+
temperature=args.temperature,
94+
)
8495
)
8596
if args.show_tokens:
8697
print(f"Tokens: {generated_tokens}")

examples/models/llama/runner/generation.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,13 @@ def generate( # noqa: C901
6767
temperature: float = 0.8,
6868
top_p: float = 0.9,
6969
echo: bool = False,
70+
pos_base: int = 0,
7071
) -> List[int]:
7172
# prefill
7273
logits = self.forward(
7374
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
7475
input_pos=(
75-
torch.tensor([0], dtype=torch.long, device=self.device)
76+
torch.tensor([pos_base], dtype=torch.long, device=self.device)
7677
if self.params.use_kv_cache
7778
else None
7879
),
@@ -89,7 +90,9 @@ def generate( # noqa: C901
8990
[[current_token]], dtype=torch.long, device=self.device
9091
),
9192
input_pos=torch.tensor(
92-
[len(tokens) - 1], dtype=torch.long, device=self.device
93+
[pos_base + len(tokens) - 1],
94+
dtype=torch.long,
95+
device=self.device,
9396
),
9497
)
9598
else:
@@ -136,3 +139,49 @@ def text_completion(
136139
top_p=top_p,
137140
echo=echo,
138141
)
142+
143+
def chat_completion(
144+
self,
145+
temperature: float = 0.6,
146+
top_p: float = 0.9,
147+
) -> List[int]:
148+
"""
149+
Perform multi-turn chat with the language model.
150+
151+
Args:
152+
prompt (str): Text prompt for completion.
153+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
154+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
155+
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
156+
157+
Returns:
158+
Generated list of tokens.
159+
160+
Note:
161+
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
162+
"""
163+
exit_prompt = "exit"
164+
tokens = []
165+
prompt = input("Me: ")
166+
while prompt and prompt != exit_prompt:
167+
print("LLM: ", end="", flush=True)
168+
new_tokens = self.generate(
169+
prompt_tokens=self.tokenizer.encode(
170+
self._format_prompt(prompt), bos=True, eos=False
171+
),
172+
temperature=temperature,
173+
top_p=top_p,
174+
echo=True,
175+
pos_base=len(tokens),
176+
)
177+
tokens.extend(new_tokens)
178+
prompt = input("Me: ")
179+
return tokens
180+
181+
def _format_prompt(self, prompt: str) -> str:
182+
return f"""
183+
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
184+
185+
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
186+
187+
{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

0 commit comments

Comments
 (0)