Skip to content

Commit 25e328a

Browse files
lucylqmalfet
authored andcommitted
export bf16 (#618)
1 parent fe70ea2 commit 25e328a

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

export_et.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901
7979
if state_dict_dtype != torch.float32:
8080
print("model.to torch.float32")
8181
model = model.to(dtype=torch.float32)
82+
elif target_precision == torch.bfloat16:
83+
print("model.to torch.bfloat16")
84+
model = model.to(dtype=torch.bfloat16)
8285
else:
8386
raise ValueError(f"Unsupported dtype for ET export: {target_precision}")
8487

0 commit comments

Comments
 (0)