Skip to content

Commit 301640f

Browse files
author
Varun Puri
committed
Separate browser and API, fix context for chat models.
1 parent 7e6a1e7 commit 301640f

File tree

6 files changed

+192
-50
lines changed

6 files changed

+192
-50
lines changed

README.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ source .venv/bin/activate
5454
```
5555
[skip default]: end
5656

57-
[shell default]: ./install_requirements.sh
57+
[shell default]: ./install_requirements.sh
5858

5959
Installations can be tested by
6060

@@ -118,6 +118,34 @@ python3 torchchat.py generate llama3 --prompt "write me a story about a boy and
118118

119119
For more information run `python3 torchchat.py generate --help`
120120

121+
The `Generator` class can also be imported into a Python program to generate responses.
122+
123+
```
124+
from generate import Generator, GeneratorArgs
125+
from build.builder import (
126+
BuilderArgs,
127+
TokenizerArgs,
128+
)
129+
130+
...
131+
132+
# Load the model and tokenizer.
133+
gen = Generator(
134+
builder_args,
135+
speculative_builder_args,
136+
tokenizer_args,
137+
generator_args,
138+
args.profile,
139+
args.quantize,
140+
args.draft_quantize,
141+
)
142+
143+
# The generate function is a Python Generator that will yield torch.Tensors as each token is returned.
144+
for tok in gen.generate(generator_args):
145+
print(gen.tokenizer.decode(y.tolist()))
146+
147+
```
148+
121149

122150
### Browser
123151

api/api.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -146,17 +146,52 @@ class CompletionResponseChunk():
146146
class OpenAIAPIGenerator(Generator):
147147
def __init__(self, *args, **kwargs):
148148
super().__init__(*args, **kwargs)
149+
self.start_pos = 0
150+
self.max_seq_length = (
151+
self.model.config.max_seq_length + self.speculative_builder_args.speculate_k + 1
152+
if self.draft_model is not None
153+
else self.model.config.max_seq_length
154+
)
149155

150156

151157
def completion(self, completion_request: CompletionRequest):
158+
device_sync(device=self.builder_args.device)
159+
160+
id = str(uuid.uuid4())
161+
idx = 0
162+
buffer = []
163+
encoded = self.encode_tokens(
164+
completion_request.prompt, bos=True, device=self.builder_args.device
165+
)
152166
generator_args = GeneratorArgs(
153-
prompt = completion_request.prompt,
154-
temperature = completion_request.temperature,
167+
completion_request.prompt,
168+
encoded_prompt=encoded,
169+
chat_mode=False,
155170
)
156-
id = 12345678
157-
idx = 0
158-
for x, metrics in self.chat(generator_args):
159-
content = "".join(self.tokenizer.decode([self.tokenizer.encode(".")[0]] + x.tolist())[1:])
171+
172+
def callback(x, *, done_generating=False):
173+
return self._callback(
174+
x,
175+
buffer=buffer,
176+
done_generating=done_generating,
177+
)
178+
179+
for y, metrics in self.generate(
180+
self.model,
181+
encoded,
182+
generator_args.max_new_tokens,
183+
draft_model=self.draft_model,
184+
speculate_k=generator_args.speculate_k,
185+
chat_mode=generator_args.chat_mode,
186+
callback=callback,
187+
temperature=generator_args.temperature,
188+
top_k=generator_args.top_k,
189+
sequential_prefill=generator_args.sequential_prefill,
190+
start_pos=self.start_pos,
191+
max_seq_length=self.max_seq_length,
192+
193+
):
194+
content = "".join(self.tokenizer.decode([self.tokenizer.encode(".")[0]] + y.tolist())[1:])
160195
chunk_delta = ChunkDelta(
161196
role = "assistant",
162197
content = content,
@@ -174,6 +209,7 @@ def completion(self, completion_request: CompletionRequest):
174209
system_fingerprint = uuid.UUID(int=uuid.getnode()),
175210
)
176211
yield chunk_response
212+
self.start_pos += y.size(0)
177213
idx += 1
178214

179215
end_chunk: CompletionChoiceChunk(
@@ -190,10 +226,14 @@ def completion(self, completion_request: CompletionRequest):
190226
system_fingerprint = uuid.UUID(int=uuid.getnode()),
191227
)
192228

193-
def _callback(self, x, *, buffer, period_id, done_generating, tokenizer, is_llama3_model):
194-
if x.item() == tokenizer.eos_id():
229+
def _callback(self, x, *, buffer, done_generating):
230+
period_id = self.tokenizer.encode(".")[0]
231+
buffer.append(
232+
self.tokenizer.decode([period_id] + x.tolist())[1:]
233+
)
234+
if x.item() == self.tokenizer.eos_id():
195235
done_generating = True
196-
if is_llama3_model and x.item() == tokenizer.special_tokens["<|eot_id|>"]:
236+
if self.is_llama3_model and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]:
197237
done_generating = True
198238

199239
def main(args):
@@ -248,9 +288,12 @@ def initialize_generator() -> OpenAIAPIGenerator:
248288
)
249289

250290
def unwrap(completion_generator):
251-
for obj in completion_generator:
252-
yield obj.choices[0].delta.content
253-
291+
for chunk_response in completion_generator:
292+
content = chunk_response.choices[0].delta.content
293+
if not gen.is_llama3_model or content not in set(gen.tokenizer.special_tokens.keys()):
294+
yield content
295+
yield "."
296+
254297
response = st.write_stream(unwrap(gen.completion(req)))
255298

256299
# Add assistant response to chat history

api/browser.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
2+
import streamlit as st
3+
4+
from build.builder import (
5+
_initialize_model,
6+
_initialize_tokenizer,
7+
BuilderArgs,
8+
TokenizerArgs,
9+
)
10+
from build.model import Transformer
11+
from build.utils import device_sync, set_precision
12+
from cli import add_arguments_for_verb, arg_init, check_args, logger
13+
14+
from generate import GeneratorArgs
15+
from .api import *
16+
17+
18+
19+
def main(args):
20+
builder_args = BuilderArgs.from_args(args)
21+
speculative_builder_args = BuilderArgs.from_speculative_args(args)
22+
tokenizer_args = TokenizerArgs.from_args(args)
23+
generator_args = GeneratorArgs.from_args(args)
24+
generator_args.chat_mode = False
25+
26+
@st.cache_resource
27+
def initialize_generator() -> OpenAIAPIGenerator:
28+
return OpenAIAPIGenerator(
29+
builder_args,
30+
speculative_builder_args,
31+
tokenizer_args,
32+
generator_args,
33+
args.profile,
34+
args.quantize,
35+
args.draft_quantize,
36+
)
37+
38+
gen = initialize_generator()
39+
40+
tokens_generated = 0
41+
st.title("Simple chat")
42+
43+
# Initialize chat history
44+
if "messages" not in st.session_state:
45+
st.session_state.messages = []
46+
47+
# Display chat messages from history on app rerun
48+
for message in st.session_state.messages:
49+
with st.chat_message(message["role"]):
50+
st.markdown(message["content"])
51+
52+
# Accept user input
53+
if prompt := st.chat_input("What is up?"):
54+
# Add user message to chat history
55+
st.session_state.messages.append({"role": "user", "content": prompt})
56+
# Display user message in chat message container
57+
with st.chat_message("user"):
58+
st.markdown(prompt)
59+
60+
# Display assistant response in chat message container
61+
with st.chat_message("assistant"):
62+
63+
req = CompletionRequest(
64+
model = gen.builder_args.checkpoint_path,
65+
prompt = prompt,
66+
temperature = generator_args.temperature,
67+
messages = [],
68+
)
69+
70+
def unwrap(completion_generator):
71+
for chunk_response in completion_generator:
72+
content = chunk_response.choices[0].delta.content
73+
if not gen.is_llama3_model or content not in set(gen.tokenizer.special_tokens.keys()):
74+
yield content
75+
yield "."
76+
77+
response = st.write_stream(unwrap(gen.completion(req)))
78+
79+
# Add assistant response to chat history
80+
st.session_state.messages.append({"role": "assistant", "content": response})
81+
82+
83+
84+
85+
if __name__ == "__main__":
86+
parser = argparse.ArgumentParser(description="torchchat generate CLI")
87+
verb = "generate"
88+
add_arguments_for_verb(parser, verb)
89+
args = parser.parse_args()
90+
check_args(args, verb)
91+
args = arg_init(args)
92+
main(args)

cli.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
).expanduser()
2525

2626

27-
KNOWN_VERBS = ["chat", "browser", "download", "generate", "eval", "export", "list", "server", "remove", "where"]
27+
KNOWN_VERBS = ["chat", "browser", "download", "generate", "eval", "export", "list","remove", "where"]
2828

2929
# Handle CLI arguments that are common to a majority of subcommands.
3030
def check_args(args, verb: str) -> None:
@@ -256,12 +256,6 @@ def add_arguments_for_verb(parser, verb: str):
256256
default=default_model_dir,
257257
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}",
258258
)
259-
parser.add_argument(
260-
"--port",
261-
type=int,
262-
default=5000,
263-
help="Port for the web server in browser mode",
264-
)
265259

266260

267261
def arg_init(args):

generate.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,6 @@ def __init__(
228228
generator_args.validate_build(self.builder_args)
229229
generator_args.validate_build(self.speculative_builder_args, "draft model")
230230

231-
for _ in self.chat(generator_args):
232-
continue
233231

234232
def multinomial_sample_one_no_sync(
235233
self,
@@ -337,10 +335,11 @@ def decode_n_tokens(
337335
input_pos += 1
338336
new_tokens.append(next_token.clone())
339337
callback(new_tokens[-1], done_generating=_i == num_new_tokens - 2)
340-
if need_probs:
338+
if need_probs or not next_prob:
339+
yield cur_token.clone(), None
340+
else:
341341
new_probs.append(next_prob.clone())
342-
343-
yield cur_token.clone(), next_prob.clone()
342+
yield cur_token.clone(), next_prob.clone()
344343
cur_token = next_token
345344

346345
# encountered eos
@@ -365,7 +364,7 @@ def decode_n_tokens(
365364
model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs
366365
)
367366
input_pos += 1
368-
yield eos_token.clone(), next_prob.clone()
367+
yield eos_token.clone(), (next_prob.clone() if next_prob else None)
369368

370369
# return new_tokens, new_probs
371370

@@ -450,9 +449,7 @@ def generate(
450449
speculate_k: Optional[int] = 8,
451450
sequential_prefill=True,
452451
callback=lambda x: x,
453-
tokenizer=None,
454452
max_seq_length: int,
455-
is_llama3_model: bool = False,
456453
**sampling_kwargs,
457454
) -> torch.Tensor:
458455
"""
@@ -536,8 +533,8 @@ def generate(
536533
max_new_tokens - 1,
537534
callback=callback,
538535
need_probs=False,
539-
eos_token_id=tokenizer.eos_id() if tokenizer else 2,
540-
eot_id=tokenizer.special_tokens["<|eot_id|>"] if is_llama3_model else None,
536+
eos_token_id=self.tokenizer.eos_id() if self.tokenizer else 2,
537+
eot_id=self.tokenizer.special_tokens["<|eot_id|>"] if self.is_llama3_model else None,
541538
**sampling_kwargs,
542539
):
543540
generated_tokens.append(generated_token)
@@ -555,19 +552,20 @@ def generate(
555552
return seq, generate_stats
556553

557554

558-
def encode_tokens(self, tokenizer, string, bos=True, device="cpu"):
555+
def encode_tokens(self, string, bos=True, device="cpu"):
559556
tokens = self.tokenizer.encode(string)
560557
if bos:
561-
tokens = [tokenizer.bos_id()] + tokens
558+
tokens = [self.tokenizer.bos_id()] + tokens
562559
return torch.tensor(tokens, dtype=torch.int, device=device)
563560

564-
def _callback(self, x, *, buffer, period_id, done_generating, tokenizer, is_llama3_model):
561+
def _callback(self, x, *, buffer, done_generating):
562+
period_id = self.tokenizer.encode(".")[0]
565563
buffer.append(
566564
self.tokenizer.decode([period_id] + x.tolist())[1:]
567565
) # I think this results in the first output token being dropped from the display which is wrong.
568566
if x.item() == self.tokenizer.eos_id():
569567
done_generating = True
570-
if is_llama3_model and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]:
568+
if self.is_llama3_model and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]:
571569
done_generating = True
572570
buffer = buffer[:-1] # drop the eot_id from the output buffer
573571
if len(buffer) == 4 or done_generating:
@@ -581,7 +579,7 @@ def chat(
581579
):
582580
print("Starting Interactive Chat")
583581
encoded = self.encode_tokens(
584-
self.tokenizer, generator_args.prompt, bos=True, device=self.builder_args.device
582+
generator_args.prompt, bos=True, device=self.builder_args.device
585583
)
586584
logging.debug(encoded)
587585
prompt_length = encoded.size(0)
@@ -664,7 +662,7 @@ def chat(
664662
else:
665663
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
666664
encoded = self.encode_tokens(
667-
self.tokenizer, prompt, bos=True, device=self.builder_args.device
665+
prompt, bos=True, device=self.builder_args.device
668666
)
669667
else:
670668
if self.system_prompt is not None:
@@ -707,10 +705,7 @@ def callback(x, *, done_generating=False):
707705
return self._callback(
708706
x,
709707
buffer=buffer,
710-
period_id=period_id,
711708
done_generating=done_generating,
712-
tokenizer=self.tokenizer,
713-
is_llama3_model=self.is_llama3_model,
714709
)
715710

716711
else:
@@ -722,10 +717,7 @@ def callback(x, *, done_generating=False):
722717
return self._callback(
723718
x,
724719
buffer=buffer,
725-
period_id=period_id,
726720
done_generating=done_generating,
727-
tokenizer=self.tokenizer,
728-
is_llama3_model=self.is_llama3_model,
729721
)
730722

731723
if (i != generator_args.num_samples - 1 or not self.profile) or (
@@ -751,9 +743,7 @@ def callback(x, *, done_generating=False):
751743
top_k=generator_args.top_k,
752744
sequential_prefill=generator_args.sequential_prefill,
753745
start_pos=start_pos,
754-
tokenizer=self.tokenizer,
755746
max_seq_length=max_seq_length,
756-
is_llama3_model=self.is_llama3_model,
757747
):
758748
if metrics:
759749
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])

0 commit comments

Comments
 (0)