Skip to content

Commit 160421a

Browse files
authored
Allow non tensor checkpoint values (#8845)
1 parent 9aca1fa commit 160421a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

examples/models/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
6464
mismatched_dtypes = [
6565
(key, value.dtype)
6666
for key, value in checkpoint.items()
67-
if value.dtype != dtype
67+
if hasattr(value, "dtype") and value.dtype != dtype
6868
]
6969
if len(mismatched_dtypes) > 0:
7070
print(

0 commit comments

Comments
 (0)