Skip to content

Commit 6ec5b35

Browse files
committed
[aoti] Add cpp packaging for aoti + loading in python
1 parent 5d5664c commit 6ec5b35

File tree

8 files changed

+130
-39
lines changed

8 files changed

+130
-39
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__/
66
# C extensions
77
*.so
88

9+
.vscode
910
.model-artifacts/
1011
.venv
1112
.torchchat
@@ -21,3 +22,6 @@ runner-aoti/cmake-out/*
2122

2223
# debug / logging files
2324
system_info.txt
25+
checkpoints/
26+
exportedModels/
27+
cmake-out/

README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,9 @@ that is then loaded for inference. This can be done with both Python and C++ env
264264

265265
The following example exports and executes the Llama3.1 8B Instruct
266266
model. The first command compiles and performs the actual export.
267-
```
268-
python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.so
267+
268+
```bash
269+
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts
269270
```
270271

271272
> [!NOTE]
@@ -279,7 +280,7 @@ case visit our [customization guide](docs/model_customization.md).
279280

280281
To run in a python enviroment, use the generate subcommand like before, but include the dso file.
281282

282-
```
283+
```bash
283284
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
284285
```
285286
**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`.
@@ -292,9 +293,14 @@ To run in a C++ enviroment, we need to build the runner binary.
292293
scripts/build_native.sh aoti
293294
```
294295

295-
Then run the compiled executable, with the exported DSO from earlier.
296+
To compile the AOTI generated artifacts into a `.so`:
297+
```bash
298+
make -C exportedModels/llama3_1_artifacts
299+
```
300+
301+
Then run the compiled executable, with the compiled DSO.
296302
```bash
297-
cmake-out/aoti_run exportedModels/llama3.1.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
303+
cmake-out/aoti_run exportedModels/llama3_1_artifacts/llama3_1_artifacts.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
298304
```
299305
**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`.
300306

build/builder.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class BuilderArgs:
4141
gguf_path: Optional[Union[Path, str]] = None
4242
gguf_kwargs: Optional[Dict[str, Any]] = None
4343
dso_path: Optional[Union[Path, str]] = None
44+
aoti_package_path: Optional[Union[Path, str]] = None
4445
pte_path: Optional[Union[Path, str]] = None
4546
device: Optional[str] = None
4647
precision: torch.dtype = torch.float32
@@ -60,28 +61,29 @@ def __post_init__(self):
6061
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
6162
or (self.gguf_path and self.gguf_path.is_file())
6263
or (self.dso_path and Path(self.dso_path).is_file())
64+
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
6365
or (self.pte_path and Path(self.pte_path).is_file())
6466
):
6567
raise RuntimeError(
6668
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
6769
)
6870

69-
if self.dso_path and self.pte_path:
70-
raise RuntimeError("specify either DSO path or PTE path, but not both")
71+
if self.pte_path and self.aoti_package_path:
72+
raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one")
7173

72-
if self.checkpoint_path and (self.dso_path or self.pte_path):
74+
if self.checkpoint_path and (self.pte_path or self.aoti_package_path):
7375
print(
74-
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
76+
"Warning: checkpoint path ignored because an exported AOTI or PTE path specified"
7577
)
76-
if self.checkpoint_dir and (self.dso_path or self.pte_path):
78+
if self.checkpoint_dir and (self.pte_path or self.aoti_package_path):
7779
print(
78-
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
80+
"Warning: checkpoint dir ignored because an exported AOTI or PTE path specified"
7981
)
80-
if self.gguf_path and (self.dso_path or self.pte_path):
82+
if self.gguf_path and (self.pte_path or self.aoti_package_path):
8183
print(
82-
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
84+
"Warning: GGUF path ignored because an exported AOTI or PTE path specified"
8385
)
84-
if not (self.dso_path) and not (self.pte_path):
86+
if not (self.dso_path) and not (self.aoti_package_path):
8587
self.prefill_possible = True
8688

8789
@classmethod
@@ -111,6 +113,7 @@ def from_args(cls, args): # -> BuilderArgs:
111113

112114
dso_path = getattr(args, "dso_path", None)
113115
pte_path = getattr(args, "pte_path", None)
116+
aoti_package_path = getattr(args, "aoti_package_path", None)
114117

115118
is_chat_model = False
116119
if args.is_chat_model:
@@ -121,6 +124,7 @@ def from_args(cls, args): # -> BuilderArgs:
121124
checkpoint_dir,
122125
dso_path,
123126
pte_path,
127+
aoti_package_path,
124128
args.gguf_path,
125129
]:
126130
if path is not None:
@@ -136,6 +140,7 @@ def from_args(cls, args): # -> BuilderArgs:
136140

137141

138142
output_pte_path = getattr(args, "output_pte_path", None)
143+
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
139144
output_dso_path = getattr(args, "output_dso_path", None)
140145
if output_pte_path and args.dtype.startswith("fast"):
141146
if args.dtype == "fast":
@@ -157,10 +162,11 @@ def from_args(cls, args): # -> BuilderArgs:
157162
gguf_path=args.gguf_path,
158163
gguf_kwargs=None,
159164
dso_path=dso_path,
165+
aoti_package_path=aoti_package_path,
160166
pte_path=pte_path,
161167
device=args.device,
162168
precision=dtype,
163-
setup_caches=(output_dso_path or output_pte_path),
169+
setup_caches=(output_dso_path or output_pte_path or output_aoti_package_path),
164170
use_distributed=args.distributed,
165171
is_chat_model=is_chat_model,
166172
dynamic_shapes=getattr(args, "dynamic_shapes", False),
@@ -175,6 +181,7 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
175181
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
176182
speculative_builder_args.gguf_path = None
177183
speculative_builder_args.dso_path = None
184+
speculative_builder_args.aoti_package_path = None
178185
speculative_builder_args.pte_path = None
179186
return speculative_builder_args
180187

@@ -450,11 +457,12 @@ def _initialize_model(
450457
):
451458
print("Loading model...")
452459

453-
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
460+
if builder_args.gguf_path and (builder_args.dso_path or builder_args.aoti_package_path or builder_args.pte_path):
454461
print("Setting gguf_kwargs for generate.")
455462
is_dso = builder_args.dso_path is not None
463+
is_aoti_package = builder_args.aoti_package_path is not None
456464
is_pte = builder_args.pte_path is not None
457-
assert not (is_dso and is_pte)
465+
assert not (is_dso and is_aoti_package and is_pte)
458466
assert builder_args.gguf_kwargs is None
459467
# TODO: make GGUF load independent of backend
460468
# currently not working because AVX int_mm broken
@@ -488,6 +496,36 @@ def _initialize_model(
488496
)
489497
except:
490498
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
499+
500+
elif builder_args.aoti_package_path:
501+
if not is_cuda_or_cpu_device(builder_args.device):
502+
print(
503+
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
504+
)
505+
builder_args.device = "cpu"
506+
507+
# assert (
508+
# quantize is None or quantize == "{ }"
509+
# ), "quantize not valid for exported PT2 model. Specify quantization during export."
510+
511+
with measure_time("Time to load model: {time:.02f} seconds"):
512+
model = _load_model(builder_args, only_config=True)
513+
device_sync(device=builder_args.device)
514+
515+
try:
516+
# Replace model forward with the AOT-compiled forward
517+
# This is a hacky way to quickly demo AOTI's capability.
518+
# model is still a Python object, and any mutation to its
519+
# attributes will NOT be seen on by AOTI-compiled forward
520+
# function, e.g. calling model.setup_cache will NOT touch
521+
# AOTI compiled and maintained model buffers such as kv_cache.
522+
from torch._inductor.package import load_package
523+
model.forward = load_package(
524+
str(builder_args.aoti_package_path.absolute()), builder_args.device
525+
)
526+
except:
527+
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")
528+
491529
elif builder_args.pte_path:
492530
if not is_cpu_device(builder_args.device):
493531
print(

build/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,42 +69,46 @@ def unpack_packed_weights(
6969

7070
active_builder_args_dso = None
7171
active_builder_args_pte = None
72+
active_builder_args_aoti_package = None
7273

7374

74-
def set_backend(dso, pte):
75+
def set_backend(dso, pte, aoti_package):
7576
global active_builder_args_dso
7677
global active_builder_args_pte
7778
active_builder_args_dso = dso
79+
active_builder_args_aoti_package = aoti_package
7880
active_builder_args_pte = pte
7981

8082

8183
def use_aoti_backend() -> bool:
8284
global active_builder_args_dso
85+
global active_builder_args_aoti_package
8386
global active_builder_args_pte
8487

8588
# eager == aoti, which is when backend has not been explicitly set
86-
if (not active_builder_args_dso) and not (active_builder_args_pte):
89+
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
8790
return True
8891

89-
if active_builder_args_pte and active_builder_args_dso:
92+
if active_builder_args_pte and active_builder_args_aoti_package:
9093
raise RuntimeError(
91-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
94+
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
9295
)
9396

94-
return bool(active_builder_args_dso)
97+
return bool(active_builder_args_dso) or bool(active_builder_args_aoti_package)
9598

9699

97100
def use_et_backend() -> bool:
98101
global active_builder_args_dso
102+
global active_builder_args_aoti_package
99103
global active_builder_args_pte
100104

101105
# eager == aoti, which is when backend has not been explicitly set
102-
if not (active_builder_args_pte or active_builder_args_dso):
103-
return False
106+
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
107+
return True
104108

105-
if active_builder_args_pte and active_builder_args_dso:
109+
if active_builder_args_pte and active_builder_args_aoti_package:
106110
raise RuntimeError(
107-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
111+
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
108112
)
109113

110114
return bool(active_builder_args_pte)

cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,12 @@ def _add_export_output_path_args(parser) -> None:
186186
default=None,
187187
help="Output to the specified AOT Inductor .dso model file",
188188
)
189+
output_path_parser.add_argument(
190+
"--output-aoti-package-path",
191+
type=str,
192+
default=None,
193+
help="Output directory for AOTInductor compiled artifacts",
194+
)
189195

190196

191197
def _add_export_args(parser) -> None:
@@ -215,6 +221,12 @@ def _add_exported_input_path_args(parser) -> None:
215221
default=None,
216222
help="Use the specified AOT Inductor .dso model file",
217223
)
224+
exclusive_parser.add_argument(
225+
"--aoti-package-path",
226+
type=Path,
227+
default=None,
228+
help="Use the specified directory containing AOT Inductor compiled files",
229+
)
218230
exclusive_parser.add_argument(
219231
"--pte-path",
220232
type=Path,

eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def main(args) -> None:
259259

260260
if compile:
261261
assert not (
262-
builder_args.dso_path or builder_args.pte_path
262+
builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path
263263
), "cannot compile exported model"
264264
model_forward = torch.compile(
265265
model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True
@@ -287,6 +287,8 @@ def main(args) -> None:
287287
)
288288
if builder_args.dso_path:
289289
print(f"For model {builder_args.dso_path}")
290+
if builder_args.aoti_package_path:
291+
print(f"For model {builder_args.aoti_package_path}")
290292
elif builder_args.pte_path:
291293
print(f"For model {builder_args.pte_path}")
292294
elif builder_args.checkpoint_path:

0 commit comments

Comments
 (0)