Skip to content

Commit b20db51

Browse files
mikekgfbmalfet
authored andcommitted
fix read before write (#134)
1 parent 86e5374 commit b20db51

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

export_et.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,15 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901
8989

9090
state_dict = model.state_dict()
9191
state_dict_dtype = state_dict[next(iter(state_dict))].dtype
92+
target_precision = get_precision()
93+
dynamic_shapes = None
9294

9395
# need to use kv sdpa?
9496
edge_config = EdgeCompileConfig(
9597
_check_ir_validity=False,
9698
_skip_type_promotion=bool(target_precision == torch.float16),
9799
)
98100

99-
dynamic_shapes = None
100-
101-
target_precision = get_precision()
102101
if target_precision == torch.float16: # or args.quantization_mode=="int4":
103102
if state_dict_dtype != torch.float16:
104103
print("model.to torch.float16")

0 commit comments

Comments
 (0)