Skip to content

Commit dba4a9e

Browse files
Martin Yuanfacebook-github-bot
authored andcommitted
Add torch.no_grad() guard on export
Summary: For some models where the grad is enabled (like LLava model), we need to disable grad before exporting. This PR is to add this guard when using export script. It would not affect the models that grad is not enabled. Reviewed By: cccclai Differential Revision: D54812718
1 parent 75284d2 commit dba4a9e

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

examples/portable/scripts/export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import argparse
1010
import logging
1111

12+
import torch
13+
1214
from executorch.exir.capture import EdgeCompileConfig, ExecutorchBackendConfig
1315

1416
from ...models import MODEL_NAME_TO_MODEL
@@ -75,4 +77,5 @@ def main() -> None:
7577

7678

7779
if __name__ == "__main__":
78-
main() # pragma: no cover
80+
with torch.no_grad():
81+
main() # pragma: no cover

0 commit comments

Comments
 (0)