Skip to content

Commit 23d48e1

Browse files
authored
Merge branch 'main' into angelayi/aoti_api_update
2 parents 73739bd + cbc72a4 commit 23d48e1

File tree

4 files changed

+77
-62
lines changed

4 files changed

+77
-62
lines changed

.github/workflows/runner-cuda-dtype.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
5353
python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-aoti-package-path /tmp/model.pt2
5454
55-
./cmake-out/aoti_run /tmp/model.pt2 -d CUDA -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
55+
./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
5656
5757
done
5858

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ torchchat/utils/scripts/build_native.sh aoti
341341

342342
Then run the compiled executable, with the pt2.
343343
```bash
344-
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
344+
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -i "Once upon a time"
345345
```
346346

347347
## Mobile Execution

runner/run.cpp

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ typedef struct {
102102
typedef struct {
103103
Config config; // the hyperparameters of the architecture (the blueprint)
104104
RunState state; // buffers for the "wave" of activations in the forward pass
105+
std::unordered_map<std::string, std::string> metadata;
105106

106107
#ifdef __AOTI_MODEL__
107108
torch::inductor::AOTIModelPackageLoader *runner;
@@ -141,20 +142,9 @@ void read_checkpoint(char *checkpoint, Config *config) {
141142
config->vocab_size = abs(config->vocab_size);
142143
}
143144

144-
void build_transformer(Transformer *t, char *model_path, int vocab_size,
145-
int seq_len) {
146-
// read in the Config and the Weights from the model
147-
// read_checkpoint(model_path, &t->config);
148-
// allocate the RunState buffers
149-
t->config.vocab_size = vocab_size;
150-
t->config.seq_len = seq_len;
151-
malloc_run_state(&t->state, &t->config);
152-
145+
void build_transformer(Transformer *t, char *model_path) {
153146
#ifdef __AOTI_MODEL__
154147
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
155-
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu"
156-
? torch::Device(torch::kCPU)
157-
: torch::Device(torch::kCUDA);
158148
#else //__ET_MODEL__
159149
t->runner = new Module(
160150
/* path to PTE model */ model_path,
@@ -776,9 +766,6 @@ void error_usage() {
776766
" -v <int> (optional) vocab size, default is model-specific.\n");
777767
fprintf(stderr,
778768
" -l <int> (optional) llama version (2 or 3), default 2.\n");
779-
fprintf(
780-
stderr,
781-
" -d <string> (optional) device(CUDA or CPU) model was exported for\n");
782769
exit(EXIT_FAILURE);
783770
}
784771

@@ -848,37 +835,35 @@ int main(int argc, char *argv[]) {
848835
system_prompt = argv[i + 1];
849836
} else if (argv[i][1] == 'l') {
850837
llama_ver = atoi(argv[i + 1]);
851-
#ifdef __AOTI_MODEL__
852-
} else if (argv[i][1] == 'd') {
853-
#ifdef USE_CUDA
854-
if (strcasecmp(argv[i + 1], "CUDA") == 0) {
855-
aoti_device = torch::Device(torch::kCUDA);
856-
} else
857-
#endif
858-
if (strcasecmp(argv[i + 1], "CPU") == 0) {
859-
aoti_device = torch::Device(torch::kCPU);
860-
} else {
861-
fprintf(stderr, "Unknown device %s", argv[i + 1]);
862-
exit(1);
863-
}
864-
#endif
865838
} else {
866839
error_usage();
867840
}
868841
}
869842

843+
if (model_path == NULL) {
844+
fprintf(stderr, "No model_path provided.");
845+
error_usage();
846+
}
847+
848+
Transformer transformer;
849+
build_transformer(&transformer, model_path);
850+
851+
#ifdef __AOTI_MODEL__
852+
auto aoti_metadata = transformer.runner->get_metadata();
853+
aoti_device = aoti_metadata["AOTI_DEVICE_KEY"] == "cpu"
854+
? torch::Device(torch::kCPU)
855+
: torch::Device(torch::kCUDA);
856+
ModelType model_type = get_model_type(std::stoi(aoti_metadata["tokenizer_type"]));
857+
#else // __ET_MODEL__
870858
ModelType model_type = get_model_type(llama_ver);
859+
#endif
860+
871861
if (model_type == UNKNOWN_MODEL) {
872862
fprintf(stderr, "Unknown model type passed by -l argument. Received l=%d.",
873863
llama_ver);
874864
error_usage();
875865
}
876866

877-
if (model_path == NULL) {
878-
fprintf(stderr, "No model_path provided.");
879-
error_usage();
880-
}
881-
882867
if (tokenizer_path == NULL) {
883868
fprintf(stderr, "No tokenizer_path provided.");
884869
error_usage();
@@ -901,8 +886,12 @@ int main(int argc, char *argv[]) {
901886
vocab_size = tokenizer->vocab_size();
902887
}
903888

904-
Transformer transformer;
905-
build_transformer(&transformer, model_path, vocab_size, steps);
889+
// read in the Config and the Weights from the model
890+
// read_checkpoint(model_path, &t->config);
891+
// allocate the RunState buffers
892+
transformer.config.vocab_size = vocab_size;
893+
transformer.config.seq_len = steps;
894+
malloc_run_state(&transformer.state, &transformer.config);
906895

907896
Sampler sampler;
908897
build_sampler(&sampler, vocab_size, temperature, topp, rng_seed);

torchchat/export.py

Lines changed: 50 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8-
from typing import Optional
8+
from typing import Dict, Optional
99

1010
import torch
1111
import torch._inductor
@@ -39,6 +39,7 @@ def export_for_server(
3939
output_path: str = "model.pt2",
4040
dynamic_shapes: bool = False,
4141
package: bool = True,
42+
metadata: Optional[Dict[str, str]] = None,
4243
) -> str:
4344
"""
4445
Export the model using AOT Compile to get a .dso for server use cases.
@@ -67,8 +68,10 @@ def export_for_server(
6768
dynamic_shapes = None
6869

6970
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
70-
metadata = {} # TODO: put more metadata here
71-
options = {"aot_inductor.metadata": metadata}
71+
options = {
72+
"aot_inductor.package": package,
73+
"aot_inductor.metadata": metadata or {},
74+
}
7275
if not package:
7376
options = {"aot_inductor.output_path": output_path}
7477

@@ -106,13 +109,13 @@ def export_for_server(
106109
from typing import Any, Dict, Tuple, Union
107110

108111
import executorch.exir as exir
112+
from executorch.backends.xnnpack._passes.convert_to_linear import (
113+
ConvertToLinearPass,
114+
)
109115

110116
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
111117
XnnpackDynamicallyQuantizedPartitioner,
112118
)
113-
from executorch.backends.xnnpack._passes.convert_to_linear import (
114-
ConvertToLinearPass,
115-
)
116119
from executorch.exir import EdgeProgramManager, to_edge
117120

118121
from executorch.exir.capture._config import (
@@ -170,18 +173,22 @@ def __init__(self, attention: Attention):
170173

171174
self.wo = attention.wo
172175

173-
max_batch_size, n_heads, max_seq_length, head_dim = (
174-
attention.kv_cache[0].k_cache.shape
175-
)
176+
max_batch_size, n_heads, max_seq_length, head_dim = attention.kv_cache[
177+
0
178+
].k_cache.shape
176179
cache_dtype = attention.kv_cache[0].k_cache.dtype
177180
# The `Attention` module being replaced can have multiple KV caches
178181
# (denoted by `cache_lanes`). Thus we follow the same setup format
179182
# as in `Attention.setup_cache`.
180183
cache_lanes = len(attention.kv_cache)
181-
self.kv_cache = nn.ModuleList([
182-
CustomKVCache(max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype)
183-
for _ in range(cache_lanes)
184-
])
184+
self.kv_cache = nn.ModuleList(
185+
[
186+
CustomKVCache(
187+
max_batch_size, max_seq_length, n_heads, head_dim, cache_dtype
188+
)
189+
for _ in range(cache_lanes)
190+
]
191+
)
185192

186193
self.n_heads = attention.n_heads
187194
self.head_dim = attention.head_dim
@@ -219,9 +226,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
219226
return self.wo(output)
220227

221228
def replace_attention_with_custom_sdpa_attention(module: nn.Module):
222-
from executorch.extension.llm.custom_ops import ( # noqa
223-
sdpa_with_kv_cache,
224-
)
229+
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa
225230

226231
for name, child in module.named_children():
227232
if isinstance(child, Attention):
@@ -242,7 +247,9 @@ def _to_core_aten(
242247
raise ValueError(
243248
f"Expected passed in model to be an instance of fx.GraphModule, got {type(model)}"
244249
)
245-
core_aten_ep = export_for_training(model, example_inputs, dynamic_shapes=dynamic_shapes)
250+
core_aten_ep = export_for_training(
251+
model, example_inputs, dynamic_shapes=dynamic_shapes
252+
)
246253
if verbose:
247254
logging.info(f"Core ATen graph:\n{core_aten_ep.graph}")
248255
return core_aten_ep
@@ -354,7 +361,11 @@ def main(args):
354361

355362
print(f"Using device={builder_args.device}")
356363
set_precision(builder_args.precision)
357-
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path)
364+
set_backend(
365+
dso=args.output_dso_path,
366+
pte=args.output_pte_path,
367+
aoti_package=args.output_aoti_package_path,
368+
)
358369

359370
builder_args.dso_path = None
360371
builder_args.pte_path = None
@@ -376,6 +387,7 @@ def main(args):
376387

377388
# TODO: clean this up
378389
# This mess is because ET does not support _weight_int4pack_mm right now
390+
tokenizer_args = None
379391
if not builder_args.gguf_path:
380392
# tokenizer needed for quantization so get that here,
381393
try:
@@ -386,9 +398,8 @@ def main(args):
386398

387399
if builder_args.max_seq_length is None:
388400
if (
389-
(output_dso_path is not None or output_aoti_package_path is not None)
390-
and not builder_args.dynamic_shapes
391-
):
401+
output_dso_path is not None or output_aoti_package_path is not None
402+
) and not builder_args.dynamic_shapes:
392403
print("Setting max_seq_length to 300 for DSO export.")
393404
builder_args.max_seq_length = 300
394405
elif output_pte_path is not None:
@@ -401,7 +412,8 @@ def main(args):
401412
quantize,
402413
tokenizer,
403414
max_seq_length=builder_args.max_seq_length,
404-
support_tensor_subclass=output_dso_path is None and output_aoti_package_path is None,
415+
support_tensor_subclass=output_dso_path is None
416+
and output_aoti_package_path is None,
405417
)
406418
model_to_pte = model
407419
model_to_dso = model
@@ -439,7 +451,9 @@ def main(args):
439451
if output_dso_path:
440452
output_dso_path = str(os.path.abspath(output_dso_path))
441453
print(f"Exporting model using AOT Inductor to {output_dso_path}")
442-
print("WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead.")
454+
print(
455+
"WARNING!! The path of compiling a dso is deprecated. Please use --output-aoti-package-path to create a .pt2 artifact instead."
456+
)
443457
export_for_server(
444458
model_to_dso,
445459
builder_args.device,
@@ -450,11 +464,23 @@ def main(args):
450464

451465
if output_aoti_package_path:
452466
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
453-
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
467+
468+
if tokenizer_args is None:
469+
tokenizer_type = "0"
470+
elif tokenizer_args.is_sentencepiece:
471+
tokenizer_type = "2" # Corresponding to llama2
472+
else:
473+
tokenizer_type = "3" # Corresponding to llama3
474+
475+
metadata = {"tokenizer_type": tokenizer_type}
476+
print(
477+
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}."
478+
)
454479
export_for_server(
455480
model_to_aoti_package,
456481
builder_args.device,
457482
output_aoti_package_path,
458483
builder_args.dynamic_shapes,
459484
package=True,
485+
metadata=metadata,
460486
)

0 commit comments

Comments
 (0)