Skip to content

Code formatting #457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,17 @@ def cuda(self):

class GenericGPTQRunner(fx.Interpreter):
"""
This is a generic GPTQ runner that takes an existing model and applies GPTQ.
It uses torch._dynamo.export to obtain a graph of the model and then hooks
into function calls and when it detects a linear, it applies GPTQ to the weight
given the calibration of inputs passed in at initialization. It puts the results
into the state_dict so that the quantized model weights/qparams can be loaded
directly into the model.
This is a generic GPTQ runner that takes an existing model and
applies GPTQ. It uses torch._dynamo.export to obtain a graph of
the model and then hooks into function calls and when it detects a
linear, it applies GPTQ to the weight given the calibration of
inputs passed in at initialization. It puts the results into the
state_dict so that the quantized model weights/qparams can be
loaded directly into the model.

This class is expected to work in concert with a GPTQSimpleQuantizer
class to define the specific type of quantization being done.

"""

def __init__(
Expand Down Expand Up @@ -206,7 +208,7 @@ def get_quantized_state_dict(self):
self.gptq_done
), "need to run GPTQRunner before you can get_quantized_state_dict"
quantized_state_dict = self.new_state_dict
# Don't want to store/load the kv_cache so remove it from the state_dict

del_list = []
for param_fqn in quantized_state_dict:
if "kv_cache" in param_fqn:
Expand All @@ -224,7 +226,8 @@ def tensors_to_cuda(args):

# flatten args and kwargs together
flat_args, spec = tree_flatten((args, kwargs))
# move all single tensors to cuda, will move MultiInputs to cuda one at a time
# move all single tensors to cuda, will move MultiInputs
# to cuda one at a time
flat_args = tensors_to_cuda(flat_args)

has_multi_input = MultiInput in [type(x) for x in flat_args]
Expand Down Expand Up @@ -421,8 +424,9 @@ def faster_quant(self, H, W):
if all_qparams == []:
all_qparams.append(cur_qparams)

# convert a list of qparams objects into a single one. enerally by
# concatenating a bunch of n,1 scale/zeros tensors into a n,num_groups tensor
# convert a list of qparams objects into a single
# one. generally by concatenating a bunch of n,1 scale/zeros
# tensors into a n,num_groups tensor
all_qparams = self.combine_qparams_list_func(all_qparams)
Q = self.quantize_func(DQ, all_qparams)
return Q, DQ.to(orig_dtype), all_qparams
8 changes: 5 additions & 3 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,13 @@ def validate_model(
if model is None:
return

condition = False # not (self.is_tiktoken == model.config.use_tiktoken) or not (self.is_sentencepiece == not model.config.use_tiktoken)
is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
use_tiktoken = model.config.use_tiktoken

if condition:
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
raise RuntimeError(
f"model-specified tokenizer ({tokenizer_setting_to_name(model.config.use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(self.is_tiktoken)} for {model_description}"
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)} for {model_description}"
)

return
Expand Down
2 changes: 1 addition & 1 deletion build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from quantize import pack_scales_and_zeros, WeightOnlyInt4Linear

from build.gguf_util import Q4_0, to_float
from .model import ModelArgs, Transformer
from build.model import ModelArgs, Transformer

logger: logging.Logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class ModelArgs:
norm_eps: float = 1e-5
multiple_of: int = 256
ffn_dim_multiplier: Optional[int] = None
use_tiktoken: Optional[bool] = None
use_tiktoken: bool = False

def __post_init__(self):
if self.n_local_heads == -1:
Expand Down
3 changes: 2 additions & 1 deletion chat_in_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def main():

@app.route("/chat", methods=["GET", "POST"])
def chat():
# Retrieve the HTTP POST request parameter value from 'request.form' dictionary
# Retrieve the HTTP POST request parameter value from
# 'request.form' dictionary
_prompt = request.form.get("prompt", "")
proc.stdin.write((_prompt + "\n").encode("utf-8"))
proc.stdin.flush()
Expand Down
3 changes: 1 addition & 2 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from build.utils import allowable_dtype_names, allowable_params_table
from download import download_and_convert, is_model_downloaded

# CPU is always available and also exportable to ExecuTorch
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
default_device = "cpu"


# Handle CLI arguments that are common to a majority of subcommands.
Expand Down
7 changes: 4 additions & 3 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ def download_and_convert(
model_config = resolve_model_config(model)
model_dir = models_dir / model_config.name

# Download into a temporary directory. We'll move to the final location once
# the download and conversion is complete. This allows recovery in the event
# that the download or conversion fails unexpectedly.
# Download into a temporary directory. We'll move to the final
# location once the download and conversion is complete. This
# allows recovery in the event that the download or conversion
# fails unexpectedly.
temp_dir = models_dir / "downloads" / model_config.name
if os.path.isdir(temp_dir):
shutil.rmtree(temp_dir)
Expand Down
3 changes: 2 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
max_seq_length = min(T_new, model.config.block_size)

device, dtype = prompt.device, prompt.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
# create an empty tensor of the expected final shape and
# fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = prompt
seq = empty
Expand Down
13 changes: 5 additions & 8 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def validate_build(
)

@classmethod
def from_args(cls, args): # -> GeneratorArgs:
def from_args(cls, args):
return cls(
prompt=args.prompt,
encoded_prompt=None,
Expand Down Expand Up @@ -326,7 +326,8 @@ def generate(
is_speculative = draft_model is not None
device, dtype = prompt.device, prompt.dtype

# create an empty tensor of the expected final shape and fill in the current tokens
# create an empty tensor of the expected final shape and
# fill in the current tokens
T = prompt.size(0)
max_new_tokens = min(max_new_tokens, max_seq_length - start_pos - T)
T_new = T + max_new_tokens
Expand All @@ -338,7 +339,8 @@ def generate(
if is_speculative and draft_model is not model:
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

# create an empty tensor of the expected final shape and fill in the current tokens
# create an empty tensor of the expected final shape and
# fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = prompt
seq = empty
Expand Down Expand Up @@ -461,8 +463,6 @@ def _main(
is_speculative = speculative_builder_args.checkpoint_path is not None

if generator_args.chat_mode and not builder_args.is_chat_model:
# This is not a log message, it's a dangerous condition message
# that we must ensure is displayed
print(
"""
*******************************************************
Expand All @@ -486,8 +486,6 @@ def _main(
builder_args.setup_caches = False
model = _initialize_model(builder_args, quantize, tokenizer)

# will add a version of _initialize_model in future
# (need additional args)
if is_speculative:
draft_model = _initialize_model(
speculative_builder_args,
Expand Down Expand Up @@ -533,7 +531,6 @@ def _main(
decode_one_token, mode="reduce-overhead", fullgraph=True
)

# Uncomment to squeeze more perf out of prefill
if generator_args.compile_prefill:
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)

Expand Down