|
11 | 11 |
|
12 | 12 | from enum import Enum
|
13 | 13 | from pathlib import Path
|
14 |
| -from typing import Any, Callable, Dict, List, Tuple |
| 14 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
15 | 15 |
|
16 | 16 | import torch
|
17 | 17 |
|
@@ -77,31 +77,39 @@ def unpack_packed_weights(
|
77 | 77 | def set_backend(dso, pte, aoti_package):
|
78 | 78 | global active_builder_args_dso
|
79 | 79 | global active_builder_args_pte
|
| 80 | + global active_builder_args_aoti_package |
80 | 81 | active_builder_args_dso = dso
|
81 | 82 | active_builder_args_aoti_package = aoti_package
|
82 | 83 | active_builder_args_pte = pte
|
83 | 84 |
|
84 | 85 |
|
85 | 86 | class _Backend(Enum):
|
86 |
| - AOTI = (0,) |
| 87 | + AOTI = 0 |
87 | 88 | EXECUTORCH = 1
|
88 | 89 |
|
89 | 90 |
|
90 |
| -def _active_backend() -> _Backend: |
| 91 | +def _active_backend() -> Optional[_Backend]: |
91 | 92 | global active_builder_args_dso
|
92 | 93 | global active_builder_args_aoti_package
|
93 | 94 | global active_builder_args_pte
|
94 | 95 |
|
95 |
| - # eager == aoti, which is when backend has not been explicitly set |
96 |
| - if (not active_builder_args_pte) and (not active_builder_args_aoti_package): |
97 |
| - return True |
| 96 | + args = ( |
| 97 | + active_builder_args_dso, |
| 98 | + active_builder_args_pte, |
| 99 | + active_builder_args_aoti_package, |
| 100 | + ) |
| 101 | + |
| 102 | + # Return None, as default |
| 103 | + if not any(args): |
| 104 | + return None |
98 | 105 |
|
99 |
| - if active_builder_args_pte and active_builder_args_aoti_package: |
| 106 | + # Catch more than one arg |
| 107 | + if sum(map(bool, args)) > 1: |
100 | 108 | raise RuntimeError(
|
101 |
| - "code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!" |
| 109 | + "Code generation needs to choose different implementations. Please only use one export option, and call export twice if necessary!" |
102 | 110 | )
|
103 | 111 |
|
104 |
| - return _Backend.AOTI if active_builder_args_pte else _Backend.EXECUTORCH |
| 112 | + return _Backend.EXECUTORCH if active_builder_args_pte else _Backend.AOTI |
105 | 113 |
|
106 | 114 |
|
107 | 115 | def use_aoti_backend() -> bool:
|
|
0 commit comments