Skip to content

Commit 94c0079

Browse files
committed
[aoti] Add cpp loader changes
1 parent 9a53478 commit 94c0079

File tree

6 files changed

+41
-42
lines changed

6 files changed

+41
-42
lines changed

README.md

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy an
182182
[skip default]: end
183183

184184
### Server
185-
This mode exposes a REST API for interacting with a model.
185+
This mode exposes a REST API for interacting with a model.
186186
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.
187187

188188
To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
@@ -255,14 +255,14 @@ Use the "Max Response Tokens" slider to limit the maximum number of tokens gener
255255
## Desktop/Server Execution
256256

257257
### AOTI (AOT Inductor)
258-
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a [DSO](https://en.wikipedia.org/wiki/Shared_library) model (represented by a file with extension `.so`)
258+
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a zipped PT2 file containing all the artifacts generated by AOTInductor, and a [.so](https://en.wikipedia.org/wiki/Shared_library) file with the runnable contents
259259
that is then loaded for inference. This can be done with both Python and C++ enviroments.
260260

261261
The following example exports and executes the Llama3.1 8B Instruct
262262
model. The first command compiles and performs the actual export.
263263

264264
```bash
265-
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts
265+
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts.pt2
266266
```
267267

268268
> [!NOTE]
@@ -274,12 +274,11 @@ case visit our [customization guide](docs/model_customization.md).
274274

275275
### Run in a Python Enviroment
276276

277-
To run in a python enviroment, use the generate subcommand like before, but include the dso file.
277+
To run in a python enviroment, use the generate subcommand like before, but include the pt2 file.
278278

279279
```bash
280-
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
280+
python3 torchchat.py generate llama3.1 --aoti-package-path exportedModels/llama3_1_artifacts.pt2 --prompt "Hello my name is"
281281
```
282-
**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`.
283282

284283

285284
### Run using our C++ Runner
@@ -289,17 +288,11 @@ To run in a C++ enviroment, we need to build the runner binary.
289288
torchchat/utils/scripts/build_native.sh aoti
290289
```
291290

292-
To compile the AOTI generated artifacts into a `.so`:
291+
Then run the compiled executable, with the pt2.
293292
```bash
294-
make -C exportedModels/llama3_1_artifacts
293+
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
295294
```
296295

297-
Then run the compiled executable, with the compiled DSO.
298-
```bash
299-
cmake-out/aoti_run exportedModels/llama3_1_artifacts/llama3_1_artifacts.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
300-
```
301-
**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`.
302-
303296
## Mobile Execution
304297

305298
[ExecuTorch](https://github.com/pytorch/executorch) enables you to optimize your model for execution on a

install/install_requirements.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ fi
4747
# NOTE: If a newly-fetched version of the executorch repo changes the value of
4848
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
4949
# package versions.
50-
PYTORCH_NIGHTLY_VERSION=dev20240814
50+
PYTORCH_NIGHTLY_VERSION=dev20240913
5151

5252
# Nightly version for torchvision
53-
VISION_NIGHTLY_VERSION=dev20240814
53+
VISION_NIGHTLY_VERSION=dev20240913
5454

5555
# Nightly version for torchao
5656
AO_NIGHTLY_VERSION=dev20240905
@@ -77,7 +77,7 @@ fi
7777

7878
# pip packages needed by exir.
7979
REQUIREMENTS_TO_INSTALL=(
80-
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
80+
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
8181
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
8282
)
8383

runner/run.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ LICENSE file in the root directory of this source tree.
3131
#endif
3232

3333
#ifdef __AOTI_MODEL__
34-
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
35-
#ifdef USE_CUDA
36-
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
37-
#endif
34+
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
3835
torch::Device aoti_device(torch::kCPU);
3936

4037
#else // __ET_MODEL__
@@ -93,7 +90,7 @@ typedef struct {
9390
RunState state; // buffers for the "wave" of activations in the forward pass
9491

9592
#ifdef __AOTI_MODEL__
96-
torch::inductor::AOTIModelContainerRunner* runner;
93+
torch::inductor::AOTIModelPackageLoader* runner;
9794
#else // __ET_MODEL__
9895
Module* runner;
9996
#endif
@@ -143,16 +140,8 @@ void build_transformer(
143140
malloc_run_state(&t->state, &t->config);
144141

145142
#ifdef __AOTI_MODEL__
146-
#ifdef USE_CUDA
147-
if (aoti_device.type() == torch::kCUDA) {
148-
t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path);
149-
aoti_device = torch::Device(torch::kCUDA);
150-
} else {
151-
#else
152-
{
153-
#endif
154-
t->runner = new torch::inductor::AOTIModelContainerRunnerCpu(model_path);
155-
}
143+
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
144+
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA);
156145
#else //__ET_MODEL__
157146
t->runner = new Module(
158147
/* path to PTE model */ model_path,

torchchat/cli/builder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,9 +536,12 @@ def _initialize_model(
536536
# function, e.g. calling model.setup_cache will NOT touch
537537
# AOTI compiled and maintained model buffers such as kv_cache.
538538
from torch._inductor.package import load_package
539-
model.forward = load_package(
540-
str(builder_args.aoti_package_path.absolute()), builder_args.device
539+
aoti_compiled_model = load_package(
540+
str(builder_args.aoti_package_path.absolute())
541541
)
542+
model.forward = aoti_compiled_model
543+
metadata = aoti_compiled_model.get_metadata()
544+
builder_args.device = metadata["AOTI_DEVICE_KEY"]
542545
except:
543546
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")
544547

torchchat/export.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212

1313
from torch.export import Dim
14+
import torch._inductor
1415

1516
from torchchat.cli.builder import (
1617
_initialize_model,
@@ -67,16 +68,27 @@ def export_for_server(
6768
dynamic_shapes = None
6869

6970
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
70-
path = torch._export.aot_compile(
71+
metadata = {"model_key": model_key}
72+
options = {"aot_inductor.package": package, "aot_inductor.metadata": metadata}
73+
if not package:
74+
options = {"aot_inductor.output_path": output_path}
75+
76+
ep = torch.export.export(
7177
model,
72-
args=input,
73-
options={
74-
"aot_inductor.output_path": output_path,
75-
"aot_inductor.package": package,
76-
},
78+
input,
7779
dynamic_shapes=dynamic_shapes,
7880
)
79-
print(f"The generated DSO model can be found at: {path}")
81+
path = torch._inductor.aot_compile(
82+
ep.module(),
83+
input,
84+
options=options,
85+
)
86+
87+
if package:
88+
from torch._inductor.package import package_aoti
89+
path = package_aoti(output_path, path)
90+
91+
print(f"The generated packaged model can be found at: {path}")
8092
return path
8193

8294

torchchat/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def validate_build(
135135
reason = "model compilation"
136136
if builder_args.aoti_package_path:
137137
model_type = "PT2"
138+
if builder_args.dso_path:
139+
model_type = "DSO"
138140
if builder_args.pte_path:
139141
model_type = "PTE"
140142
if model_type and reason:
@@ -148,7 +150,7 @@ def from_args(cls, args):
148150
pte_path = getattr(args, "pte_path", None)
149151
aoti_package_path = getattr(args, "aoti_package_path", None)
150152
sequential_prefill = (
151-
args.sequential_prefill or bool(aoti_package_path) or bool(pte_path)
153+
args.sequential_prefill or bool(aoti_package_path) or bool(pte_path) or bool(dso_path)
152154
)
153155

154156
return cls(

0 commit comments

Comments
 (0)