Skip to content

Commit b88f71e

Browse files
committed
Fix export_for_train
1 parent 56aa9ef commit b88f71e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchchat/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def export_for_et(model, device, output_path) -> str:
315315
with torch.nn.attention.sdpa_kernel(
316316
[torch.nn.attention.SDPBackend.MATH]
317317
), torch.no_grad():
318-
m = export_for_training(model, input, dynamic_shapes=dynamic_shapes)
318+
m = export_for_training(model, input, dynamic_shapes=dynamic_shapes).module()
319319

320320
edge_manager = export_to_edge(
321321
m,

0 commit comments

Comments
 (0)