Skip to content

Bump torchao pin, adjust llama export to support pre-quantization via quantize_ (phi4-mini load/export) #10142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/xnnpack/operators/op_dynamic_dequantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def define_node(

@register_node_visitor
class OpDequantizeAffine(NodeVisitor):
target = "quant.dequantize_affine.default"
target = "torchao.dequantize_affine.default"

def __init__(self, *args) -> None:
super().__init__(*args)
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/operators/op_dynamic_quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def define_node(

@register_node_visitor
class OpQuantizeAffine(NodeVisitor):
target = "quant.quantize_affine.default"
target = "torchao.quantize_affine.default"

def define_node(
self,
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/operators/op_skip_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class OpChooseQparamsAffine(OpSkipOps):
do nothing if node is choose_qparams_affine.default
"""

target = "quant.choose_qparams_affine.default"
target = "torchao.choose_qparams_affine.default"


@register_node_visitor
Expand Down
6 changes: 3 additions & 3 deletions backends/xnnpack/partition/config/quant_affine_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
try:
import torchao.quantization.quant_primitives # noqa

return torch.ops.quant.quantize_affine.default
return torch.ops.torchao.quantize_affine.default
except:
return None

Expand All @@ -48,7 +48,7 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
try:
import torchao.quantization.quant_primitives # noqa

return torch.ops.quant.dequantize_affine.default
return torch.ops.torchao.dequantize_affine.default
except:
return None

Expand All @@ -60,6 +60,6 @@ def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
try:
import torchao.quantization.quant_primitives # noqa

return torch.ops.quant.choose_qparams_affine.default
return torch.ops.torchao.choose_qparams_affine.default
except:
return None
6 changes: 3 additions & 3 deletions backends/xnnpack/test/ops/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,9 @@ def _test_groupwise_dq_linear(
.export()
.check_count(
{
"torch.ops.quant.choose_qparams_affine.default": 1 * num_linears,
"torch.ops.quant.quantize_affine.default": 1 * num_linears,
"torch.ops.quant.dequantize_affine.default": 2 * num_linears,
"torch.ops.torchao.choose_qparams_affine.default": 1 * num_linears,
"torch.ops.torchao.quantize_affine.default": 1 * num_linears,
"torch.ops.torchao.dequantize_affine.default": 2 * num_linears,
"torch.ops.aten.linear.default": 1 * num_linears,
}
)
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ python -m examples.models.llama.export_llama \
```

A few notes:
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized with weight zeros or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and uses weight zeros (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32, but is quantized with scales-only. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
- If your model shares embedding/unembedding weights (like Llama1B and Llama3B do), you can add `--use_shared_embedding` to take advantage of this and reduce memory. When this option is enabled, you can specify whether embeddings are quantized asymmetrically or not by specifying a third argument. For example, `-E "torchao:4,32,true"` means that the embedding is quantized to 4-bits with group_size=32 and is asymmetric (this is the default behavior if you simply use `-E "torchao:4,32"`), whereas `-E "torchao:4,32,false"` means that the embedding is quantized to 4-bits with group_size=32 and is symmetric. If `--use_shared_embedding` is specified, the unembedding (i.e., the final linear layer) is quantized in the same way, but also uses 8-bit dynamically quantized activations.
- To do channelwise quantization, specify group_size to 0. This works for both linear and embedding layers.

Once the model is exported, we need to build ExecuTorch and the runner with the low-bit kernels.
Expand Down
41 changes: 29 additions & 12 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,24 @@ def quantize( # noqa C901
print("quantized model:", model)
return model
elif qmode.startswith("torchao:8da"):
# Check for required args
if group_size is None:
raise Exception(
"For torchao:8daxw quantization, group size must be specified."
)

pattern = r"torchao:8da(\d+)w"
matches = re.findall(pattern, qmode)
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
bitwidth = int(matches[0][0])

from torchao.experimental.quant_api import Int8DynamicActivationIntxWeightConfig
from torchao.quantization.granularity import PerGroup, PerRow
from torchao.quantization.quant_api import quantize_
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
MappingType,
quantize_,
)
from torchao.utils import unwrap_tensor_subclass

with torch.no_grad():
Expand All @@ -124,8 +134,11 @@ def quantize( # noqa C901
model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=getattr(torch, f"int{bitwidth}"),
granularity=(PerRow() if group_size == 0 else PerGroup(group_size)),
has_weight_zeros=False,
weight_granularity=(
PerAxis(0) if group_size == 0 else PerGroup(group_size)
),
weight_mapping_type=MappingType.SYMMETRIC,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
),
)
model = unwrap_tensor_subclass(model)
Expand Down Expand Up @@ -777,38 +790,42 @@ def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
EmbeddingQuantizer,
SharedEmbeddingQuantizer,
)
from torchao.quantization.granularity import PerGroup, PerRow
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import MappingType

quant_args = args.embedding_quantize.split(":")[1].split(",")
if len(quant_args) == 2:
bitwidth, group_size = quant_args
has_weight_zeros = True
is_asymmetric = True
else:
bitwidth, group_size, has_weight_zeros = quant_args
bitwidth, group_size, is_asymmetric = quant_args

if group_size in ["none", "None", "0"]:
group_size = 0

group_size = int(group_size)
bitwidth = int(bitwidth)
has_weight_zeros = bool(has_weight_zeros)
is_asymmetric = bool(is_asymmetric)
weight_dtype = getattr(torch, f"int{bitwidth}")
granularity = PerRow() if group_size == 0 else PerGroup(group_size)
granularity = PerAxis(0) if group_size == 0 else PerGroup(group_size)
mapping_type = (
MappingType.ASYMMETRIC if is_asymmetric else MappingType.SYMMETRIC
)

def _torchao_embedding_quantizer(model):
with torch.no_grad():
if not args.use_shared_embedding:
EmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
mapping_type=mapping_type,
use_fallback=False,
).quantize(model)
else:
SharedEmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
mapping_type=mapping_type,
).quantize(model)
return model

Expand Down
98 changes: 82 additions & 16 deletions examples/models/phi_4_mini/convert_weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import os
from typing import Dict

import torch
Expand All @@ -7,6 +8,63 @@

from torchtune.training import FullModelHFCheckpointer

_HF_PHI_4_FROM_META = {
"tok_embeddings.weight": "model.embed_tokens.weight",
"norm.weight": "model.norm.weight",
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
"layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight",
"layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight",
"layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight",
"layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
"layers.{}.feed_forward.w1.weight": "model.layers.{}.mlp.gate_proj.weight",
"layers.{}.feed_forward.w3.weight": "model.layers.{}.mlp.up_proj.weight",
"layers.{}.feed_forward.w2.weight": "model.layers.{}.mlp.down_proj.weight",
"output.weight": "lm_head.weight",
}


def phi_4_hf_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from hf's format to Meta's format.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in hf's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _HF_PHI_4_FROM_META.items()}

for key, value in state_dict.items():
if key.endswith("mlp.gate_up_proj.weight"):
# Split the gate_up_proj into gate_proj and up_proj
hidden_dim = value.shape[0] // 2
assert 2 * hidden_dim == value.shape[0]
gate = value[0:hidden_dim, :]
up = value[hidden_dim:, :]
for new_key, new_value in [("gate_proj", gate), ("up_proj", up)]:
new_key = key.replace("gate_up_proj", new_key)
new_key = get_mapped_key(new_key, inverted_mapping_dict)
converted_state_dict[new_key] = new_value
elif key.endswith("self_attn.qkv_proj.weight"):
# Split the qkv_proj into q_proj, k_proj, and v_proj
q_dim = value.shape[1]
kv_dim = (value.shape[0] - q_dim) // 2
assert 2 * kv_dim + q_dim == value.shape[0]
q = value[0:q_dim, :]
k = value[q_dim : (q_dim + kv_dim), :]
v = value[(q_dim + kv_dim) :, :]
for new_key, new_value in [("q_proj", q), ("k_proj", k), ("v_proj", v)]:
new_key = key.replace("qkv_proj", new_key)
new_key = get_mapped_key(new_key, inverted_mapping_dict)
converted_state_dict[new_key] = new_value
else:
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value
return converted_state_dict


# Standard _FROM_META weight mapping of Meta weights to TorchTune.
_PHI_4_FROM_META = {
Expand Down Expand Up @@ -51,22 +109,30 @@ def phi_4_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.T
return converted_state_dict


def convert_weights(input_dir: str, output_file: str) -> None:
# Don't necessarily need to use TorchTune checkpointer, can just aggregate checkpoint files by ourselves.
checkpointer = FullModelHFCheckpointer(
checkpoint_dir=input_dir,
checkpoint_files=[
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
],
output_dir=".",
model_type="PHI4",
)
def convert_weights(input_dir_or_checkpoint: str, output_file: str) -> None:
# If input_dir_or_checkpoint is a directory downloaded from HF, FullModelHFCheckpointer is used to extract the state dict
# If input_dir_or_checkpoint is a checkpoint (from eager model model), it is loaded directly
if os.path.isdir(input_dir_or_checkpoint):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an comment somewhere explicitly detailing that:

  1. FullModelHFCheckpointer is for directory (which would be straight from HF)
  2. phi_4_hf_to_meta is used for single checkpoint, and the use case is for prequantized checkpoints

checkpointer = FullModelHFCheckpointer(
checkpoint_dir=input_dir_or_checkpoint,
checkpoint_files=[
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
],
output_dir=".",
model_type="PHI4",
)
print("Loading checkpoint from directory...")
sd = checkpointer.load_checkpoint()
sd = sd["model"]
print("Converting checkpoint...")
sd = phi_4_tune_to_meta(sd)
else:
print("Loading checkpoint from file...")
sd = torch.load(input_dir_or_checkpoint, map_location="cpu", weights_only=True)
print("Converting checkpoint...")
sd = phi_4_hf_to_meta(sd)

print("Loading checkpoint...")
sd = checkpointer.load_checkpoint()
print("Converting checkpoint...")
sd = phi_4_tune_to_meta(sd["model"])
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")
Expand All @@ -79,7 +145,7 @@ def main():
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
help="Path to directory containing checkpoint files, or path to a single checkpoint file.",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

Expand Down
2 changes: 1 addition & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def keep(op):
try:
# Ops in torch.ops.quant are not always loaded, so we use try/except
# Aliases output, but we need to allow it for XNNPACK
allow_list.append(torch.ops.quant.choose_qparams_affine.default)
allow_list.append(torch.ops.torchao.choose_qparams_affine.default)
except:
pass

Expand Down
2 changes: 1 addition & 1 deletion third-party/ao
Submodule ao updated 103 files
Loading