Skip to content

Commit 54455a3

Browse files
authored
Minor fixes to PT2 export path: enum typo and max_seq_len (#1343)
1 parent 9480258 commit 54455a3

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

torchchat/export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ def export_for_server(
7878
dynamic_shapes=dynamic_shapes,
7979
options=options,
8080
)
81-
81+
8282
if package:
8383
from torch._inductor.package import package_aoti
8484
path = package_aoti(output_path, path)
85-
85+
8686
print(f"The generated packaged model can be found at: {path}")
8787
return path
8888

@@ -382,7 +382,7 @@ def main(args):
382382

383383
if builder_args.max_seq_length is None:
384384
if (
385-
output_dso_path is not None
385+
(output_dso_path is not None or output_aoti_package_path is not None)
386386
and not builder_args.dynamic_shapes
387387
):
388388
print("Setting max_seq_length to 300 for DSO export.")

torchchat/utils/build_utils.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from enum import Enum
1313
from pathlib import Path
14-
from typing import Any, Callable, Dict, List, Tuple
14+
from typing import Any, Callable, Dict, List, Optional, Tuple
1515

1616
import torch
1717

@@ -77,31 +77,39 @@ def unpack_packed_weights(
7777
def set_backend(dso, pte, aoti_package):
7878
global active_builder_args_dso
7979
global active_builder_args_pte
80+
global active_builder_args_aoti_package
8081
active_builder_args_dso = dso
8182
active_builder_args_aoti_package = aoti_package
8283
active_builder_args_pte = pte
8384

8485

8586
class _Backend(Enum):
86-
AOTI = (0,)
87+
AOTI = 0
8788
EXECUTORCH = 1
8889

8990

90-
def _active_backend() -> _Backend:
91+
def _active_backend() -> Optional[_Backend]:
9192
global active_builder_args_dso
9293
global active_builder_args_aoti_package
9394
global active_builder_args_pte
9495

95-
# eager == aoti, which is when backend has not been explicitly set
96-
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
97-
return True
96+
args = (
97+
active_builder_args_dso,
98+
active_builder_args_pte,
99+
active_builder_args_aoti_package,
100+
)
101+
102+
# Return None, as default
103+
if not any(args):
104+
return None
98105

99-
if active_builder_args_pte and active_builder_args_aoti_package:
106+
# Catch more than one arg
107+
if sum(map(bool, args)) > 1:
100108
raise RuntimeError(
101-
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
109+
"Code generation needs to choose different implementations. Please only use one export option, and call export twice if necessary!"
102110
)
103111

104-
return _Backend.AOTI if active_builder_args_pte else _Backend.EXECUTORCH
112+
return _Backend.EXECUTORCH if active_builder_args_pte else _Backend.AOTI
105113

106114

107115
def use_aoti_backend() -> bool:

0 commit comments

Comments
 (0)