Skip to content

Commit ff6207f

Browse files
committed
up
1 parent 7c77a5b commit ff6207f

File tree

6 files changed

+39
-39
lines changed

6 files changed

+39
-39
lines changed

backends/xnnpack/operators/quant_params.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,13 +185,11 @@ def from_q_dq_node(
185185
quant_node_args = extract_qdq_affine_op_args_for_decomposed_ops(quant_node)
186186

187187
scale = quant_node_args[1]
188-
zp = quant_node_args[2] if len(quant_node_args) > 2 else None
188+
zp = quant_node_args[2]
189189
axis = 0
190190
if per_channel:
191191
assert isinstance(scale, torch.fx.Node) and isinstance(scale.target, str)
192-
assert zp is None or (
193-
isinstance(zp, torch.fx.Node) and isinstance(zp.target, str)
194-
)
192+
assert isinstance(zp, torch.fx.Node) and isinstance(zp.target, str)
195193
assert (
196194
ep is not None
197195
), "ExportedProgram must be provided to extract per channel params"
@@ -202,11 +200,7 @@ def _get_tensor(node):
202200
return cast(torch.Tensor, param)
203201

204202
scale = _get_tensor(scale)
205-
zp = (
206-
_get_tensor(zp)
207-
if zp is not None
208-
else torch.zeros_like(scale, dtype=torch.int8)
209-
)
203+
zp = _get_tensor(zp)
210204
axis = cast(int, quant_node_args[3])
211205

212206
if _groupwise:

backends/xnnpack/utils/quant_utils.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,18 @@ def is_dynamic_qdq(node: torch.fx.Node) -> bool:
5858
node_input_args = extract_qdq_affine_op_args_for_decomposed_ops(node)
5959

6060
scale = node_input_args[1]
61-
if not isinstance(scale, torch.fx.Node):
61+
zp = node_input_args[2]
62+
if not (isinstance(scale, torch.fx.Node) and isinstance(zp, torch.fx.Node)):
6263
return False
63-
if not (scale.target == operator.getitem):
64-
return False
65-
scale_choose_qparam = scale.all_input_nodes[0]
66-
if not is_qparam(scale_choose_qparam):
64+
65+
if not (scale.target == operator.getitem and zp.target == operator.getitem):
6766
return False
6867

69-
if len(node_input_args) > 2:
70-
zp = node_input_args[2]
71-
if not isinstance(zp, torch.fx.Node):
72-
return False
68+
scale_choose_qparam = scale.all_input_nodes[0]
69+
zp_choose_qparam = zp.all_input_nodes[0]
7370

74-
if not (zp.target == operator.getitem):
75-
return False
76-
zp_choose_qparam = zp.all_input_nodes[0]
77-
if not is_qparam(zp_choose_qparam):
78-
return False
71+
if not (is_qparam(scale_choose_qparam) and is_qparam(zp_choose_qparam)):
72+
return False
7973

8074
return True
8175

@@ -229,7 +223,7 @@ def extract_qdq_affine_op_args_for_decomposed_ops(node: torch.fx.Node):
229223
# add target_dtype_node after quant_min/quant_max
230224
args.append(target_dtype)
231225
# zero_point_domain
232-
if len(node.args) > 7 and node.args[7] not in ["INT", "NONE"]:
226+
if len(node.args) > 7 and node.args[7] != "INT":
233227
return None, None
234228

235229
if is_per_channel_group(node):

examples/models/llama/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ python -m examples.models.llama.export_llama \
416416
```
417417
418418
A few notes:
419-
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized with weight zeros or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and uses weight zeros (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32, but is quantized with scales-only. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
419+
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
420420
- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers.
421421
422422
Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.

examples/models/llama/source_transformation/quantize.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,13 @@ def quantize( # noqa C901
112112
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
113113
bitwidth = int(matches[0][0])
114114

115-
from torchao.experimental.quant_api import Int8DynamicActivationIntxWeightConfig
116-
from torchao.quantization.granularity import PerGroup, PerRow
117-
from torchao.quantization.quant_api import quantize_
115+
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
116+
from torchao.quantization.granularity import PerAxis, PerGroup
117+
from torchao.quantization.quant_api import (
118+
Int8DynamicActivationIntxWeightConfig,
119+
MappingType,
120+
quantize_,
121+
)
118122
from torchao.utils import unwrap_tensor_subclass
119123

120124
with torch.no_grad():
@@ -124,8 +128,11 @@ def quantize( # noqa C901
124128
model,
125129
Int8DynamicActivationIntxWeightConfig(
126130
weight_dtype=getattr(torch, f"int{bitwidth}"),
127-
granularity=(PerRow() if group_size == 0 else PerGroup(group_size)),
128-
has_weight_zeros=False,
131+
granularity=(
132+
PerAxis(0) if group_size == 0 else PerGroup(group_size)
133+
),
134+
mapping_type=MappingType.SYMMETRIC,
135+
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
129136
),
130137
)
131138
model = unwrap_tensor_subclass(model)
@@ -777,38 +784,42 @@ def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
777784
EmbeddingQuantizer,
778785
SharedEmbeddingQuantizer,
779786
)
780-
from torchao.quantization.granularity import PerGroup, PerRow
787+
from torchao.quantization.granularity import PerAxis, PerGroup
788+
from torchao.quantization.quant_api import MappingType
781789

782790
quant_args = args.embedding_quantize.split(":")[1].split(",")
783791
if len(quant_args) == 2:
784792
bitwidth, group_size = quant_args
785-
has_weight_zeros = True
793+
is_asymmetric = True
786794
else:
787-
bitwidth, group_size, has_weight_zeros = quant_args
795+
bitwidth, group_size, is_asymmetric = quant_args
788796

789797
if group_size in ["none", "None", "0"]:
790798
group_size = 0
791799

792800
group_size = int(group_size)
793801
bitwidth = int(bitwidth)
794-
has_weight_zeros = bool(has_weight_zeros)
802+
is_asymmetric = bool(is_asymmetric)
795803
weight_dtype = getattr(torch, f"int{bitwidth}")
796-
granularity = PerRow() if group_size == 0 else PerGroup(group_size)
804+
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
805+
mapping_type = (
806+
MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC
807+
)
797808

798809
def _torchao_embedding_quantizer(model):
799810
with torch.no_grad():
800811
if not args.use_shared_embedding:
801812
EmbeddingQuantizer(
802813
weight_dtype=weight_dtype,
803814
granularity=granularity,
804-
has_weight_zeros=has_weight_zeros,
815+
mapping_type=mapping_type,
805816
use_fallback=False,
806817
).quantize(model)
807818
else:
808819
SharedEmbeddingQuantizer(
809820
weight_dtype=weight_dtype,
810821
granularity=granularity,
811-
has_weight_zeros=has_weight_zeros,
822+
mapping_type=mapping_type,
812823
).quantize(model)
813824
return model
814825

examples/models/phi_4_mini/convert_weights.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
110110

111111

112112
def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
113-
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
113+
# If input_dir_or_checkpoint is a directory downloaded from HF, FullModelHFCheckpointer is used to extract the state dict
114+
# If input_dir_or_checkpoint is a checkpoint (from eager model model), it is loaded directly
114115
if os.path.isdir(input_dir_or_checkpoint):
115116
checkpointer = FullModelHFCheckpointer(
116117
checkpoint_dir=input_dir_or_checkpoint,

third-party/ao

Submodule ao updated 99 files

0 commit comments

Comments
 (0)