Skip to content

Commit 8e2d359

Browse files
committed
add the ability to have multi-round conversation with llama
Ad the ability to have multi-round conversations with LLM. This will be helpful for testing long context length. Differential Revision: [D65771122](https://our.internmc.facebook.com/intern/diff/D65771122/) ghstack-source-id: 252934165 Pull Request resolved: #6758
1 parent 57973d8 commit 8e2d359

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

examples/models/llama/runner/eager.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser:
5454
parser.add_argument(
5555
"--prompt",
5656
type=str,
57-
default="Hello",
57+
default=None,
5858
)
5959

6060
parser.add_argument(
@@ -70,6 +70,13 @@ def build_args_parser() -> argparse.ArgumentParser:
7070
help="Show the tokens that were generated",
7171
)
7272

73+
parser.add_argument(
74+
"--chat",
75+
action="store_true",
76+
default=False,
77+
help="Have multi-turn chat with the model",
78+
)
79+
7380
return parser
7481

7582

@@ -78,9 +85,13 @@ def main() -> None:
7885
args = parser.parse_args()
7986

8087
runner = EagerLlamaRunner(args)
81-
generated_tokens = runner.text_completion(
82-
prompt=args.prompt,
83-
temperature=args.temperature,
88+
generated_tokens = (
89+
runner.chat_completion(temperature=args.temperature)
90+
if args.chat
91+
else runner.text_completion(
92+
prompt=args.prompt,
93+
temperature=args.temperature,
94+
)
8495
)
8596
if args.show_tokens:
8697
print(f"Tokens: {generated_tokens}")

examples/models/llama/runner/generation.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,13 @@ def generate( # noqa: C901
6767
temperature: float = 0.8,
6868
top_p: float = 0.9,
6969
echo: bool = False,
70+
pos_base: int = 0,
7071
) -> List[int]:
7172
# prefill
7273
logits = self.forward(
7374
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
7475
input_pos=(
75-
torch.tensor([0], dtype=torch.long, device=self.device)
76+
torch.tensor([pos_base], dtype=torch.long, device=self.device)
7677
if self.params.use_kv_cache
7778
else None
7879
),
@@ -89,7 +90,9 @@ def generate( # noqa: C901
8990
[[current_token]], dtype=torch.long, device=self.device
9091
),
9192
input_pos=torch.tensor(
92-
[len(tokens) - 1], dtype=torch.long, device=self.device
93+
[pos_base + len(tokens) - 1],
94+
dtype=torch.long,
95+
device=self.device,
9396
),
9497
)
9598
else:
@@ -136,3 +139,49 @@ def text_completion(
136139
top_p=top_p,
137140
echo=echo,
138141
)
142+
143+
def chat_completion(
144+
self,
145+
temperature: float = 0.6,
146+
top_p: float = 0.9,
147+
) -> List[int]:
148+
"""
149+
Perform multi-turn chat with the language model.
150+
151+
Args:
152+
prompt (str): Text prompt for completion.
153+
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
154+
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
155+
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
156+
157+
Returns:
158+
Generated list of tokens.
159+
160+
Note:
161+
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
162+
"""
163+
exit_prompt = "exit"
164+
tokens = []
165+
prompt = input("Me: ")
166+
while prompt and prompt != exit_prompt:
167+
print("LLM: ", end="", flush=True)
168+
new_tokens = self.generate(
169+
prompt_tokens=self.tokenizer.encode(
170+
self._format_prompt(prompt), bos=True, eos=False
171+
),
172+
temperature=temperature,
173+
top_p=top_p,
174+
echo=True,
175+
pos_base=len(tokens),
176+
)
177+
tokens.extend(new_tokens)
178+
prompt = input("Me: ")
179+
return tokens
180+
181+
def _format_prompt(self, prompt: str) -> str:
182+
return f"""
183+
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
184+
185+
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
186+
187+
{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

0 commit comments

Comments
 (0)