Skip to content

Refactor quantize.py functions to remove args #10893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 12 additions & 37 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,23 +1341,11 @@ def _get_source_transforms( # noqa
transformations based on the given checkpoint first. In those cases,
this wil be a no-op.
"""

# Create a mock args object with the necessary attributes
class Args:
pass

args = Args()
args.checkpoint = checkpoint
args.tokenizer_path = tokenizer_path
args.embedding_quantize = embedding_quantize
args.use_shared_embedding = use_shared_embedding
args.use_qat = use_qat
args.use_lora = use_lora
args.preq_mode = preq_mode
args.preq_group_size = preq_group_size
args.preq_embedding_quantize = preq_embedding_quantize

transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
transforms.append(
get_quant_embedding_transform(
embedding_quantize, use_shared_embedding, checkpoint_dtype
)
)

# quantization_mode should be applied after embedding_quantize
# to support shared_embedding
Expand All @@ -1375,30 +1363,17 @@ class Args:
There are cases where this may be a no-op, namely, if all linears are
quantized in the checkpoint.
"""

# Create a mock args object with the necessary attributes
class Args:
pass

args = Args()
args.checkpoint = checkpoint
args.tokenizer_path = tokenizer_path
args.quantization_mode = quantization_mode
args.group_size = group_size
args.use_shared_embedding = use_shared_embedding
args.calibration_tasks = calibration_tasks
args.calibration_limit = calibration_limit
args.calibration_seq_length = calibration_seq_length
args.use_shared_embedding = use_shared_embedding
args.use_qat = use_qat
args.use_lora = use_lora
args.preq_mode = preq_mode

transforms.append(
get_quant_weight_transform(
args=args,
quantization_mode=quantization_mode,
group_size=group_size,
computation_dtype=dtype_override,
checkpoint_dtype=checkpoint_dtype,
checkpoint_path=checkpoint,
tokenizer_path=tokenizer_path,
calibration_tasks=calibration_tasks,
calibration_limit=calibration_limit,
calibration_seq_length=calibration_seq_length,
)
)

Expand Down
55 changes: 26 additions & 29 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def quantize( # noqa C901
checkpoint_dtype: Optional[DType] = None,
checkpoint_path: Optional[Path] = None,
# following arguments only available when setting int4 or gptq quantization.
group_size: Optional[int] = 128,
group_size: Optional[int] = None,
# following arguments are only used for GPTQ
calibration_tasks: Optional[list] = None,
calibration_limit: Optional[int] = None,
Expand Down Expand Up @@ -146,9 +146,9 @@ def quantize( # noqa C901
print("quantized model:", model)
return model
elif qmode == "8da4w":
# Check for required args
if group_size is None:
raise Exception("For 8da4w quantization, group size must be specified.")
# TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
group_size = 128

from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
from torchao.utils import unwrap_tensor_subclass
Expand Down Expand Up @@ -784,16 +784,20 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
############################ Source Transform Start #######################


def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
if args.embedding_quantize.startswith("torchao:"):
def get_quant_embedding_transform(
embedding_quantize: str,
use_shared_embedding: bool = False,
dtype_override: Optional[DType] = None,
):
if embedding_quantize.startswith("torchao:"):
from torchao.experimental.quant_api import (
EmbeddingQuantizer,
SharedEmbeddingQuantizer,
)
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import MappingType

quant_args = args.embedding_quantize.split(":")[1].split(",")
quant_args = embedding_quantize.split(":")[1].split(",")
if len(quant_args) == 2:
bitwidth, group_size = quant_args
is_asymmetric = True
Expand All @@ -814,7 +818,7 @@ def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):

def _torchao_embedding_quantizer(model):
with torch.no_grad():
if not args.use_shared_embedding:
if not use_shared_embedding:
EmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=granularity,
Expand All @@ -831,7 +835,7 @@ def _torchao_embedding_quantizer(model):

return _torchao_embedding_quantizer

bitwidth, group_size = args.embedding_quantize.split(",")
bitwidth, group_size = embedding_quantize.split(",")
if group_size == "none" or group_size == "None" or group_size == "0":
group_size = None
else:
Expand All @@ -848,34 +852,27 @@ def _torchao_embedding_quantizer(model):


def get_quant_weight_transform(
args,
quantization_mode: str,
group_size: Optional[int] = None,
computation_dtype: Optional[DType] = None,
checkpoint_dtype: Optional[DType] = None,
checkpoint_path: Optional[Path] = None,
tokenizer_path: Optional[Path] = None,
calibration_tasks: Optional[list] = None,
calibration_limit: Optional[int] = None,
calibration_seq_length: Optional[int] = None,
):
# If these optional args are None, don't provide them to quantize().
quant_args_str = [
"group_size",
"calibration_tasks",
"calibration_limit",
"calibration_seq_length",
]
arg_dict = vars(args)
quant_args = {
param: val
for param in quant_args_str
if (val := arg_dict.get(param)) is not None
}

return partial(
quantize,
**quant_args,
qmode=args.quantization_mode,
qmode=quantization_mode,
computation_dtype=computation_dtype,
checkpoint_dtype=checkpoint_dtype,
checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None),
tokenizer_path=(
Path(path) if (path := args.tokenizer_path) is not None else None
),
checkpoint_path=(Path(path) if (path := checkpoint_path) is not None else None),
group_size=group_size,
calibration_tasks=calibration_tasks,
calibration_limit=calibration_limit,
calibration_seq_length=calibration_seq_length,
tokenizer_path=(Path(path) if (path := tokenizer_path) is not None else None),
)


Expand Down
11 changes: 10 additions & 1 deletion examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,16 @@ def forward(self, input_pos, embeddings):
"4,32",
]
)
quant_transform = get_quant_weight_transform(args, dtype_override)
quant_transform = get_quant_weight_transform(
quantization_mode=args.quantization_mode,
group_size=args.group_size,
computation_dtype=dtype_override,
checkpoint_path=args.checkpoint_path,
tokenizer_path=args.tokenizer_path,
calibration_tasks=args.calibration_tasks,
calibration_limit=args.calibration_limit,
calibration_seq_length=args.calibration_seq_length,
)
_, quantizers, _ = get_quantizer_and_quant_params(args)
source_transforms = []
if llava.use_sdpa_with_kv_cache_op:
Expand Down
6 changes: 3 additions & 3 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,9 @@ def permute(w, heads):

for i in range(len(llama_instance_list)):
if args.embedding_quantize:
llama_instance_list[i] = get_quant_embedding_transform(args)(
llama_instance_list[i]
)
llama_instance_list[i] = get_quant_embedding_transform(
embedding_quantize=args.embedding_quantize
)(llama_instance_list[i])
llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i])
llama_instance_list[i] = SingleLlama(
llama_instance_list[i].eval(), pte_filename
Expand Down
Loading