Skip to content

Commit dd34475

Browse files
committed
[aoti] Add cpp packaging for aoti + loading in python
1 parent 6fae164 commit dd34475

File tree

7 files changed

+136
-38
lines changed

7 files changed

+136
-38
lines changed

README.md

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,9 @@ that is then loaded for inference. This can be done with both Python and C++ env
260260

261261
The following example exports and executes the Llama3.1 8B Instruct
262262
model. The first command compiles and performs the actual export.
263-
```
264-
python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.so
263+
264+
```bash
265+
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts
265266
```
266267

267268
> [!NOTE]
@@ -275,7 +276,7 @@ case visit our [customization guide](docs/model_customization.md).
275276

276277
To run in a python enviroment, use the generate subcommand like before, but include the dso file.
277278

278-
```
279+
```bash
279280
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
280281
```
281282
**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`.
@@ -288,9 +289,14 @@ To run in a C++ enviroment, we need to build the runner binary.
288289
torchchat/utils/scripts/build_native.sh aoti
289290
```
290291

291-
Then run the compiled executable, with the exported DSO from earlier.
292+
To compile the AOTI generated artifacts into a `.so`:
293+
```bash
294+
make -C exportedModels/llama3_1_artifacts
295+
```
296+
297+
Then run the compiled executable, with the compiled DSO.
292298
```bash
293-
cmake-out/aoti_run exportedModels/llama3.1.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
299+
cmake-out/aoti_run exportedModels/llama3_1_artifacts/llama3_1_artifacts.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
294300
```
295301
**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`.
296302

torchchat/cli/builder.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class BuilderArgs:
5050
gguf_path: Optional[Union[Path, str]] = None
5151
gguf_kwargs: Optional[Dict[str, Any]] = None
5252
dso_path: Optional[Union[Path, str]] = None
53+
aoti_package_path: Optional[Union[Path, str]] = None
5354
pte_path: Optional[Union[Path, str]] = None
5455
device: Optional[str] = None
5556
precision: torch.dtype = torch.float32
@@ -69,28 +70,29 @@ def __post_init__(self):
6970
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
7071
or (self.gguf_path and self.gguf_path.is_file())
7172
or (self.dso_path and Path(self.dso_path).is_file())
73+
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
7274
or (self.pte_path and Path(self.pte_path).is_file())
7375
):
7476
raise RuntimeError(
7577
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
7678
)
7779

78-
if self.dso_path and self.pte_path:
79-
raise RuntimeError("specify either DSO path or PTE path, but not both")
80+
if self.pte_path and self.aoti_package_path:
81+
raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one")
8082

81-
if self.checkpoint_path and (self.dso_path or self.pte_path):
83+
if self.checkpoint_path and (self.pte_path or self.aoti_package_path):
8284
print(
83-
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
85+
"Warning: checkpoint path ignored because an exported AOTI or PTE path specified"
8486
)
85-
if self.checkpoint_dir and (self.dso_path or self.pte_path):
87+
if self.checkpoint_dir and (self.pte_path or self.aoti_package_path):
8688
print(
87-
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
89+
"Warning: checkpoint dir ignored because an exported AOTI or PTE path specified"
8890
)
89-
if self.gguf_path and (self.dso_path or self.pte_path):
91+
if self.gguf_path and (self.pte_path or self.aoti_package_path):
9092
print(
91-
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
93+
"Warning: GGUF path ignored because an exported AOTI or PTE path specified"
9294
)
93-
if not (self.dso_path) and not (self.pte_path):
95+
if not (self.dso_path) and not (self.aoti_package_path):
9496
self.prefill_possible = True
9597

9698
@classmethod
@@ -120,6 +122,7 @@ def from_args(cls, args): # -> BuilderArgs:
120122

121123
dso_path = getattr(args, "dso_path", None)
122124
pte_path = getattr(args, "pte_path", None)
125+
aoti_package_path = getattr(args, "aoti_package_path", None)
123126

124127
is_chat_model = False
125128
if args.is_chat_model:
@@ -130,6 +133,7 @@ def from_args(cls, args): # -> BuilderArgs:
130133
checkpoint_dir,
131134
dso_path,
132135
pte_path,
136+
aoti_package_path,
133137
args.gguf_path,
134138
]:
135139
if path is not None:
@@ -145,6 +149,7 @@ def from_args(cls, args): # -> BuilderArgs:
145149

146150

147151
output_pte_path = getattr(args, "output_pte_path", None)
152+
output_aoti_package_path = getattr(args, "output_aoti_package_path", None)
148153
output_dso_path = getattr(args, "output_dso_path", None)
149154
if output_pte_path and args.dtype.startswith("fast"):
150155
if args.dtype == "fast":
@@ -166,10 +171,11 @@ def from_args(cls, args): # -> BuilderArgs:
166171
gguf_path=args.gguf_path,
167172
gguf_kwargs=None,
168173
dso_path=dso_path,
174+
aoti_package_path=aoti_package_path,
169175
pte_path=pte_path,
170176
device=args.device,
171177
precision=dtype,
172-
setup_caches=(output_dso_path or output_pte_path),
178+
setup_caches=(output_dso_path or output_pte_path or output_aoti_package_path),
173179
use_distributed=args.distributed,
174180
is_chat_model=is_chat_model,
175181
dynamic_shapes=getattr(args, "dynamic_shapes", False),
@@ -184,6 +190,7 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
184190
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
185191
speculative_builder_args.gguf_path = None
186192
speculative_builder_args.dso_path = None
193+
speculative_builder_args.aoti_package_path = None
187194
speculative_builder_args.pte_path = None
188195
return speculative_builder_args
189196

@@ -463,11 +470,12 @@ def _initialize_model(
463470
):
464471
print("Loading model...")
465472

466-
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
473+
if builder_args.gguf_path and (builder_args.dso_path or builder_args.aoti_package_path or builder_args.pte_path):
467474
print("Setting gguf_kwargs for generate.")
468475
is_dso = builder_args.dso_path is not None
476+
is_aoti_package = builder_args.aoti_package_path is not None
469477
is_pte = builder_args.pte_path is not None
470-
assert not (is_dso and is_pte)
478+
assert not (is_dso and is_aoti_package and is_pte)
471479
assert builder_args.gguf_kwargs is None
472480
# TODO: make GGUF load independent of backend
473481
# currently not working because AVX int_mm broken
@@ -501,6 +509,36 @@ def _initialize_model(
501509
)
502510
except:
503511
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
512+
513+
elif builder_args.aoti_package_path:
514+
if not is_cuda_or_cpu_device(builder_args.device):
515+
print(
516+
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
517+
)
518+
builder_args.device = "cpu"
519+
520+
# assert (
521+
# quantize is None or quantize == "{ }"
522+
# ), "quantize not valid for exported PT2 model. Specify quantization during export."
523+
524+
with measure_time("Time to load model: {time:.02f} seconds"):
525+
model = _load_model(builder_args, only_config=True)
526+
device_sync(device=builder_args.device)
527+
528+
try:
529+
# Replace model forward with the AOT-compiled forward
530+
# This is a hacky way to quickly demo AOTI's capability.
531+
# model is still a Python object, and any mutation to its
532+
# attributes will NOT be seen on by AOTI-compiled forward
533+
# function, e.g. calling model.setup_cache will NOT touch
534+
# AOTI compiled and maintained model buffers such as kv_cache.
535+
from torch._inductor.package import load_package
536+
model.forward = load_package(
537+
str(builder_args.aoti_package_path.absolute()), builder_args.device
538+
)
539+
except:
540+
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")
541+
504542
elif builder_args.pte_path:
505543
if not is_cpu_device(builder_args.device):
506544
print(

torchchat/cli/cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ def _add_export_output_path_args(parser) -> None:
191191
default=None,
192192
help="Output to the specified AOT Inductor .dso model file",
193193
)
194+
output_path_parser.add_argument(
195+
"--output-aoti-package-path",
196+
type=str,
197+
default=None,
198+
help="Output directory for AOTInductor compiled artifacts",
199+
)
194200

195201

196202
def _add_export_args(parser) -> None:
@@ -220,6 +226,12 @@ def _add_exported_input_path_args(parser) -> None:
220226
default=None,
221227
help="Use the specified AOT Inductor .dso model file",
222228
)
229+
exclusive_parser.add_argument(
230+
"--aoti-package-path",
231+
type=Path,
232+
default=None,
233+
help="Use the specified directory containing AOT Inductor compiled files",
234+
)
223235
exclusive_parser.add_argument(
224236
"--pte-path",
225237
type=Path,

torchchat/export.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535
def export_for_server(
3636
model: nn.Module,
3737
device: Optional[str] = "cpu",
38-
output_path: str = "model.dso",
38+
output_path: str = "model.pt2",
3939
dynamic_shapes: bool = False,
40+
package: bool = True,
41+
model_key: str = "",
4042
) -> str:
4143
"""
4244
Export the model using AOT Compile to get a .dso for server use cases.
@@ -65,14 +67,17 @@ def export_for_server(
6567
dynamic_shapes = None
6668

6769
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
68-
so = torch._export.aot_compile(
70+
path = torch._export.aot_compile(
6971
model,
7072
args=input,
71-
options={"aot_inductor.output_path": output_path},
73+
options={
74+
"aot_inductor.output_path": output_path,
75+
"aot_inductor.package": package,
76+
},
7277
dynamic_shapes=dynamic_shapes,
7378
)
74-
print(f"The generated DSO model can be found at: {so}")
75-
return so
79+
print(f"The generated DSO model can be found at: {path}")
80+
return path
7681

7782

7883
"""
@@ -335,14 +340,16 @@ def main(args):
335340

336341
print(f"Using device={builder_args.device}")
337342
set_precision(builder_args.precision)
338-
set_backend(dso=args.output_dso_path, pte=args.output_pte_path)
343+
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path)
339344

340345
builder_args.dso_path = None
341346
builder_args.pte_path = None
347+
builder_args.aoti_package_path = None
342348
builder_args.setup_caches = True
343349

344350
output_pte_path = args.output_pte_path
345351
output_dso_path = args.output_dso_path
352+
output_aoti_package_path = args.output_aoti_package_path
346353

347354
if output_pte_path and builder_args.device != "cpu":
348355
print(
@@ -380,6 +387,7 @@ def main(args):
380387
)
381388
model_to_pte = model
382389
model_to_dso = model
390+
model_to_aoti_package = model
383391
else:
384392
if output_pte_path:
385393
_set_gguf_kwargs(builder_args, is_et=True, context="export")
@@ -389,13 +397,14 @@ def main(args):
389397
)
390398
_unset_gguf_kwargs(builder_args)
391399

392-
if output_dso_path:
400+
if output_dso_path or output_aoti_package_path:
393401
_set_gguf_kwargs(builder_args, is_et=False, context="export")
394-
model_to_dso = _initialize_model(
402+
model_to_aoti_package = _initialize_model(
395403
builder_args,
396404
quantize,
397405
support_tensor_subclass=False,
398406
)
407+
model_to_dso = model_to_aoti_package
399408
_unset_gguf_kwargs(builder_args)
400409

401410
with torch.no_grad():
@@ -409,6 +418,7 @@ def main(args):
409418
"Export with executorch requested but ExecuTorch could not be loaded"
410419
)
411420
print(executorch_exception)
421+
412422
if output_dso_path:
413423
output_dso_path = str(os.path.abspath(output_dso_path))
414424
print(f"Exporting model using AOT Inductor to {output_dso_path}")
@@ -417,4 +427,17 @@ def main(args):
417427
builder_args.device,
418428
output_dso_path,
419429
builder_args.dynamic_shapes,
430+
package=False,
431+
)
432+
433+
if output_aoti_package_path:
434+
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
435+
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
436+
export_for_server(
437+
model_to_aoti_package,
438+
builder_args.device,
439+
output_aoti_package_path,
440+
builder_args.dynamic_shapes,
441+
package=True,
442+
model_key=builder_args.params_table,
420443
)

torchchat/generate.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ def validate_build(
133133
reason = "model compilation for prefill"
134134
if self.compile:
135135
reason = "model compilation"
136-
if builder_args.dso_path:
137-
model_type = "DSO"
136+
if builder_args.aoti_package_path:
137+
model_type = "PT2"
138138
if builder_args.pte_path:
139139
model_type = "PTE"
140140
if model_type and reason:
@@ -146,7 +146,10 @@ def validate_build(
146146
def from_args(cls, args):
147147
dso_path = getattr(args, "dso_path", None)
148148
pte_path = getattr(args, "pte_path", None)
149-
sequential_prefill = args.sequential_prefill or bool(dso_path) or bool(pte_path)
149+
aoti_package_path = getattr(args, "aoti_package_path", None)
150+
sequential_prefill = (
151+
args.sequential_prefill or bool(aoti_package_path) or bool(pte_path)
152+
)
150153

151154
return cls(
152155
prompt=getattr(args, "prompt", ""),
@@ -948,3 +951,13 @@ def main(args):
948951
torch.cuda.reset_peak_memory_stats()
949952
for _ in gen.chat(generator_args):
950953
pass
954+
955+
956+
if __name__ == "__main__":
957+
parser = argparse.ArgumentParser(description="torchchat generate CLI")
958+
verb = "generate"
959+
add_arguments_for_verb(parser, verb)
960+
args = parser.parse_args()
961+
check_args(args, verb)
962+
args = arg_init(args)
963+
main(args)

torchchat/usages/eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def main(args) -> None:
260260

261261
if compile:
262262
assert not (
263-
builder_args.dso_path or builder_args.pte_path
263+
builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path
264264
), "cannot compile exported model"
265265
model_forward = torch.compile(
266266
model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True
@@ -288,6 +288,8 @@ def main(args) -> None:
288288
)
289289
if builder_args.dso_path:
290290
print(f"For model {builder_args.dso_path}")
291+
if builder_args.aoti_package_path:
292+
print(f"For model {builder_args.aoti_package_path}")
291293
elif builder_args.pte_path:
292294
print(f"For model {builder_args.pte_path}")
293295
elif builder_args.checkpoint_path:

0 commit comments

Comments
 (0)