Skip to content

Commit 50781ac

Browse files
mikekgfbmalfet
authored andcommitted
Import AO quantizer api at top level (#327)
* refactor quantizer entry point quantize_model to be table driven, and scalable * add tokenizer arg consistently * code beautification * refactor and import ao api wholesale * code beautification * tab->spc
1 parent bd0fcfe commit 50781ac

File tree

15 files changed

+250
-197
lines changed

15 files changed

+250
-197
lines changed

.github/workflows/et-gguf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ jobs:
6767
mkdir gguf_files
6868
wget -O gguf_files/llama-2-7b.Q4_0.gguf "https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true"
6969
./llama.cpp/quantize --allow-requantize gguf_files/llama-2-7b.Q4_0.gguf gguf_files/llama-2-7b.Q4_0.requant_F32.gguf F32
70-
wget -O gguf_files/tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
70+
wget -O gguf_files/tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
7171
- name: Run inference
7272
run: |
7373
export GGUF_PATH=${PWD}/gguf_files/llama-2-7b.Q4_0.gguf
74-
export TOKENIZER_PATH=${PWD}/gguf_files/tokenizer.model
74+
export TOKENIZER_PATH=${PWD}/gguf_files/tokenizer.model
7575
export MODEL_NAME=llama-2-7b_Q4_0_gguf
7676
7777
python generate.py --tokenizer-path ${TOKENIZER_PATH} --gguf-path ${GGUF_PATH} --temperature 0 > ${PWD}/output_eager

GPTQ.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
aten = torch.ops.aten
1515

1616
from eval import (
17+
GPTFastEvalWrapper,
1718
setup_cache_padded_seq_input_pos_max_seq_length_for_prefill,
18-
GPTFastEvalWrapper
1919
)
2020

2121

@@ -64,7 +64,6 @@ def __init__(
6464
)
6565
self.pad_calibration_inputs = False
6666

67-
6867
def add_input(self, args):
6968
if self.inputs is None:
7069
self.inputs = [MultiInput([arg]) for arg in args]
@@ -114,7 +113,6 @@ def _model_call(self, inps):
114113
)
115114

116115

117-
118116
class MultiInput:
119117
def __init__(self, inputs):
120118
self.values = list(inputs)
@@ -127,7 +125,9 @@ def __getitem__(self, slice):
127125
return MultiInput(self.values[slice])
128126

129127
def cuda(self):
130-
self.values = [val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values]
128+
self.values = [
129+
val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values
130+
]
131131

132132

133133
class GenericGPTQRunner(fx.Interpreter):
@@ -236,7 +236,14 @@ def tensors_to_cuda(args):
236236
)
237237
transposed_args = list(
238238
zip(
239-
*[x.values if isinstance(x, MultiInput) else [x] * multi_input_count for x in flat_args]
239+
*[
240+
(
241+
x.values
242+
if isinstance(x, MultiInput)
243+
else [x] * multi_input_count
244+
)
245+
for x in flat_args
246+
]
240247
)
241248
)
242249
else:
@@ -245,8 +252,8 @@ def tensors_to_cuda(args):
245252

246253
# check whether we apply GPTQ to this module
247254
quantize_linear = (
248-
(target == aten.linear.default) # if its a linear
249-
and id(args[1]) in self.id_to_name # and if we know the layer name
255+
(target == aten.linear.default) # if its a linear
256+
and id(args[1]) in self.id_to_name # and if we know the layer name
250257
and not skip_quant # and if we weren't told to skip quantization
251258
# and if the skip_layer_func doesn't say we should skip
252259
and not (self.skip_layer_func is not None and self.skip_layer_func(args[1]))
@@ -334,11 +341,14 @@ def SQNR(x, y):
334341
target, (args[0][:2], DQ2, *args[2:]), kwargs, skip_quant=True
335342
)
336343

