Skip to content

Commit b04747b

Browse files
mikekgfbmalfet
authored andcommitted
Update state dict and model together (#573)
* code beautification * code beautification, move functions together * rewrite model rewriter * rewrite quantizers * weights is none check * typo * not weight -> weight is not None * fix dimensions for parallel prefill * test * typo * bfloat16 on ARM with MacOS 14 * precision for a8w4 * sdpa_kv * fixes * inline qlq definition * trial and error * qdq not working * ci * not so fast with bf16=fast * typo, and handle fast across maxcos version... * typo * type cast
1 parent a5d83fc commit b04747b

File tree

9 files changed

+131
-95
lines changed

9 files changed

+131
-95
lines changed

.github/workflows/pull.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ jobs:
631631
test-mps:
632632
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
633633
with:
634-
runner: macos-m1-stable
634+
runner: macos-m1-stable # neeps MPS, was macos-m1-stable
635635
script: |
636636
set -x
637637
# NS: Remove previous installation of torch first
@@ -740,7 +740,7 @@ jobs:
740740
test-mps-dtype:
741741
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
742742
with:
743-
runner: macos-m1-stable
743+
runner: macos-m1-stable # needs MPS, was macos-m1-stable
744744
script: |
745745
set -x
746746
# NS: Remove previous installation of torch first
@@ -918,7 +918,13 @@ jobs:
918918
919919
python torchchat.py export stories15M --output-pte-path ./model.pte
920920
./cmake-out/et_run ./model.pte -z ./tokenizer.bin -t 0 -i "${PRMT}"
921-
921+
922+
for dtype in fp32 fp16; do # bf16 needs to be supported
923+
echo "Testing export + runner with dtype=$dtype"
924+
python torchchat.py export stories15M --dtype $dtype --output-pte-path ./model.pte
925+
./cmake-out/et_run ./model.pte -z ./tokenizer.bin -t 0 -i "${PRMT}"
926+
done
927+
922928
echo "Tests complete."
923929
runner-aoti:
924930
name: test-runner-aoti (${{ matrix.platform }}, ${{ matrix.model_name }})

build/builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,16 @@ def validate_model(
194194
if model is None:
195195
return
196196

197+
if self.is_tiktoken == self.is_sentencepiece:
198+
raise RuntimeError("no tokenizer was found")
199+
197200
is_tiktoken = self.is_tiktoken
198201
is_sentencepiece = self.is_sentencepiece
199202
use_tiktoken = model.config.use_tiktoken
200203

201204
if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
202205
raise RuntimeError(
203-
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)} for {model_description}"
206+
f"model-specified tokenizer ({tokenizer_setting_to_name(use_tiktoken)}) does not match provided tokenizer ({tokenizer_setting_to_name(is_tiktoken)}) for {model_description}"
204207
)
205208

206209
return

build/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ def __init__(
124124
dtype=None,
125125
):
126126
super().__init__()
127+
print(f"dtype on entry {dtype}")
127128
if not dtype:
128129
dtype = get_precision()
130+
print(f"dtype on get_prec {dtype}")
129131
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
130132
self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype))
131133
self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype))

build/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,12 @@ def get_precision():
134134

135135
def name_to_dtype(name):
136136
if (name == "fast") or (name == "fast16"):
137+
# MacOS now supports bfloat16
137138
import platform
138-
139139
if platform.processor() == "arm":
140-
return torch.float16
141-
else:
142-
return torch.bfloat16
140+
if int(platform.mac_ver()[0].split('.')[0]) < 14:
141+
return torch.float16
142+
return torch.bfloat16
143143

144144
if name in name_to_dtype_dict:
145145
return name_to_dtype_dict[name]

cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def _add_arguments_common(parser):
296296
def arg_init(args):
297297
if not (torch.__version__ > "2.3"):
298298
raise RuntimeError(
299-
"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
299+
f"You are using PyTorch {torch.__version__}. At this time, torchchat uses the latest PyTorch technology with high-performance kernels only available in PyTorch nightly until the PyTorch 2.4 release"
300300
)
301301

