Skip to content

Commit 00d8784

Browse files
Merge branch 'pytorch:main' into validate_same_dtype
2 parents aacee59 + c913634 commit 00d8784

File tree

4 files changed

+51
-70
lines changed

4 files changed

+51
-70
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 12 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,23 +1341,11 @@ def _get_source_transforms( # noqa
13411341
transformations based on the given checkpoint first. In those cases,
13421342
this wil be a no-op.
13431343
"""
1344-
1345-
# Create a mock args object with the necessary attributes
1346-
class Args:
1347-
pass
1348-
1349-
args = Args()
1350-
args.checkpoint = checkpoint
1351-
args.tokenizer_path = tokenizer_path
1352-
args.embedding_quantize = embedding_quantize
1353-
args.use_shared_embedding = use_shared_embedding
1354-
args.use_qat = use_qat
1355-
args.use_lora = use_lora
1356-
args.preq_mode = preq_mode
1357-
args.preq_group_size = preq_group_size
1358-
args.preq_embedding_quantize = preq_embedding_quantize
1359-
1360-
transforms.append(get_quant_embedding_transform(args, checkpoint_dtype))
1344+
transforms.append(
1345+
get_quant_embedding_transform(
1346+
embedding_quantize, use_shared_embedding, checkpoint_dtype
1347+
)
1348+
)
13611349

13621350
# quantization_mode should be applied after embedding_quantize
13631351
# to support shared_embedding
@@ -1375,30 +1363,17 @@ class Args:
13751363
There are cases where this may be a no-op, namely, if all linears are
13761364
quantized in the checkpoint.
13771365
"""
1378-
1379-
# Create a mock args object with the necessary attributes
1380-
class Args:
1381-
pass
1382-
1383-
args = Args()
1384-
args.checkpoint = checkpoint
1385-
args.tokenizer_path = tokenizer_path
1386-
args.quantization_mode = quantization_mode
1387-
args.group_size = group_size
1388-
args.use_shared_embedding = use_shared_embedding
1389-
args.calibration_tasks = calibration_tasks
1390-
args.calibration_limit = calibration_limit
1391-
args.calibration_seq_length = calibration_seq_length
1392-
args.use_shared_embedding = use_shared_embedding
1393-
args.use_qat = use_qat
1394-
args.use_lora = use_lora
1395-
args.preq_mode = preq_mode
1396-
13971366
transforms.append(
13981367
get_quant_weight_transform(
1399-
args=args,
1368+
quantization_mode=quantization_mode,
1369+
group_size=group_size,
14001370
computation_dtype=dtype_override,
14011371
checkpoint_dtype=checkpoint_dtype,
1372+
checkpoint_path=checkpoint,
1373+
tokenizer_path=tokenizer_path,
1374+
calibration_tasks=calibration_tasks,
1375+
calibration_limit=calibration_limit,
1376+
calibration_seq_length=calibration_seq_length,
14021377
)
14031378
)
14041379

examples/models/llama/source_transformation/quantize.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def quantize( # noqa C901
4141
checkpoint_dtype: Optional[DType] = None,
4242
checkpoint_path: Optional[Path] = None,
4343
# following arguments only available when setting int4 or gptq quantization.
44-
group_size: Optional[int] = 128,
44+
group_size: Optional[int] = None,
4545
# following arguments are only used for GPTQ
4646
calibration_tasks: Optional[list] = None,
4747
calibration_limit: Optional[int] = None,
@@ -146,9 +146,9 @@ def quantize( # noqa C901
146146
print("quantized model:", model)
147147
return model
148148
elif qmode == "8da4w":
149-
# Check for required args
150149
if group_size is None:
151-
raise Exception("For 8da4w quantization, group size must be specified.")
150+
# TODO: Default value for group size for 8da4w. Need this here for refactor, will clean this up.
151+
group_size = 128
152152

153153
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
154154
from torchao.utils import unwrap_tensor_subclass
@@ -784,16 +784,20 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:
784784
############################ Source Transform Start #######################
785785

786786

787-
def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
788-
if args.embedding_quantize.startswith("torchao:"):
787+
def get_quant_embedding_transform(
788+
embedding_quantize: str,
789+
use_shared_embedding: bool = False,
790+
dtype_override: Optional[DType] = None,
791+
):
792+
if embedding_quantize.startswith("torchao:"):
789793
from torchao.experimental.quant_api import (
790794
EmbeddingQuantizer,
791795
SharedEmbeddingQuantizer,
792796
)
793797
from torchao.quantization.granularity import PerAxis, PerGroup
794798
from torchao.quantization.quant_api import MappingType
795799

796-
quant_args = args.embedding_quantize.split(":")[1].split(",")
800+
quant_args = embedding_quantize.split(":")[1].split(",")
797801
if len(quant_args) == 2:
798802
bitwidth, group_size = quant_args
799803
is_asymmetric = True
@@ -814,7 +818,7 @@ def get_quant_embedding_transform(args, dtype_override: Optional[DType] = None):
814818

815819
def _torchao_embedding_quantizer(model):
816820
with torch.no_grad():
817-
if not args.use_shared_embedding:
821+
if not use_shared_embedding:
818822
EmbeddingQuantizer(
819823
weight_dtype=weight_dtype,
820824
granularity=granularity,
@@ -831,7 +835,7 @@ def _torchao_embedding_quantizer(model):
831835

832836
return _torchao_embedding_quantizer
833837

834-
bitwidth, group_size = args.embedding_quantize.split(",")
838+
bitwidth, group_size = embedding_quantize.split(",")
835839
if group_size == "none" or group_size == "None" or group_size == "0":
836840
group_size = None
837841
else:
@@ -848,34 +852,27 @@ def _torchao_embedding_quantizer(model):
848852

849853

850854
def get_quant_weight_transform(
851-
args,
855+
quantization_mode: str,
856+
group_size: Optional[int] = None,
852857
computation_dtype: Optional[DType] = None,
853858
checkpoint_dtype: Optional[DType] = None,
859+
checkpoint_path: Optional[Path] = None,
860+
tokenizer_path: Optional[Path] = None,
861+
calibration_tasks: Optional[list] = None,
862+
calibration_limit: Optional[int] = None,
863+
calibration_seq_length: Optional[int] = None,
854864
):
855-
# If these optional args are None, don't provide them to quantize().
856-
quant_args_str = [
857-
"group_size",
858-
"calibration_tasks",
859-
"calibration_limit",
860-
"calibration_seq_length",
861-
]
862-
arg_dict = vars(args)
863-
quant_args = {
864-
param: val
865-
for param in quant_args_str
866-
if (val := arg_dict.get(param)) is not None
867-
}
868-
869865
return partial(
870866
quantize,
871-
**quant_args,
872-
qmode=args.quantization_mode,
867+
qmode=quantization_mode,
873868
computation_dtype=computation_dtype,
874869
checkpoint_dtype=checkpoint_dtype,
875-
checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None),
876-
tokenizer_path=(
877-
Path(path) if (path := args.tokenizer_path) is not None else None
878-
),
870+
checkpoint_path=(Path(path) if (path := checkpoint_path) is not None else None),
871+
group_size=group_size,
872+
calibration_tasks=calibration_tasks,
873+
calibration_limit=calibration_limit,
874+
calibration_seq_length=calibration_seq_length,
875+
tokenizer_path=(Path(path) if (path := tokenizer_path) is not None else None),
879876
)
880877

881878

examples/models/llava/export_llava.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,16 @@ def forward(self, input_pos, embeddings):
107107
"4,32",
108108
]
109109
)
110-
quant_transform = get_quant_weight_transform(args, dtype_override)
110+
quant_transform = get_quant_weight_transform(
111+
quantization_mode=args.quantization_mode,
112+
group_size=args.group_size,
113+
computation_dtype=dtype_override,
114+
checkpoint_path=args.checkpoint_path,
115+
tokenizer_path=args.tokenizer_path,
116+
calibration_tasks=args.calibration_tasks,
117+
calibration_limit=args.calibration_limit,
118+
calibration_seq_length=args.calibration_seq_length,
119+
)
111120
_, quantizers, _ = get_quantizer_and_quant_params(args)
112121
source_transforms = []
113122
if llava.use_sdpa_with_kv_cache_op:

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,9 @@ def permute(w, heads):
603603

604604
for i in range(len(llama_instance_list)):
605605
if args.embedding_quantize:
606-
llama_instance_list[i] = get_quant_embedding_transform(args)(
607-
llama_instance_list[i]
608-
)
606+
llama_instance_list[i] = get_quant_embedding_transform(
607+
embedding_quantize=args.embedding_quantize
608+
)(llama_instance_list[i])
609609
llama_instance_list[i] = convert_linear_to_conv2d(llama_instance_list[i])
610610
llama_instance_list[i] = SingleLlama(
611611
llama_instance_list[i].eval(), pte_filename

0 commit comments

Comments
 (0)