337-
print("SQNR for output without GPTQ (should be less than above)",
338-
torch.cat([
344+
print(
345+
"SQNR for output without GPTQ (should be less than above)",
346+
torch.cat(
347+
[
339348
SQNR(old.cpu(), old_q.cpu()).unsqueeze(0)
340349
for (old, old_q) in zip(old_out.values, old_q_out.values)
341-
]).mean(),
350+
]
351+
).mean(),
342352
)
343353
return new_out
344354

build/builder.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def from_args(cls, args): # -> TokenizerArgs:
147147
tokenizer_path = args.tokenizer_path
148148
elif args.model: # Using a named, well-known model
149149
model_config = resolve_model_config(args.model)
150-
tokenizer_path = Path(args.model_directory) / model_config.name / "tokenizer.model"
150+
tokenizer_path = (
151+
Path(args.model_directory) / model_config.name / "tokenizer.model"
152+
)
151153
elif args.checkpoint_path:
152154
tokenizer_path = args.checkpoint_path.parent / "tokenizer.model"
153155
elif hasattr(args, "checkpoint_dir") and args.checkpoint_dir:
@@ -297,7 +299,7 @@ def _load_model(builder_args):
297299
def _initialize_model(
298300
builder_args,
299301
quantize,
300-
tokenizer = None,
302+
tokenizer=None,
301303
):
302304
print("Loading model ...")
303305
t0 = time.time()
@@ -364,17 +366,22 @@ def _initialize_model(
364366

365367
return model
366368

369+
367370
def tokenizer_setting_to_name(tiktoken: bool = False) -> str:
368371
return "TikToken" if tiktoken else "SentencePiece"
369372

373+
370374
def validate_args(model: Transformer, tokenizer_args: TokenizerArgs):
371375
use_tiktoken = model.config.use_tiktoken
372376
is_tiktoken = tokenizer_args.is_tiktoken
373377

374378
if use_tiktoken is None:
375379
model.config.use_tiktoken = is_tiktoken
376380
elif use_tiktoken != is_tiktoken:
377-
raise RuntimeError(f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}")
381+
raise RuntimeError(
382+
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)} does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}"
383+
)
384+
378385

379386
def resolve_model_name(model: str) -> str:
380387
# If the provided model name is an alias, retrieve the full path.

build/model.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ def __post_init__(self):
5858
self.hidden_dim = find_multiple(hidden_dim, multiple_of)
5959
self.head_dim = self.dim // self.n_heads
6060
if isinstance(self.use_tiktoken, str):
61-
self.use_tiktoken = (self.use_tiktoken == "True")
62-
61+
self.use_tiktoken = self.use_tiktoken == "True"
6362

6463
@classmethod
6564
def from_params(cls, params_path):
@@ -118,7 +117,6 @@ def from_name(cls, name: str):
118117

119118
class KVCache(nn.Module):
120119
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=None):
121-
# torch.float): # bfloat16 ):
122120
super().__init__()
123121
if not dtype:
124122
dtype = get_precision()
@@ -180,11 +178,6 @@ def setup_caches(self, max_batch_size, max_seq_length):
180178
self.register_buffer("causal_mask", causal_mask, persistent=True)
181179

182180
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
183-
# print ("*")
184-
# print (f"* shape idx: {idx.shape}")
185-
# print (f"* shape pos: {input_pos.shape}")
186-
# print("@")
187-
188181
assert self.freqs_cis is not None, "Caches must be initialized first"
189182
mask = self.causal_mask[None, None, input_pos]
190183
freqs_cis = self.freqs_cis[input_pos]
@@ -194,7 +187,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
194187
x = layer(x, input_pos, freqs_cis, mask)
195188
x = self.norm(x)
196189
logits = self.output(x)
197-
# print(f"******** logits shape: {logits.shape}")
190+
# print(f"logits shape: {logits.shape}")
198191
return logits
199192

200193
@classmethod
@@ -360,7 +353,6 @@ def forward(self, x: Tensor) -> Tensor:
360353
return output * self.weight
361354

362355