302302
if hasattr(args, "quantize") and Path(args.quantize).is_file():

export_et.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901
7070
_skip_type_promotion=bool(target_precision == torch.float16),
7171
)
7272

73-
if target_precision == torch.float16:
73+
if target_precision == torch.float16 or target_precision == torch.bfloat16:
7474
if state_dict_dtype != torch.float16:
7575
print("model.to torch.float16")
7676
model = model.to(dtype=torch.float16)

export_et_util.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,16 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
6363
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
6464
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
6565

66-
q = apply_rotary_emb(q, freqs_cis)
67-
k = apply_rotary_emb(k, freqs_cis)
68-
66+
q = apply_rotary_emb(q, freqs_cis).to(dtype=torch.float)
67+
k = apply_rotary_emb(k, freqs_cis).to(dtype=torch.float)
68+
v = v.to(dtype=torch.float)
69+
6970
# KV cache should always be enabled
7071
assert self.kv_cache is not None
7172
output = torch.ops.llama.sdpa_with_kv_cache(
72-
q.float(),
73-
k.float(),
74-
v.float(),
73+
q,
74+
k,
75+
v,
7576
self.kv_cache.k_cache,
7677
self.kv_cache.v_cache,
7778
input_pos[-1].item(),

qops.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,20 @@
1515
from torch.nn.parameter import Parameter
1616

1717

18-
def linear_int8(input, weight, scales):
18+
def linear_int8_aoti(input, weight, scales):
1919
n_groups = scales.numel() // scales.shape[0]
2020

2121
# we special-case channel-wise, because we know how to make that fast
2222
if n_groups == 1:
23+
scales = scales.view(-1)
2324
if (
2425
torch.compiler.is_compiling()
2526
or input.device.type != "cpu"
2627
or torch.__version__ < "2.4"
2728
):
28-
return F.linear(input, weight.to(dtype=input.dtype)) * scales
29+
lin = F.linear(input, weight.to(dtype=input.dtype))
30+
# print(f"linear shape {lin.shape}, scales shape {scales.shape}")
31+
return lin * scales
2932
# Use int8pack_mm for CPU eager
3033
return torch.ops.aten._weight_int8pack_mm(
3134
input.reshape(-1, input.shape[-1]),
@@ -42,6 +45,55 @@ def linear_int8(input, weight, scales):
4245
)
4346

4447

48+
def _qdq_dynamic_quantized_linear(
49+
x_fp32, x_quant_min, x_quant_max, x_eps,
50+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max,
51+
bias_fp32,
52+
):
53+
x_scale, x_zero_point = torch.ops.quantized_decomposed.choose_qparams(x_fp32, x_quant_min, x_quant_max, x_eps, torch.int8)
54+
x_i8 = torch.ops.quantized_decomposed.quantize_per_tensor(
55+
x_fp32, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
56+
x_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
57+
x_i8, x_scale, x_zero_point, x_quant_min, x_quant_max, torch.int8)
58+
weight_fp32 = torch.ops.quantized_decomposed.dequantize_per_tensor(
59+
weight_i8, weight_scale, weight_zero_point, weight_quant_min, weight_quant_max, torch.int8)
60+
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
61+
return out_fp32
62+
63+
def linear_int8_et(input, weight, scales):
64+
n_groups = scales.numel() // scales.shape[0]
65+
66+
# we special-case channel-wise, because we know how to make that fast
67+
if n_groups == 1:
68+
scales = scales.view(-1)
69+
70+
if True:
71+
lin = F.linear(input, weight.to(dtype=input.dtype))
72+
# print(f"linear shape {lin.shape}, scales shape {scales.shape}")
73+
return lin * scales
74+
75+
return _qdq_dynamic_quantized_linear(
76+
x_fp32=input.float(),
77+
x_quant_min=-128,
78+
x_quant_max=127,
79+
x_eps=torch.finfo(input.dtype).eps,
80+
weight_i8=weight,
81+
weight_scale=scales.float(),
82+
weight_zero_point=0,
83+
weight_quant_min=-128,
84+
weight_quant_max=127,
85+
bias_fp32=None,
86+
).to(dtype=input.dtype)
87+
88+
return F.linear(
89+
input,
90+
(
91+
weight.to(dtype=input.dtype).view(weight.shape[0], n_groups, -1)
92+
* scales.view(weight.shape[0], n_groups, -1)
93+
).view(weight.shape[0], -1),
94+
)
95+
96+
4597
class LinearInt8(nn.Module):
4698
__constants__ = ["in_features", "out_features"]
4799
in_features: int
@@ -68,17 +120,14 @@ def __init__(
68120
if device is None:
69121
device = "cpu"
70122

71-
if device == "einputecutorch":
72-
device = "cpu"
73-
74123
assert not bias, "Bias is not supported by LinearInt8"
75124
self.in_features = in_features
76125
self.out_features = out_features
77126

78-
assert bool(weight) == bool(
79-
scales
127+
assert (weight is None) == bool(
128+
scales is None
80129
), "must specify both weights and scales, or neither"
81-
if not weight:
130+
if weight is None:
82131
weight = torch.empty(
83132
(out_features, in_features), dtype=torch.int8, device=device
84133
)
@@ -91,8 +140,16 @@ def __init__(
91140
self.register_buffer("weight", weight.to(device))
92141
self.register_buffer("scales", scales.to(device))
93142

94-
def forward(self, input: torch.Tensor) -> torch.Tensor:
95-
return linear_int8(input, self.weight, self.scales)
143+
if use_et_backend():
144+
self.forward = self.et_forward
145+
else:
146+
self.forward = self.aoti_forward
147+
148+
def aoti_forward(self, input: torch.Tensor) -> torch.Tensor:
149+
return linear_int8_aoti(input, self.weight, self.scales)
150+
151+
def et_forward(self, input: torch.Tensor) -> torch.Tensor:
152+
return linear_int8_et(input, self.weight, self.scales)
96153

97154

98155
class QuantizedEmbedding(torch.nn.Module):

quantize.py

Lines changed: 36 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
9393
self.model_ = model
9494
self.device = device
9595
self.tokenizer = tokenizer
96-
self.quantizer = quant_api.Int8DynActInt4WeightQuantizer(**kwargs)
96+
self.quantizer = quant_api.Int8DynActInt4WeightQuantizer(
97+
**kwargs, precision=get_precision(), scales_precision=get_precision()
98+
)
9799

98100
def create_quantized_state_dict(self) -> Dict: # "StateDict"
99101
pass
@@ -362,39 +364,6 @@ def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
362364
##### Weight-only int8 per-channel quantized code ######
363365

364366

365-
def replace_linear_weight_only_int8_per_channel(
366-
module, device, node_type, groupsize=None
367-
):
368-
if groupsize is not None and groupsize != 0:
369-
pass
370-
371-
for name, child in module.named_children():
372-
# print(f"name: {name}")
373-
if isinstance(child, nn.Linear):
374-
if (
375-
(node_type == "*")
376-
or (node_type == "output" and name == "output")
377-
or (node_type == "!output" and name != "output")
378-
):
379-
# print(f"{name, child}")
380-
# print(f"in_features: {child.in_features}")
381-
# print(f"out_features: {child.out_features}")
382-
setattr(
383-
module,
384-
name,
385-
WeightOnlyInt8Linear(
386-
in_features=child.in_features,
387-
out_features=child.out_features,
388-
device=device,
389-
groupsize=groupsize,
390-
),
391-
)
392-
else:
393-
replace_linear_weight_only_int8_per_channel(
394-
child, device, node_type, groupsize
395-
)
396-
397-
398367
class WeightOnlyInt8QuantHandler(QuantHandler):
399368
def __init__(
400369
self,
@@ -416,9 +385,11 @@ def __init__(
416385
self.bitwidth = bitwidth
417386

418387
@torch.no_grad()
419-
def create_quantized_state_dict(self) -> Dict:
420-
cur_state_dict = state_dict_device(self.model_.state_dict())
421-
dict_device = "cpu" # self.device
388+
def quantize(self, module):
389+
# cur_state_dict = state_dict_device(self.model_.state_dict())
390+
# dict_device = "cpu" # self.device
391+
392+
device = self.device
422393

423394
if self.bitwidth == 4:
424395
range_min = -8
@@ -429,24 +400,19 @@ def create_quantized_state_dict(self) -> Dict:
429400
else:
430401
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
431402

432-
for fqn, mod in self.model_.named_modules():
433-
# print(f"maybe? quantize {fqn}...{type(mod)}")
434-
if isinstance(mod, torch.nn.Linear):
435-
# print(f"candidate {fqn}, nodetype {self.node_type}")
403+
for name, child in module.named_children():
404+
# print(f"name: {name}")
405+
if isinstance(child, nn.Linear):
436406
if (
437407
(self.node_type == "*")
438-
or (self.node_type == "output" and fqn in ["output", "final_proj"])
439-
or (
440-
self.node_type == "!output"
441-
and fqn not in ["output", "final_proj"]
442-
)
408+
or (self.node_type == "output" and name == "output")
409+
or (self.node_type == "!output" and name != "output")
443410
):
444-
# print(
445-
# f"quantize {self.node_type} {fqn, mod} with groupsize {self.groupsize}, bitwidth {self.bitwidth}"
446-
# )
447-
448-
# print(f"initial weight shape {mod.weight.shape}")
449-
input_weight = mod.weight.float()
411+
# print(f"{name, child}")
412+
input_weight = child.weight.float()
413+
# print(f"{name, child}")
414+
# print(f"in_features: {child.in_features}")
415+
# print(f"out_features: {child.out_features}")
450416

451417
# print(f"expanded weight shape {input_weight.shape}")
452418
weight, scales, _ = dynamically_quantize_per_channel(
@@ -455,28 +421,29 @@ def create_quantized_state_dict(self) -> Dict:
455421
range_max,
456422
torch.int8,
457423
self.groupsize,
458-
scales_dtype=mod.weight.dtype,
424+
scales_dtype=child.weight.dtype,
459425
)
460426

461-
weight = weight.to(device=dict_device)
462-
scales = scales.to(device=dict_device)
463-
cur_state_dict[f"{fqn}.weight"] = weight
464-
# squeeze makes groupsize=rowsize unidimensional
465-
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
466-
467-
return cur_state_dict
427+
setattr(
428+
module,
429+
name,
430+
WeightOnlyInt8Linear(
431+
in_features=child.in_features,
432+
out_features=child.out_features,
433+
device=self.device,
434+
# update variables from quantization
435+
weight=weight,
436+
scales=scales,
437+
groupsize=self.groupsize,
438+
),
439+
)
440+
else:
441+
self.quantize(module)
468442

469-
def convert_for_runtime(self) -> nn.Module:
470-
replace_linear_weight_only_int8_per_channel(
471-
self.model_, self.device, self.node_type, self.groupsize
472-
)
473-
return self.model_
443+
return module
474444

475445
def quantized_model(self) -> nn.Module:
476-
model_updated_state_dict = self.create_quantized_state_dict()
477-
self.convert_for_runtime()
478-
self.model_.load_state_dict(model_updated_state_dict)
479-
return self.model_
446+
return self.quantize(self.model_)
480447

481448

482449
#########################################################################

0 commit comments

Comments
 (0)