Skip to content

ET or AOTI backend logic #392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,11 @@ def from_args(cls, args): # -> TokenizerArgs:
def _initialize_tokenizer(tokenizer_args: TokenizerArgs):
if tokenizer_args.is_sentencepiece:
from sentencepiece import SentencePieceProcessor

return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path))
elif tokenizer_args.is_tiktoken:
from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer

return TiktokenTokenizer(model_path=str(tokenizer_args.tokenizer_path))
else:
raise RuntimeError("must specify a valid tokenizer in TokenizerArgs")
Expand Down
2 changes: 1 addition & 1 deletion build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor
from torch.nn import functional as F

from build.utils import find_multiple, get_precision
from build.utils import find_multiple, get_precision, use_aoti_backend

config_path = Path(f"{str(Path(__file__).parent)}/known_model_params")

Expand Down
50 changes: 49 additions & 1 deletion build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,53 @@


##########################################################################
### dtype name to torch.dtype mapping ###
### set and get target backend is aoti or et for this model ###

active_builder_args_dso = None
active_builder_args_pte = None


def set_backend(dso, pte):
global active_builder_args_dso
global active_builder_args_pte
active_builder_args_dso = dso
active_builder_args_pte = pte


def use_aoti_backend() -> bool:
global active_builder_args_dso
global active_builder_args_pte

# eager == aoti, which is when backend has not been explicitly set
if (not active_builder_args_dso) and not (active_builder_args_pte):
return True

if active_builder_args_pte and active_builder_args_dso:
raise RuntimeError(
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
)

return bool(active_builder_args_dso)


def use_et_backend() -> bool:
global active_builder_args_dso
global active_builder_args_pte

# eager == aoti, which is when backend has not been explicitly set
if not (active_builder_args_pte or active_builder_args_dso):
return False

if active_builder_args_pte and active_builder_args_dso:
raise RuntimeError(
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
)

return bool(active_builder_args_pte)


##########################################################################
### set and get target precision for this model ###

precision = torch.float32

Expand All @@ -27,6 +73,8 @@ def get_precision():
return precision


##########################################################################
### dtype name to torch.dtype mapping ###
def name_to_dtype(name):
if name in name_to_dtype_dict:
return name_to_dtype_dict[name]
Expand Down
1 change: 0 additions & 1 deletion download.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
)



def _download_hf_snapshot(
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
):
Expand Down
8 changes: 5 additions & 3 deletions export.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should maybe drop special logic in export.py for gguf (line 75-89) and rely on global state in builder.py alone.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually. But first we fix the implementation of operators.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TokenizerArgs,
)

from build.utils import set_precision
from build.utils import set_backend, set_precision, use_aoti_backend, use_et_backend
from cli import add_arguments, add_arguments_for_export, arg_init, check_args
from download import download_and_convert, is_model_downloaded
from export_aoti import export_model as export_model_aoti
Expand All @@ -35,15 +35,17 @@


def main(args):
# THIS BELONGS INTO CLI
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a PR out to move this shortly.

# If a named model was provided and not downloaded, download it.
if args.model and not is_model_downloaded(args.model, args.model_directory):
download_and_convert(args.model, args.model_directory, args.hf_token)
# if args.model and not is_model_downloaded(args.model, args.model_directory):
# download_and_convert(args.model, args.model_directory, args.hf_token)

builder_args = BuilderArgs.from_args(args)
quantize = args.quantize

print(f"Using device={builder_args.device}")
set_precision(builder_args.precision)
set_backend(dso=args.output_dso_path, pte=args.output_pte_path)

builder_args.dso_path = None
builder_args.pte_path = None
Expand Down
1 change: 1 addition & 0 deletions export_et_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None):

def replace_attention_with_custom_sdpa_attention(module: nn.Module):
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # noqa

for name, child in module.named_children():
if isinstance(child, Attention):
setattr(module, name, CustomSDPAAttention(child))
Expand Down
1 change: 1 addition & 0 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def quantized_model(self) -> nn.Module:
class Int8DynActInt4WeightQuantizer(QuantHandler):
def __init__(self, model: nn.Module, device="cpu", tokenizer=None, **kwargs):
import torchao.quantization.quant_api as quant_api

self.model_ = model
self.device = device
self.tokenizer = tokenizer
Expand Down