Skip to content

Commit ce31856

Browse files
author
Varun Puri
committed
Browser tweaks, fix chat
1 parent 3e77e6f commit ce31856

File tree

3 files changed

+10
-79
lines changed

3 files changed

+10
-79
lines changed

api/api.py

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -235,78 +235,3 @@ def _callback(self, x, *, buffer, done_generating):
235235
done_generating = True
236236
if self.is_llama3_model and x.item() == self.tokenizer.special_tokens["<|eot_id|>"]:
237237
done_generating = True
238-
239-
def main(args):
240-
builder_args = BuilderArgs.from_args(args)
241-
speculative_builder_args = BuilderArgs.from_speculative_args(args)
242-
tokenizer_args = TokenizerArgs.from_args(args)
243-
generator_args = GeneratorArgs.from_args(args)
244-
generator_args.chat_mode = False
245-
246-
@st.cache_resource
247-
def initialize_generator() -> OpenAIAPIGenerator:
248-
return OpenAIAPIGenerator(
249-
builder_args,
250-
speculative_builder_args,
251-
tokenizer_args,
252-
generator_args,
253-
args.profile,
254-
args.quantize,
255-
args.draft_quantize,
256-
)
257-
258-
gen = initialize_generator()
259-
260-
tokens_generated = 0
261-
st.title("Simple chat")
262-
263-
# Initialize chat history
264-
if "messages" not in st.session_state:
265-
st.session_state.messages = []
266-
267-
# Display chat messages from history on app rerun
268-
for message in st.session_state.messages:
269-
with st.chat_message(message["role"]):
270-
st.markdown(message["content"])
271-
272-
# Accept user input
273-
if prompt := st.chat_input("What is up?"):
274-
# Add user message to chat history
275-
st.session_state.messages.append({"role": "user", "content": prompt})
276-
# Display user message in chat message container
277-
with st.chat_message("user"):
278-
st.markdown(prompt)
279-
280-
# Display assistant response in chat message container
281-
with st.chat_message("assistant"):
282-
283-
req = CompletionRequest(
284-
model = gen.builder_args.checkpoint_path,
285-
prompt = prompt,
286-
temperature = generator_args.temperature,
287-
messages = [],
288-
)
289-
290-
def unwrap(completion_generator):
291-
for chunk_response in completion_generator:
292-
content = chunk_response.choices[0].delta.content
293-
if not gen.is_llama3_model or content not in set(gen.tokenizer.special_tokens.keys()):
294-
yield content
295-
yield "."
296-
297-
response = st.write_stream(unwrap(gen.completion(req)))
298-
299-
# Add assistant response to chat history
300-
st.session_state.messages.append({"role": "assistant", "content": response})
301-
302-
303-
304-
305-
if __name__ == "__main__":
306-
parser = argparse.ArgumentParser(description="torchchat generate CLI")
307-
verb = "generate"
308-
add_arguments_for_verb(parser, verb)
309-
args = parser.parse_args()
310-
check_args(args, verb)
311-
args = arg_init(args)
312-
main(args)

api/browser.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def initialize_generator() -> OpenAIAPIGenerator:
3838
gen = initialize_generator()
3939

4040
tokens_generated = 0
41-
st.title("Simple chat")
41+
st.title("TorchChat")
4242

4343
# Initialize chat history
4444
if "messages" not in st.session_state:
@@ -58,7 +58,7 @@ def initialize_generator() -> OpenAIAPIGenerator:
5858
st.markdown(prompt)
5959

6060
# Display assistant response in chat message container
61-
with st.chat_message("assistant"):
61+
with st.chat_message("assistant"), st.status("Generating... ", expanded=True) as status:
6262

6363
req = CompletionRequest(
6464
model = gen.builder_args.checkpoint_path,
@@ -68,11 +68,16 @@ def initialize_generator() -> OpenAIAPIGenerator:
6868
)
6969

7070
def unwrap(completion_generator):
71+
start = time.time()
72+
tokcount = 0
7173
for chunk_response in completion_generator:
7274
content = chunk_response.choices[0].delta.content
7375
if not gen.is_llama3_model or content not in set(gen.tokenizer.special_tokens.keys()):
7476
yield content
75-
yield "."
77+
if content == gen.tokenizer.eos_id():
78+
yield "."
79+
tokcount+=1
80+
status.update(label="Done, averaged {:.2f} tokens/second".format(tokcount/(time.time()-start)), state="complete")
7681

7782
response = st.write_stream(unwrap(gen.completion(req)))
7883

generate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -823,7 +823,8 @@ def main(args):
823823
args.quantize,
824824
args.draft_quantize,
825825
)
826-
gen.chat(generator_args)
826+
for _ in gen.chat(generator_args):
827+
pass
827828

828829

829830
if __name__ == "__main__":

0 commit comments

Comments
 (0)