363-
# transpsoed first two arguments to align with model in ET
364356
def precompute_freqs_cis(
365357
n_elem: int, seq_len: int, base: int = 10000, dtype=None
366358
) -> Tensor:
@@ -373,7 +365,7 @@ def precompute_freqs_cis(
373365
freqs = torch.outer(t, freqs)
374366
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
375367
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
376-
return cache.to(dtype=dtype) # bfloat16)
368+
return cache.to(dtype=dtype)
377369

378370

379371
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:

chat_in_browser.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,67 @@
11
# -*- coding: UTF-8 -*-
2-
from flask import Flask, render_template, request
3-
from cli import add_arguments_for_generate, arg_init, check_args
4-
from generate import main as generate_main
52
import subprocess
63
import sys
74

5+
from cli import add_arguments_for_generate, arg_init, check_args
6+
from flask import Flask, render_template, request
7+
from generate import main as generate_main
8+
89

910
convo = ""
1011
disable_input = False
1112

13+
1214
def create_app(*args):
1315
app = Flask(__name__)
1416

1517
import subprocess
18+
1619
# create a new process and set up pipes for communication
17-
proc = subprocess.Popen(["python", "generate.py", *args],
18-
stdin=subprocess.PIPE,
19-
stdout=subprocess.PIPE)
20+
proc = subprocess.Popen(
21+
["python", "generate.py", *args], stdin=subprocess.PIPE, stdout=subprocess.PIPE
22+
)
2023

21-
@app.route('/')
24+
@app.route("/")
2225
def main():
2326
output = ""
2427
global disable_input
2528

2629
while True:
2730
line = proc.stdout.readline()
28-
if line.decode('utf-8').startswith("What is your prompt?"):
31+
if line.decode("utf-8").startswith("What is your prompt?"):
2932
break
30-
output += line.decode('utf-8').strip() + "\n"
31-
return render_template('chat.html', convo="Hello! What is your prompt?", disable_input=disable_input)
33+
output += line.decode("utf-8").strip() + "\n"
34+
return render_template(
35+
"chat.html",
36+
convo="Hello! What is your prompt?",
37+
disable_input=disable_input,
38+
)
3239

33-
@app.route('/chat', methods=['GET', 'POST'])
40+
@app.route("/chat", methods=["GET", "POST"])
3441
def chat():
3542
# Retrieve the HTTP POST request parameter value from 'request.form' dictionary
36-
_prompt = request.form.get('prompt', '')
37-
proc.stdin.write((_prompt + "\n").encode('utf-8'))
43+
_prompt = request.form.get("prompt", "")
44+
proc.stdin.write((_prompt + "\n").encode("utf-8"))
3845
proc.stdin.flush()
3946

4047
output = ""
4148
global disable_input
4249

4350
while True:
4451
line = proc.stdout.readline()
45-
if line.decode('utf-8').startswith("What is your prompt?"):
52+
if line.decode("utf-8").startswith("What is your prompt?"):
4653
break
47-
if line.decode('utf-8').startswith("=========="):
54+
if line.decode("utf-8").startswith("=========="):
4855
disable_input = True
4956
break
50-
output += line.decode('utf-8').strip() + "\n"
57+
output += line.decode("utf-8").strip() + "\n"
5158

5259
global convo
5360

5461
if _prompt:
5562
convo += "Your prompt:\n" + _prompt + "\n\n"
5663
convo += "My response:\n" + output + "\n\n"
5764

58-
return render_template('chat.html', convo=convo, disable_input=disable_input)
65+
return render_template("chat.html", convo=convo, disable_input=disable_input)
5966

6067
return app

cli.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
# CPU is always available and also exportable to ExecuTorch
1313
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
1414

15+
1516
def check_args(args, name: str) -> None:
1617
pass
1718

19+
1820
def add_arguments_for_chat(parser):
1921
# Only chat specific options should be here
2022
_add_arguments_common(parser)
@@ -24,10 +26,7 @@ def add_arguments_for_browser(parser):
2426
# Only browser specific options should be here
2527
_add_arguments_common(parser)
2628
parser.add_argument(
27-
"--port",
28-
type=int,
29-
default=5000,
30-
help="Port for the web server in browser mode"
29+
"--port", type=int, default=5000, help="Port for the web server in browser mode"
3130
)
3231
_add_arguments_common(parser)
3332

@@ -122,10 +121,7 @@ def add_arguments(parser):
122121
help="Top-k for sampling",
123122
)
124123
parser.add_argument(
125-
"--temperature",
126-
type=float,
127-
default=0.8,
128-
help="Temperature for sampling"
124+
"--temperature", type=float, default=0.8, help="Temperature for sampling"
129125
)
130126
parser.add_argument(
131127
"--compile",
@@ -204,20 +200,25 @@ def add_arguments(parser):
204200
help="Use the specified ExecuTorch .pte model file",
205201
)
206202
parser.add_argument(
207-
"-d", "--dtype",
203+
"-d",
204+
"--dtype",
208205
default="float32",
209206
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
210207
)
211208
parser.add_argument(
212-
"-v", "--verbose",
209+
"-v",
210+
"--verbose",
213211
action="store_true",
214212
help="Verbose output",
215213
)
216214
parser.add_argument(
217-
"--quantize", type=str, default="{ }", help=(
218-
'Quantization options. pass in as {"<mode>" : {"<argname1>" : <argval1>, "<argname2>" : <argval2>,...},} '+
219-
'modes are: embedding, linear:int8, linear:int4, linear:int4-gptq, linear:int4-hqq, linear:a8w4dq, precision.'
220-
)
215+
"--quantize",
216+
type=str,
217+
default="{ }",
218+
help=(
219+
'Quantization options. pass in as {"<mode>" : {"<argname1>" : <argval1>, "<argname2>" : <argval2>,...},} '
220+
+ "modes are: embedding, linear:int8, linear:int4, linear:int4-gptq, linear:int4-hqq, linear:a8w4dq, precision."
221+
),
221222
)
222223
parser.add_argument(
223224
"--params-table",

download.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from typing import Optional, Sequence
1010

1111
from build.convert_hf_checkpoint import convert_hf_checkpoint
12-
from config.model_config import (
13-
ModelDistributionChannel,
14-
resolve_model_config,
15-
)
12+
from config.model_config import ModelDistributionChannel, resolve_model_config
1613

1714
from requests.exceptions import HTTPError
1815

0 commit comments

Comments
 (0)