Skip to content

Commit 4247992

Browse files
mikekgfbmalfet
authored andcommitted
ET or AOTI backend logic (#392)
* ET or AOTI backend logic * use args, not builder_args * typo * typo
1 parent d551af6 commit 4247992

File tree

7 files changed

+59
-6
lines changed

7 files changed

+59
-6
lines changed

build/builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,11 @@ def from_args(cls, args): # -> TokenizerArgs:
199199
def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
200200
if tokenizer_args.is_sentencepiece:
201201
from sentencepiece import SentencePieceProcessor
202+
202203
return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
203204
elif tokenizer_args.is_tiktoken:
204205
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer
206+
205207
return TiktokenTokenizer(model_path=str(tokenizer_args.tokenizer_path))
206208
else:
207209
raise RuntimeError("must specify a valid tokenizer in TokenizerArgs")

build/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch import Tensor
1616
from torch.nn import functional as F
1717

18-
from build.utils import find_multiple, get_precision
18+
from build.utils import find_multiple, get_precision, use_aoti_backend
1919

2020
config_path = Path(f"{str(Path(__file__).parent)}/known_model_params")
2121

build/utils.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,53 @@
1414

1515

1616
##########################################################################
17-
### dtype name to torch.dtype mapping ###
17+
### set and get target backend is aoti or et for this model ###
18+
19+
active_builder_args_dso = None
20+
active_builder_args_pte = None
21+
22+
23+
def set_backend(dso, pte):
24+
global active_builder_args_dso
25+
global active_builder_args_pte
26+
active_builder_args_dso = dso
27+
active_builder_args_pte = pte
28+
29+
30+
def use_aoti_backend() -> bool:
31+
global active_builder_args_dso
32+
global active_builder_args_pte
33+
34+
# eager == aoti, which is when backend has not been explicitly set
35+
if (not active_builder_args_dso) and not (active_builder_args_pte):
36+
return True
37+
38+
if active_builder_args_pte and active_builder_args_dso:
39+
raise RuntimeError(
40+
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
41+
)
42+
43+
return bool(active_builder_args_dso)
44+
45+
46+
def use_et_backend() -> bool:
47+
global active_builder_args_dso
48+
global active_builder_args_pte
49+
50+
# eager == aoti, which is when backend has not been explicitly set
51+
if not (active_builder_args_pte or active_builder_args_dso):
52+
return False
53+
54+
if active_builder_args_pte and active_builder_args_dso:
55+
raise RuntimeError(
56+
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
57+
)
58+
59+
return bool(active_builder_args_pte)
60+
61+
62+
##########################################################################
63+
### set and get target precision for this model ###
1864

1965
precision = torch.float32
2066

@@ -29,6 +75,8 @@ def get_precision():
2975
return precision
3076

3177

78+
##########################################################################
79+
### dtype name to torch.dtype mapping ###
3280
def name_to_dtype(name):
3381
if name in name_to_dtype_dict:
3482
return name_to_dtype_dict[name]

download.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
)
1818

1919

20-
2120
def _download_hf_snapshot(
2221
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
2322
):

export.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TokenizerArgs,
1919
)
2020

21-
from build.utils import set_precision
21+
from build.utils import set_backend, set_precision, use_aoti_backend, use_et_backend
2222
from cli import add_arguments, add_arguments_for_export, arg_init, check_args
2323
from download import download_and_convert, is_model_downloaded
2424
from export_aoti import export_model as export_model_aoti
@@ -35,15 +35,17 @@
3535

3636

3737
def main(args):
38+
# THIS BELONGS INTO CLI
3839
# If a named model was provided and not downloaded, download it.
39-
if args.model and not is_model_downloaded(args.model, args.model_directory):
40-
download_and_convert(args.model, args.model_directory, args.hf_token)
40+
# if args.model and not is_model_downloaded(args.model, args.model_directory):
41+
# download_and_convert(args.model, args.model_directory, args.hf_token)
4142

4243
builder_args = BuilderArgs.from_args(args)
4344
quantize = args.quantize
4445

4546
print(f"Using device={builder_args.device}")
4647
set_precision(builder_args.precision)
48+
set_backend(dso=args.output_dso_path, pte=args.output_pte_path)
4749

4850
builder_args.dso_path = None
4951
builder_args.pte_path = None

export_et_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None):
8181

8282
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
8383
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # noqa
84+
8485
for name, child in module.named_children():
8586
if isinstance(child, Attention):
8687
setattr(module, name, CustomSDPAAttention(child))

quantize.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def quantized_model(self) -> nn.Module:
7676
class Int8DynActInt4WeightQuantizer(QuantHandler):
7777
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
7878
import torchao.quantization.quant_api as quant_api
79+
7980
self.model_ = model
8081
self.device = device
8182
self.tokenizer = tokenizer

0 commit comments

Comments
 (0)