Skip to content

Commit 33306d3

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Fix dtype check for builder.py
Summary: llama2 7b checkpoint is with bfloat16 hence this check doesnt work without including that. Created from CodeHub with https://fburl.com/edit-in-codehub Reviewed By: larryliu0820, manuelcandales Differential Revision: D54122475 fbshipit-source-id: 4d25833776c9ca14f009550e18c6930734630586
1 parent 6d04257 commit 33306d3

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

examples/models/llama2/builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def load_llama_model(
7575
)
7676
state_dict = model.state_dict()
7777
dtype = state_dict[next(iter(state_dict))].dtype
78-
assert dtype in [torch.float16, torch.float32], "Only support fp16 or fp32"
78+
assert dtype in [
79+
torch.bfloat16,
80+
torch.float16,
81+
torch.float32,
82+
], f"Only support bfloat16, fp16 or fp32 got {dtype}"
7983
logging.info(f"Loaded model with dtype={dtype}")
8084

8185
return LlamaEdgeManager(

0 commit comments

Comments
 (0)