Skip to content

Commit 0c1c362

Browse files
authored
Refactor dtype handling in export_llama
Differential Revision: D71515138 Pull Request resolved: #9430
1 parent 6b573af commit 0c1c362

File tree

2 files changed

+46
-74
lines changed

2 files changed

+46
-74
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 39 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def build_args_parser() -> argparse.ArgumentParser:
322322
default="fp32",
323323
type=str,
324324
choices=["fp32", "fp16", "bf16"],
325-
help="Override the dtype of the model (default is the checkpoint dtype)."
326-
"Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
325+
help="Provide the dtype of the model. This must match up with the supported dtypes of the backends that you are using."
326+
"Please be aware that only some backends support fp16 and bf16.",
327327
)
328328

329329
parser.add_argument(
@@ -565,43 +565,42 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
565565
output_dir_path = canonical_path(args.output_dir, dir=True)
566566
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
567567

568-
# dtype override
569-
if args.dtype_override is not None:
570-
dtype_override = DType[args.dtype_override]
571-
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
572-
dtype_override = DType["fp16"]
573-
else:
574-
dtype_override = None
568+
# Convert dtype override string arg to actual type.
569+
dtype_override = DType[args.dtype_override]
570+
571+
edge_manager = _load_llama_model(
572+
args.model,
573+
checkpoint=checkpoint_path,
574+
checkpoint_dir=checkpoint_dir,
575+
params_path=params_path,
576+
use_kv_cache=args.use_kv_cache,
577+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
578+
generate_full_logits=args.generate_full_logits,
579+
weight_type=weight_type,
580+
enable_dynamic_shape=args.enable_dynamic_shape,
581+
calibration_tasks=args.calibration_tasks,
582+
calibration_limit=args.calibration_limit,
583+
calibration_seq_length=args.calibration_seq_length,
584+
calibration_data=args.calibration_data,
585+
tokenizer_path=args.tokenizer_path,
586+
verbose=args.verbose,
587+
max_seq_len=args.max_seq_length,
588+
max_context_len=args.max_context_length,
589+
input_prune_map_path=args.input_prune_map,
590+
output_prune_map_path=args.output_prune_map,
591+
metadata_str=args.metadata,
592+
dtype_override=dtype_override,
593+
args=args,
594+
)
575595

576-
return (
577-
_load_llama_model(
578-
args.model,
579-
checkpoint=checkpoint_path,
580-
checkpoint_dir=checkpoint_dir,
581-
params_path=params_path,
582-
use_kv_cache=args.use_kv_cache,
583-
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
584-
generate_full_logits=args.generate_full_logits,
585-
weight_type=weight_type,
586-
enable_dynamic_shape=args.enable_dynamic_shape,
587-
calibration_tasks=args.calibration_tasks,
588-
calibration_limit=args.calibration_limit,
589-
calibration_seq_length=args.calibration_seq_length,
590-
calibration_data=args.calibration_data,
591-
tokenizer_path=args.tokenizer_path,
592-
verbose=args.verbose,
593-
max_seq_len=args.max_seq_length,
594-
max_context_len=args.max_context_length,
595-
input_prune_map_path=args.input_prune_map,
596-
output_prune_map_path=args.output_prune_map,
597-
metadata_str=args.metadata,
598-
dtype_override=dtype_override,
599-
args=args,
600-
)
601-
.set_output_dir(output_dir_path)
602-
.source_transform(_get_source_transforms(args.model, dtype_override, args))
596+
# At this point, the model is loaded in the default fp32.
597+
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
598+
edge_manager.set_output_dir(output_dir_path).source_transform(
599+
_get_source_transforms(args.model, dtype_override, args)
603600
)
604601

602+
return edge_manager
603+
605604

606605
def get_quantizer_and_quant_params(args):
607606
pt2e_quant_params = get_pt2e_quantization_params(
@@ -1006,6 +1005,8 @@ def _load_llama_model(
10061005
else:
10071006
raise ValueError(f"{modelname} is not a valid Llama model.")
10081007

1008+
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
1009+
10091010
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
10101011
EagerModelFactory.create_model(
10111012
module_name,
@@ -1022,41 +1023,16 @@ def _load_llama_model(
10221023
enable_dynamic_shape=enable_dynamic_shape,
10231024
input_prune_map_path=input_prune_map_path,
10241025
output_prune_map_path=output_prune_map_path,
1026+
dtype=torch_dtype,
10251027
args=args,
10261028
)
10271029
)
1028-
if dtype_override:
1029-
assert isinstance(
1030-
dtype_override, DType
1031-
), "Override dtype needs to be of type <DType>"
1032-
torch_dtype = dtype_override.to_torch_dtype()
1033-
logging.info(f"model.to {torch_dtype}")
1034-
model = model.to(dtype=torch_dtype)
1035-
dtype = dtype_override
1036-
else:
1037-
state_dict = model.state_dict()
1038-
dtype = state_dict[next(iter(state_dict))].dtype
1039-
assert dtype in [
1040-
torch.bfloat16,
1041-
torch.float16,
1042-
torch.float32,
1043-
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
1044-
logging.info(f"Loaded model with dtype={dtype}")
1045-
1046-
if dtype == torch.bfloat16:
1047-
dtype = DType.bf16
1048-
elif dtype == torch.float16:
1049-
dtype = DType.fp16
1050-
elif dtype == torch.float32:
1051-
dtype = DType.fp32
1052-
else:
1053-
raise ValueError(f"Unsupported dtype {dtype}")
10541030

10551031
return LLMEdgeManager(
10561032
model=model,
10571033
modelname=modelname,
10581034
max_seq_len=model.max_seq_len,
1059-
dtype=dtype,
1035+
dtype=dtype_override,
10601036
use_kv_cache=use_kv_cache,
10611037
generate_full_logits=generate_full_logits,
10621038
example_inputs=example_inputs,

examples/models/llama/model.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ def __init__(self, **kwargs):
122122
"""
123123
)
124124

125-
# Get checkpoint dtype.
126-
self.dtype = get_checkpoint_dtype(checkpoint)
127-
128125
with open(params_path, "r") as f:
129126
params = json.loads(f.read())
130127
output_prune_map = None
@@ -171,7 +168,9 @@ def __init__(self, **kwargs):
171168
# Within the device="meta" context, tensors that are created do not carry data.
172169
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
173170
with torch.device("meta"):
171+
# Model itself is loaded in default dtype, fp32.
174172
self.model_ = Transformer(model_args)
173+
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
175174

176175
if "int8" in str(checkpoint_path):
177176
print("Using int8 weight-only quantization!")
@@ -241,6 +240,10 @@ def __init__(self, **kwargs):
241240
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
242241
# Because we are using device="meta", tensors do not have memory associated with them
243242
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
243+
244+
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
245+
# by default initialized to fp32. This is fine because every other supported type
246+
# losslessly converts to fp32, so we don't lose precision here.
244247
missing, unexpected = self.model_.load_state_dict(
245248
checkpoint,
246249
strict=False,
@@ -277,14 +280,7 @@ def __init__(self, **kwargs):
277280
self.model_ = prune_output_vocab(self.model_, output_prune_map)
278281

279282
def get_eager_model(self) -> torch.nn.Module:
280-
if self.dtype:
281-
# convert to the type of the provided checkpoint
282-
# input and output are torch.long, so signature unchanged
283-
return self.model_.to(self.dtype)
284-
else:
285-
# int8 quantization code has some bf16,
286-
# switch all to FP32
287-
return self.model_.to(torch.float32)
283+
return self.model_
288284

289285
def get_example_inputs(self):
290286
if self.use_kv_cache:

0 commit comments

Comments
 (0)