Skip to content

Commit 937ab9a

Browse files
committed
Update AOTI package
1 parent f730056 commit 937ab9a

File tree

9 files changed

+147
-62
lines changed

9 files changed

+147
-62
lines changed

README.md

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy an
182182
[skip default]: end
183183

184184
### Server
185-
This mode exposes a REST API for interacting with a model.
185+
This mode exposes a REST API for interacting with a model.
186186
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.
187187

188188
To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
@@ -255,13 +255,14 @@ Use the "Max Response Tokens" slider to limit the maximum number of tokens gener
255255
## Desktop/Server Execution
256256

257257
### AOTI (AOT Inductor)
258-
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a [DSO](https://en.wikipedia.org/wiki/Shared_library) model (represented by a file with extension `.so`)
258+
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a zipped PT2 file containing all the artifacts generated by AOTInductor, and a [.so](https://en.wikipedia.org/wiki/Shared_library) file with the runnable contents
259259
that is then loaded for inference. This can be done with both Python and C++ enviroments.
260260

261261
The following example exports and executes the Llama3.1 8B Instruct
262262
model. The first command compiles and performs the actual export.
263-
```
264-
python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.so
263+
264+
```bash
265+
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts.pt2
265266
```
266267

267268
> [!NOTE]
@@ -273,12 +274,11 @@ case visit our [customization guide](docs/model_customization.md).
273274

274275
### Run in a Python Enviroment
275276

276-
To run in a python enviroment, use the generate subcommand like before, but include the dso file.
277+
To run in a python enviroment, use the generate subcommand like before, but include the pt2 file.
277278

279+
```bash
280+
python3 torchchat.py generate llama3.1 --aoti-package-path exportedModels/llama3_1_artifacts.pt2 --prompt "Hello my name is"
278281
```
279-
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
280-
```
281-
**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`.
282282

283283

284284
### Run using our C++ Runner
@@ -288,11 +288,10 @@ To run in a C++ enviroment, we need to build the runner binary.
288288
torchchat/utils/scripts/build_native.sh aoti
289289
```
290290

291-
Then run the compiled executable, with the exported DSO from earlier.
291+
Then run the compiled executable, with the pt2.
292292
```bash
293-
cmake-out/aoti_run exportedModels/llama3.1.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
293+
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
294294
```
295-
**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`.
296295

297296
## Mobile Execution
298297

install/install_requirements.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ fi
4747
# NOTE: If a newly-fetched version of the executorch repo changes the value of
4848
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
4949
# package versions.
50-
PYTORCH_NIGHTLY_VERSION=dev20240814
50+
PYTORCH_NIGHTLY_VERSION=dev20240913
5151

5252
# Nightly version for torchvision
53-
VISION_NIGHTLY_VERSION=dev20240814
53+
VISION_NIGHTLY_VERSION=dev20240913
5454

5555
# Nightly version for torchtune
5656
TUNE_NIGHTLY_VERSION=dev20240916
@@ -74,7 +74,7 @@ fi
7474

7575
# pip packages needed by exir.
7676
REQUIREMENTS_TO_INSTALL=(
77-
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
77+
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
7878
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
7979
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
8080
)

runner/run.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ LICENSE file in the root directory of this source tree.
3131
#endif
3232

3333
#ifdef __AOTI_MODEL__
34-
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
35-
#ifdef USE_CUDA
36-
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
37-
#endif
34+
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
3835
torch::Device aoti_device(torch::kCPU);
3936

4037
#else // __ET_MODEL__
@@ -93,7 +90,7 @@ typedef struct {
9390
RunState state; // buffers for the "wave" of activations in the forward pass
9491

9592
#ifdef __AOTI_MODEL__
96-
torch::inductor::AOTIModelContainerRunner* runner;
93+
torch::inductor::AOTIModelPackageLoader* runner;
9794
#else // __ET_MODEL__
9895
Module* runner;
9996
#endif
@@ -143,16 +140,8 @@ void build_transformer(
143140
malloc_run_state(&t->state, &t->config);
144141

145142
#ifdef __AOTI_MODEL__
146-
#ifdef USE_CUDA
147-
if (aoti_device.type() == torch::kCUDA) {
148-
t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path);
149-
aoti_device = torch::Device(torch::kCUDA);
150-
} else {
151-
#else
152-
{
153-
#endif
154-
t->runner = new torch::inductor::AOTIModelContainerRunnerCpu(model_path);
155-
}
143+
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
144+
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA);
156145
#else //__ET_MODEL__
157146
t->runner = new Module(
158147
/* path to PTE model */ model_path,

torchchat/cli/builder.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class BuilderArgs:
5757
gguf_path: Optional[Union[Path, str]] = None
5858
gguf_kwargs: Optional[Dict[str, Any]] = None
5959
dso_path: Optional[Union[Path, str]] = None
60+
aoti_package_path: Optional[Union[Path, str]] = None
6061
pte_path: Optional[Union[Path, str]] = None
6162
device: Optional[str] = None
6263
precision: torch.dtype = torch.float32
@@ -76,28 +77,29 @@ def __post_init__(self):
7677
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
7778
or (self.gguf_path and self.gguf_path.is_file())
7879
or (self.dso_path and Path(self.dso_path).is_file())
80+
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
7981
or (self.pte_path and Path(self.pte_path).is_file())
8082
):
8183
raise RuntimeError(
8284
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
8385
)
8486

85-
if self.dso_path and self.pte_path:
86-
raise RuntimeError("specify either DSO path or PTE path, but not both")
87+
if self.aoti_package_path and self.pte_path:
88+
raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one")
8789

88-
if self.checkpoint_path and (self.dso_path or self.pte_path):
90+
if self.checkpoint_path and (self.aoti_package_path or self.pte_path):
8991
print(
90-
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
92+
"Warning: checkpoint path ignored because an exported AOTI or PTE path specified"
9193
)
92-
if self.checkpoint_dir and (self.dso_path or self.pte_path):
94+
if self.checkpoint_dir and (self.aoti_package_path or self.pte_path):
9395
print(
94-
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
96+
"Warning: checkpoint dir ignored because an exported AOTI or PTE path specified"
9597
)
96-
if self.gguf_path and (self.dso_path or self.pte_path):
98+
if self.gguf_path and (self.aoti_package_path or self.pte_path):
9799
print(
98-
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
100+
"Warning: GGUF path ignored because an exported AOTI or PTE path specified"
99101
)
100-
if not (self.dso_path) and not (self.pte_path):
102+
if not (self.aoti_package_path) and not (self.pte_path):
101103
self.prefill_possible = True
102104

103105
@classmethod
@@ -127,6 +129,7 @@ def from_args(cls, args): # -> BuilderArgs:
127129

128130
dso_path = getattr(args, "dso_path", None)
129131
pte_path = getattr(args, "pte_path", None)
132+
aoti_package_path = getattr(args, "aoti_package_path", None)
130133

131134
is_chat_model = False
132135
if args.is_chat_model:
@@ -137,6 +140,7 @@ def from_args(cls, args): # -> BuilderArgs:
137140
checkpoint_dir,
138141
dso_path,
139142
pte_path,
143+
aoti_package_path,
140144
args.gguf_path,
141145
]:
142146
if path is not None:
@@ -151,6 +155,7 @@ def from_args(cls, args): # -> BuilderArgs:
151155
is_chat_model = True
152156

153157
output_pte_path = getattr(args, "output_pte_path", None)
158+
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
154159
output_dso_path = getattr(args, "output_dso_path", None)
155160
if output_pte_path and args.dtype.startswith("fast"):
156161
if args.dtype == "fast":
@@ -172,10 +177,11 @@ def from_args(cls, args): # -> BuilderArgs:
172177
gguf_path=args.gguf_path,
173178
gguf_kwargs=None,
174179
dso_path=dso_path,
180+
aoti_package_path=aoti_package_path,
175181
pte_path=pte_path,
176182
device=args.device,
177183
precision=dtype,
178-
setup_caches=(output_dso_path or output_pte_path),
184+
setup_caches=(output_dso_path or output_pte_path or output_aoti_package_path),
179185
use_distributed=args.distributed,
180186
is_chat_model=is_chat_model,
181187
dynamic_shapes=getattr(args, "dynamic_shapes", False),
@@ -190,6 +196,7 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
190196
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
191197
speculative_builder_args.gguf_path = None
192198
speculative_builder_args.dso_path = None
199+
speculative_builder_args.aoti_package_path = None
193200
speculative_builder_args.pte_path = None
194201
return speculative_builder_args
195202

@@ -482,11 +489,12 @@ def _initialize_model(
482489
):
483490
print("Loading model...")
484491

485-
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
492+
if builder_args.gguf_path and (builder_args.dso_path or builder_args.aoti_package_path or builder_args.pte_path):
486493
print("Setting gguf_kwargs for generate.")
487494
is_dso = builder_args.dso_path is not None
495+
is_aoti_package = builder_args.aoti_package_path is not None
488496
is_pte = builder_args.pte_path is not None
489-
assert not (is_dso and is_pte)
497+
assert not (is_dso and is_aoti_package and is_pte)
490498
assert builder_args.gguf_kwargs is None
491499
# TODO: make GGUF load independent of backend
492500
# currently not working because AVX int_mm broken
@@ -520,6 +528,39 @@ def _initialize_model(
520528
)
521529
except:
522530
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
531+
532+
elif builder_args.aoti_package_path:
533+
if not is_cuda_or_cpu_device(builder_args.device):
534+
print(
535+
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
536+
)
537+
builder_args.device = "cpu"
538+
539+
# assert (
540+
# quantize is None or quantize == "{ }"
541+
# ), "quantize not valid for exported PT2 model. Specify quantization during export."
542+
543+
with measure_time("Time to load model: {time:.02f} seconds"):
544+
model = _load_model(builder_args, only_config=True)
545+
device_sync(device=builder_args.device)
546+
547+
try:
548+
# Replace model forward with the AOT-compiled forward
549+
# This is a hacky way to quickly demo AOTI's capability.
550+
# model is still a Python object, and any mutation to its
551+
# attributes will NOT be seen on by AOTI-compiled forward
552+
# function, e.g. calling model.setup_cache will NOT touch
553+
# AOTI compiled and maintained model buffers such as kv_cache.
554+
from torch._inductor.package import load_package
555+
aoti_compiled_model = load_package(
556+
str(builder_args.aoti_package_path.absolute())
557+
)
558+
model.forward = aoti_compiled_model
559+
metadata = aoti_compiled_model.get_metadata()
560+
builder_args.device = metadata["AOTI_DEVICE_KEY"]
561+
except:
562+
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")
563+
523564
elif builder_args.pte_path:
524565
if not is_cpu_device(builder_args.device):
525566
print(

torchchat/cli/cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ def _add_export_output_path_args(parser) -> None:
191191
default=None,
192192
help="Output to the specified AOT Inductor .dso model file",
193193
)
194+
exclusive_parser.add_argument(
195+
"--output-aoti-package-path",
196+
type=str,
197+
default=None,
198+
help="Output directory for AOTInductor compiled artifacts",
199+
)
194200

195201

196202
def _add_export_args(parser) -> None:
@@ -220,6 +226,12 @@ def _add_exported_input_path_args(parser) -> None:
220226
default=None,
221227
help="Use the specified AOT Inductor .dso model file",
222228
)
229+
exclusive_parser.add_argument(
230+
"--aoti-package-path",
231+
type=Path,
232+
default=None,
233+
help="Use the specified directory containing AOT Inductor compiled files",
234+
)
223235
exclusive_parser.add_argument(
224236
"--pte-path",
225237
type=Path,

0 commit comments

Comments
 (0)