Skip to content

Commit f812d59

Browse files
committed
[ExecuTorch] Allow setting dtype to bf16 in export_llama
Pull Request resolved: #4985 Support creating bf16 PTEs. ghstack-source-id: 240577203 @exported-using-ghexport Differential Revision: [D61981363](https://our.internmc.facebook.com/intern/diff/D61981363/)
1 parent 248e33c commit f812d59

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,9 @@ def build_args_parser() -> argparse.ArgumentParser:
256256
"--dtype-override",
257257
default="fp32",
258258
type=str,
259-
choices=["fp32", "fp16"],
259+
choices=["fp32", "fp16", "bf16"],
260260
help="Override the dtype of the model (default is the checkpoint dtype)."
261-
"Options: fp32, fp16. Please be aware that only some backends support fp16.",
261+
"Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
262262
)
263263

264264
parser.add_argument(

extension/llm/export/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def to_torch_dtype(self) -> torch.dtype:
4646
mapping = {
4747
DType.fp32: torch.float32,
4848
DType.fp16: torch.float16,
49+
DType.bf16: torch.bfloat16,
4950
}
5051
if self not in mapping:
5152
raise ValueError(f"Unsupported dtype {self}")

0 commit comments

Comments
 (0)