Skip to content

Commit 62f1e9d

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Refactor dtype handling in export_llama (#9430)
Summary: While it might make sense intuitively to have the dtype of the model be the dtype of the checkpoint, this isn't possible for all backends which only support some dtypes. We need to be explicit about the dtype of the model for this reason. No more intermediate conversion into the checkpoint dtype, which could cause precision loss in situations like these: fp32 -> checkpoint dtype (fp16 or lower) -> back to dtype override (fp32), where we are losing precision on buffers that are instantiated in fp32 and downcast to fp16. Reviewed By: kimishpatel Differential Revision: D71515138
1 parent a828307 commit 62f1e9d

File tree

2 files changed

+46
-74
lines changed

2 files changed

+46
-74
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 39 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def build_args_parser() -> argparse.ArgumentParser:
322322
default="fp32",
323323
type=str,
324324
choices=["fp32", "fp16", "bf16"],
325-
help="Override the dtype of the model (default is the checkpoint dtype)."
326-
"Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
325+
help="Provide the dtype of the model. This must match up with the supported dtypes of the backends that you are using."
326+
"Please be aware that only some backends support fp16 and bf16.",
327327
)
328328

329329
parser.add_argument(
@@ -565,43 +565,42 @@ def _prepare_for_llama_export(args) -> LLMEdgeManager:
565565
output_dir_path = canonical_path(args.output_dir, dir=True)
566566
weight_type = WeightType.FAIRSEQ2 if args.fairseq2 else WeightType.LLAMA
567567

568-
# dtype override
569-
if args.dtype_override is not None:
570-
dtype_override = DType[args.dtype_override]
571-
elif args.quantization_mode in ["8da4w", "8da4w-gptq"]:
572-
dtype_override = DType["fp16"]
573-
else:
574-
dtype_override = None
568+
# Convert dtype override string arg to actual type.
569+
dtype_override = DType[args.dtype_override]
570+
571+
edge_manager = _load_llama_model(
572+
args.model,
573+
checkpoint=checkpoint_path,
574+
checkpoint_dir=checkpoint_dir,
575+
params_path=params_path,
576+
use_kv_cache=args.use_kv_cache,
577+
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
578+
generate_full_logits=args.generate_full_logits,
579+
weight_type=weight_type,
580+
enable_dynamic_shape=args.enable_dynamic_shape,
581+
calibration_tasks=args.calibration_tasks,
582+
calibration_limit=args.calibration_limit,
583+
calibration_seq_length=args.calibration_seq_length,
584+
calibration_data=args.calibration_data,
585+
tokenizer_path=args.tokenizer_path,
586+
verbose=args.verbose,
587+
max_seq_len=args.max_seq_length,
588+
max_context_len=args.max_context_length,
589+
input_prune_map_path=args.input_prune_map,
590+
output_prune_map_path=args.output_prune_map,
591+
metadata_str=args.metadata,
592+
dtype_override=dtype_override,
593+
args=args,
594+
)
575595

576-
return (
577-
_load_llama_model(
578-
args.model,
579-
checkpoint=checkpoint_path,
580-
checkpoint_dir=checkpoint_dir,
581-
params_path=params_path,
582-
use_kv_cache=args.use_kv_cache,
583-
use_sdpa_with_kv_cache=args.use_sdpa_with_kv_cache,
584-
generate_full_logits=args.generate_full_logits,
585-
weight_type=weight_type,
586-
enable_dynamic_shape=args.enable_dynamic_shape,
587-
calibration_tasks=args.calibration_tasks,
588-
calibration_limit=args.calibration_limit,
589-
calibration_seq_length=args.calibration_seq_length,
590-
calibration_data=args.calibration_data,
591-
tokenizer_path=args.tokenizer_path,
592-
verbose=args.verbose,
593-
max_seq_len=args.max_seq_length,
594-
max_context_len=args.max_context_length,
595-
input_prune_map_path=args.input_prune_map,
596-
output_prune_map_path=args.output_prune_map,
597-
metadata_str=args.metadata,
598-
dtype_override=dtype_override,
599-
args=args,
600-
)
601-
.set_output_dir(output_dir_path)
602-
.source_transform(_get_source_transforms(args.model, dtype_override, args))
596+
# At this point, the model is loaded in the default fp32.
597+
edge_manager.model = edge_manager.model.to(dtype=dtype_override.to_torch_dtype())
598+
edge_manager.set_output_dir(output_dir_path).source_transform(
599+
_get_source_transforms(args.model, dtype_override, args)
603600
)
604601

602+
return edge_manager
603+
605604

606605
def get_quantizer_and_quant_params(args):
607606
pt2e_quant_params = get_pt2e_quantization_params(
@@ -1006,6 +1005,8 @@ def _load_llama_model(
10061005
else:
10071006
raise ValueError(f"{modelname} is not a valid Llama model.")
10081007

1008+
torch_dtype = dtype_override.to_torch_dtype() if dtype_override else None
1009+
10091010
model, example_inputs, example_kwarg_inputs, dynamic_shapes = (
10101011
EagerModelFactory.create_model(
10111012
module_name,
@@ -1022,41 +1023,16 @@ def _load_llama_model(
10221023
enable_dynamic_shape=enable_dynamic_shape,
10231024
input_prune_map_path=input_prune_map_path,
10241025
output_prune_map_path=output_prune_map_path,
1026+
dtype=torch_dtype,
10251027
args=args,
10261028
)
10271029
)
1028-
if dtype_override:
1029-
assert isinstance(
1030-
dtype_override, DType
1031-
), "Override dtype needs to be of type <DType>"
1032-
torch_dtype = dtype_override.to_torch_dtype()
1033-
logging.info(f"model.to {torch_dtype}")
1034-
model = model.to(dtype=torch_dtype)
1035-
dtype = dtype_override
1036-
else:
1037-
state_dict = model.state_dict()
1038-
dtype = state_dict[next(iter(state_dict))].dtype
1039-
assert dtype in [
1040-
torch.bfloat16,
1041-
torch.float16,
1042-
torch.float32,
1043-
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
1044-
logging.info(f"Loaded model with dtype={dtype}")
1045-
1046-
if dtype == torch.bfloat16:
1047-
dtype = DType.bf16
1048-
elif dtype == torch.float16:
1049-
dtype = DType.fp16
1050-
elif dtype == torch.float32:
1051-
dtype = DType.fp32
1052-
else:
1053-
raise ValueError(f"Unsupported dtype {dtype}")
10541030

10551031
return LLMEdgeManager(
10561032
model=model,
10571033
modelname=modelname,
10581034
max_seq_len=model.max_seq_len,
1059-
dtype=dtype,
1035+
dtype=dtype_override,
10601036
use_kv_cache=use_kv_cache,
10611037
generate_full_logits=generate_full_logits,
10621038
example_inputs=example_inputs,

examples/models/llama/model.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ def __init__(self, **kwargs):
122122
"""
123123
)
124124

125-
# Get checkpoint dtype.
126-
self.dtype = get_checkpoint_dtype(checkpoint)
127-
128125
with open(params_path, "r") as f:
129126
params = json.loads(f.read())
130127
output_prune_map = None
@@ -171,7 +168,9 @@ def __init__(self, **kwargs):
171168
# Within the device="meta" context, tensors that are created do not carry data.
172169
# They possess all other metadata a tensor carries such as size, stride, requires_grad.
173170
with torch.device("meta"):
171+
# Model itself is loaded in default dtype, fp32.
174172
self.model_ = Transformer(model_args)
173+
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
175174

176175
if "int8" in str(checkpoint_path):
177176
print("Using int8 weight-only quantization!")
@@ -241,6 +240,10 @@ def __init__(self, **kwargs):
241240
# assign=True: load params/buffers by assignment instead of performing an in-place copy.
242241
# Because we are using device="meta", tensors do not have memory associated with them
243242
# and an in-place copy is a no-op. Use assign=True in load_state_dict for this scenario.
243+
244+
# Also, the checkpoint is loaded and dtype promoted to the transformer's dtype, which is
245+
# by default initialized to fp32. This is fine because every other supported type
246+
# losslessly converts to fp32, so we don't lose precision here.
244247
missing, unexpected = self.model_.load_state_dict(
245248
checkpoint,
246249
strict=False,
@@ -277,14 +280,7 @@ def __init__(self, **kwargs):
277280
self.model_ = prune_output_vocab(self.model_, output_prune_map)
278281

279282
def get_eager_model(self) -> torch.nn.Module:
280-
if self.dtype:
281-
# convert to the type of the provided checkpoint
282-
# input and output are torch.long, so signature unchanged
283-
return self.model_.to(self.dtype)
284-
else:
285-
# int8 quantization code has some bf16,
286-
# switch all to FP32
287-
return self.model_.to(torch.float32)
283+
return self.model_
288284

289285
def get_example_inputs(self):
290286
if self.use_kv_cache:

0 commit comments

Comments
 (0)