Skip to content

set buffer size to 8192 as default, decode precision as a string, lint #476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BuilderArgs:
def __post_init__(self):
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"

if not (
(self.checkpoint_path and self.checkpoint_path.is_file())
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
Expand Down Expand Up @@ -408,10 +408,10 @@ def _initialize_model(
print(f"Time to quantize model: {time.time() - t0q:.02f} seconds")

if builder_args.setup_caches:
# TODO: get this from args?
max_seq_length = 2048
with torch.device(builder_args.device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
model.setup_caches(
max_batch_size=1, max_seq_length=model.config.max_seq_length
)

model.to(dtype=builder_args.precision)

Expand Down
1 change: 1 addition & 0 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ModelArgs:
multiple_of: int = 256
ffn_dim_multiplier: Optional[int] = None
use_tiktoken: bool = False
max_seq_length: int = 8192

def __post_init__(self):
if self.n_local_heads == -1:
Expand Down
3 changes: 1 addition & 2 deletions chat_in_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def create_app(*args):
["python3", "generate.py", *args], stdin=subprocess.PIPE, stdout=subprocess.PIPE
)


@app.route("/")
def main():
print("Starting chat session.")
Expand Down Expand Up @@ -93,7 +92,7 @@ def chat():
# Strip "Model: " from output
model_prefix = "Model: "
if output.startswith(model_prefix):
output = output[len(model_prefix):]
output = output[len(model_prefix) :]

global convo

Expand Down
11 changes: 4 additions & 7 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,16 @@ def _download_hf_snapshot(
ignore_patterns="*safetensors*",
)
except HTTPError as e:
if e.response.status_code == 401: # Missing HuggingFace CLI login.
if e.response.status_code == 401: # Missing HuggingFace CLI login.
print(
"Access denied. Create a HuggingFace account and run 'pip3 install huggingface_hub' and 'huggingface-cli login' to authenticate.",
file=sys.stderr
file=sys.stderr,
)
exit(1)
elif e.response.status_code == 403: # No access to the specific model.
elif e.response.status_code == 403: # No access to the specific model.
# The error message includes a link to request access to the given model. This prints nicely and does not include
# a traceback.
print(
str(e),
file=sys.stderr
)
print(str(e), file=sys.stderr)
exit(1)
else:
raise e
Expand Down
137 changes: 97 additions & 40 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, List
from typing import List, Optional, Tuple

import torch
import torch._dynamo.config
Expand All @@ -32,6 +32,7 @@
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"


class ChatFormat:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
Expand Down Expand Up @@ -62,7 +63,6 @@ def encode_dialog_prompt(self, dialog) -> List[int]:
return tokens



@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
Expand Down Expand Up @@ -210,11 +210,17 @@ def decode_n_tokens(
):
new_tokens, new_probs = [], []
encountered_eos = False
for i in range(num_new_tokens - 1): # -1 to save space to run an EoS if dont generate it naturally
for i in range(
num_new_tokens - 1
): # -1 to save space to run an EoS if dont generate it naturally
# Actually better for Inductor to codegen attention here
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
next_token, next_prob = decode_one_token(
model, cur_token.clone(), input_pos, need_probs=need_probs, **sampling_kwargs
model,
cur_token.clone(),
input_pos,
need_probs=need_probs,
**sampling_kwargs,
)
input_pos += 1
new_tokens.append(next_token.clone())
Expand All @@ -223,15 +229,25 @@ def decode_n_tokens(
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)
# encountered eos
if (next_token.item() == eos_token_id or (eot_id is not None and next_token.item() == eot_id)):
if next_token.item() == eos_token_id or (
eot_id is not None and next_token.item() == eot_id
):
encountered_eos = True
_, _ = decode_one_token(model, cur_token, input_pos, need_probs, **sampling_kwargs)
_, _ = decode_one_token(
model, cur_token, input_pos, need_probs, **sampling_kwargs
)
input_pos += 1
break
if not encountered_eos:
eos_token = torch.tensor([eos_token_id if eot_id is None else eot_id], dtype=cur_token.dtype, device=cur_token.device)
eos_token = torch.tensor(
[eos_token_id if eot_id is None else eot_id],
dtype=cur_token.dtype,
device=cur_token.device,
)
new_tokens.append(eos_token.clone())
_, _ = decode_one_token(model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs)
_, _ = decode_one_token(
model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs
)
input_pos += 1

return new_tokens, new_probs
Expand Down Expand Up @@ -337,7 +353,9 @@ def generate(
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if is_speculative and draft_model is not model:
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
draft_model.setup_caches(
max_batch_size=1, max_seq_length=max_seq_length
)

# create an empty tensor of the expected final shape and
# fill in the current tokens
Expand Down Expand Up @@ -366,7 +384,9 @@ def generate(

num_tokens_generated = 0
input_pos = torch.tensor([start_pos + T], device=device, dtype=torch.int)
accept_counts = [0] * (speculate_k + 1) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long
accept_counts = [0] * (
speculate_k + 1
) # creates array of [0, 0, 0, ...] that is speculate_k + 1 long

if is_speculative:
input_pos = input_pos.item() # for speculative decoding easier to keep on host
Expand All @@ -392,12 +412,14 @@ def generate(
max_new_tokens - 1,
callback=callback,
need_probs=False,
eos_token_id = tokenizer.eos_id() if tokenizer else 2,
eot_id = tokenizer.special_tokens["<|eot_id|>"] if is_llama3_model else None,
eos_token_id=tokenizer.eos_id() if tokenizer else 2,
eot_id=tokenizer.special_tokens["<|eot_id|>"] if is_llama3_model else None,
**sampling_kwargs,
)
seq[T + 1 : T + 1 + len(generated_tokens)] = torch.cat(generated_tokens)
seq = seq[:T + 1 + len(generated_tokens)] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.
seq = seq[
: T + 1 + len(generated_tokens)
] # If we dont generate all the way to max_new_tokens slice off the extra space we allocated.

generate_stats = {"accept_counts": accept_counts}
return seq, generate_stats
Expand All @@ -410,7 +432,6 @@ def encode_tokens(tokenizer, string, bos=True, device="cpu"):
return torch.tensor(tokens, dtype=torch.int, device=device)



def get_device_info(name: str) -> str:
import platform
from subprocess import check_output
Expand Down Expand Up @@ -481,7 +502,9 @@ def _main(
# Piggy backing off of this flag then for now to identify llama3 without prompting user.
is_llama3_model = tokenizer_args.is_tiktoken
if generator_args.chat_mode and is_llama3_model:
logging.debug("Llama3 model detected in chat mode. Using updated sentence schemas")
logging.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)

builder_args.setup_caches = False
model = _initialize_model(builder_args, quantize, tokenizer)
Expand Down Expand Up @@ -534,20 +557,29 @@ def _main(
if generator_args.compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

system_prompt=None
system_prompt = None
# Set up our max_seq_length
if generator_args.chat_mode:
max_seq_length = 2048
print(f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye")
system_prompt = input("System Prompt [Optional]: ")
max_seq_length = model.config.max_seq_length
print(
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
)
get_system_prompt = input(
"Do you want to enter a system prompt? Enter y for yes and anything else for no. \n"
)
if get_system_prompt == "y" or get_system_prompt == "Y":
system_prompt = input("What is your system prompt? \n")
if is_llama3_model:
chat_formatter = ChatFormat(tokenizer)
else:
max_seq_length = min(encoded.size(0) + generator_args.max_new_tokens, model.config.block_size)

max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens, model.config.block_size
)

max_seq_length = (
max_seq_length + speculate_k + 1 if draft_model is not None else max_seq_length
max_seq_length + speculative_builder_args.speculate_k + 1
if draft_model is not None
else max_seq_length
)

aggregate_metrics = {
Expand All @@ -557,39 +589,59 @@ def _main(
start = -1 if generator_args.compile else 0
start_pos = 0


# arbitrarily large number as chat mode goes until max_seq length or user exits
num_samples = generator_args.num_samples if not generator_args.chat_mode else 100000
i = -1 # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
while (i < num_samples):
i = (
-1
) # long loop and Im scared someone will add a continue in it, so start at -1 and increment at the start
while i < num_samples:
i += 1
device_sync(device=builder_args.device)
if i >= 0 and generator_args.chat_mode:
prompt = input("User: ")
if (prompt == "/bye"):
if prompt == "/bye":
print("Exiting Chat.\n")
break
if not is_llama3_model:
if system_prompt:
prompt = f"{B_INST} {B_SYS}\n{system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}"
system_prompt = None # can only provide system prompt on first interaction
system_prompt = (
None # can only provide system prompt on first interaction
)
else:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(
tokenizer, prompt, bos=True, device=builder_args.device
)
else:
if system_prompt:
encoded = chat_formatter.encode_dialog_prompt([{"role" : "system", "content" : system_prompt}, {"role" : "user", "content" : prompt}])
if system_prompt is not None:
encoded = chat_formatter.encode_dialog_prompt(
[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
)
system_prompt = None
elif(i == 0):
encoded = chat_formatter.encode_dialog_prompt([{"role" : "user", "content" : prompt}])
elif i == 0:
encoded = chat_formatter.encode_dialog_prompt(
[{"role": "user", "content": prompt}]
)
else:
encoded = chat_formatter.encode_message({"role" : "user", "content" : prompt})
encoded.extend(chat_formatter.encode_header({"role": "assistant", "content": ""}))
encoded = torch.tensor(encoded, dtype=torch.int, device=builder_args.device)
if (encoded.size(0) + start_pos > max_seq_length):
print("This prompt would take us past the max_seq_length. Ending Conversation.")
encoded = chat_formatter.encode_message(
{"role": "user", "content": prompt}
)
encoded.extend(
chat_formatter.encode_header(
{"role": "assistant", "content": ""}
)
)
encoded = torch.tensor(
encoded, dtype=torch.int, device=builder_args.device
)
if encoded.size(0) + start_pos > max_seq_length:
print(
"This prompt would take us past the max_seq_length. Ending Conversation."
)
break

if generator_args.chat_mode and i >= 0:
Expand All @@ -604,12 +656,17 @@ def callback(
):
if done_generating:
return
buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) # I think this results in the first output token being dropped from the display which is wrong.
buffer.append(
tokenizer.decode([period_id] + x.tolist())[1:]
) # I think this results in the first output token being dropped from the display which is wrong.
if x.item() == tokenizer.eos_id():
done_generating = True
if (is_llama3_model and x.item() == tokenizer.special_tokens["<|eot_id|>"]):
if (
is_llama3_model
and x.item() == tokenizer.special_tokens["<|eot_id|>"]
):
done_generating = True
buffer = buffer[:-1] # drop the eot_id from the output buffer
buffer = buffer[:-1] # drop the eot_id from the output buffer
if len(buffer) == 4 or done_generating:
print("".join(buffer), end="", flush=True)
buffer.clear()
Expand Down Expand Up @@ -672,7 +729,7 @@ def callback(x):
)
logging.debug(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")

if (start_pos >= max_seq_length):
if start_pos >= max_seq_length:
print("Max Sequence Length Reached. Ending Conversation.")
break

Expand Down
11 changes: 7 additions & 4 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from build.utils import find_multiple, get_precision, use_et_backend
from build.utils import find_multiple, get_precision, name_to_dtype, use_et_backend


#########################################################################
Expand Down Expand Up @@ -97,11 +97,14 @@ def quantized_model(self) -> nn.Module:


class PrecisionHandler(QuantHandler):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, *, dtype):
self.model_ = model
self.device = device
self.tokenizer = tokenizer
self.kwargs = kwargs

if isinstance(dtype, str):
dtype = name_to_dtype(dtype)
self.dtype = dtype

def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass
Expand All @@ -110,7 +113,7 @@ def convert_for_runtime(self) -> nn.Module:
pass

def quantized_model(self) -> nn.Module:
return self.model_.to(device=self.device, **self.kwargs)
return self.model_.to(device=self.device, dtype=self.dtype)


#########################################################################
Expand Down