Skip to content

Commit 0cf4e99

Browse files
committed
[aoti] Add cpp packaging for aoti + loading in python
1 parent ee681bf commit 0cf4e99

File tree

9 files changed

+119
-38
lines changed

9 files changed

+119
-38
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__/
66
# C extensions
77
*.so
88

9+
.vscode
910
.model-artifacts/
1011
.venv
1112
.torchchat
@@ -18,3 +19,7 @@ runner-aoti/cmake-out/*
1819

1920
# pte files
2021
*.pte
22+
23+
checkpoints/
24+
exportedModels/
25+
cmake-out/

README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,13 @@ model. The first command performs the actual export, the second
150150
command loads the exported model into the Python interface to enable
151151
users to test the exported model.
152152

153-
```
153+
```bash
154154
# Compile
155-
python3 torchchat.py export llama3 --output-dso-path exportedModels/llama3.so
155+
python3 torchchat.py export llama3 --output-aoti-package-path exportedModels/llama3_artifacts --device cpu
156156

157157
# Execute the exported model using Python
158158

159-
python3 torchchat.py generate llama3 --dso-path exportedModels/llama3.so --prompt "Hello my name is"
159+
python3 torchchat.py generate llama3 --aoti-package-path exportedModels/llama3_artifacts --prompt "Hello my name is" --device cpu
160160
```
161161

162162
NOTE: If your machine has cuda add this flag for performance
@@ -172,9 +172,14 @@ To build the runner binary on your Mac or Linux:
172172
scripts/build_native.sh aoti
173173
```
174174

175+
To compile the AOTI generated artifacts into a `.so`:
176+
```bash
177+
make -C exportedModels/llama3_artifacts
178+
```
179+
175180
Execute
176181
```bash
177-
cmake-out/aoti_run exportedModels/llama3.so -z `python3 torchchat.py where llama3`/tokenizer.model -l 3 -i "Once upon a time"
182+
cmake-out/aoti_run exportedModels/llama3_artifacts/llama3_artifacts.so -z `python3 torchchat.py where llama3`/tokenizer.model -l 3 -i "Once upon a time" -d cpu
178183
```
179184

180185
## Mobile Execution

build/builder.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class BuilderArgs:
3737
gguf_path: Optional[Union[Path, str]] = None
3838
gguf_kwargs: Optional[Dict[str, Any]] = None
3939
dso_path: Optional[Union[Path, str]] = None
40+
aoti_package_path: Optional[Union[Path, str]] = None
4041
pte_path: Optional[Union[Path, str]] = None
4142
device: Optional[str] = None
4243
precision: torch.dtype = torch.float32
@@ -54,28 +55,29 @@ def __post_init__(self):
5455
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
5556
or (self.gguf_path and self.gguf_path.is_file())
5657
or (self.dso_path and Path(self.dso_path).is_file())
58+
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
5759
or (self.pte_path and Path(self.pte_path).is_file())
5860
):
5961
raise RuntimeError(
6062
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
6163
)
6264

63-
if self.dso_path and self.pte_path:
64-
raise RuntimeError("specify either DSO path or PTE path, but not both")
65+
if self.pte_path and self.aoti_package_path:
66+
raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one")
6567

66-
if self.checkpoint_path and (self.dso_path or self.pte_path):
68+
if self.checkpoint_path and (self.pte_path or self.aoti_package_path):
6769
print(
68-
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
70+
"Warning: checkpoint path ignored because an exported AOTI or PTE path specified"
6971
)
70-
if self.checkpoint_dir and (self.dso_path or self.pte_path):
72+
if self.checkpoint_dir and (self.pte_path or self.aoti_package_path):
7173
print(
72-
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
74+
"Warning: checkpoint dir ignored because an exported AOTI or PTE path specified"
7375
)
74-
if self.gguf_path and (self.dso_path or self.pte_path):
76+
if self.gguf_path and (self.pte_path or self.aoti_package_path):
7577
print(
76-
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
78+
"Warning: GGUF path ignored because an exported AOTI or PTE path specified"
7779
)
78-
if not (self.dso_path) and not (self.pte_path):
80+
if not (self.dso_path) and not (self.aoti_package_path):
7981
self.prefill_possible = True
8082

8183
@classmethod
@@ -111,6 +113,7 @@ def from_args(cls, args): # -> BuilderArgs:
111113
checkpoint_path,
112114
checkpoint_dir,
113115
args.dso_path,
116+
args.aoti_package_path,
114117
args.pte_path,
115118
args.gguf_path,
116119
]:
@@ -145,10 +148,11 @@ def from_args(cls, args): # -> BuilderArgs:
145148
gguf_path=args.gguf_path,
146149
gguf_kwargs=None,
147150
dso_path=args.dso_path,
151+
aoti_package_path=args.aoti_package_path,
148152
pte_path=args.pte_path,
149153
device=args.device,
150154
precision=dtype,
151-
setup_caches=(args.output_dso_path or args.output_pte_path),
155+
setup_caches=(args.output_dso_path or args.output_pte_path or args.output_aoti_package_path),
152156
use_distributed=args.distributed,
153157
is_chat_model=is_chat_model,
154158
)
@@ -161,6 +165,7 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
161165
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
162166
speculative_builder_args.gguf_path = None
163167
speculative_builder_args.dso_path = None
168+
speculative_builder_args.aoti_package_path = None
164169
speculative_builder_args.pte_path = None
165170
return speculative_builder_args
166171

@@ -432,11 +437,12 @@ def _initialize_model(
432437
):
433438
print("Loading model...")
434439

435-
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
440+
if builder_args.gguf_path and (builder_args.dso_path or builder_args.aoti_package_path or builder_args.pte_path):
436441
print("Setting gguf_kwargs for generate.")
437442
is_dso = builder_args.dso_path is not None
443+
is_aoti_package = builder_args.aoti_package_path is not None
438444
is_pte = builder_args.pte_path is not None
439-
assert not (is_dso and is_pte)
445+
assert not (is_dso and is_aoti_package and is_pte)
440446
assert builder_args.gguf_kwargs is None
441447
# TODO: make GGUF load independent of backend
442448
# currently not working because AVX int_mm broken
@@ -470,6 +476,36 @@ def _initialize_model(
470476
)
471477
except:
472478
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
479+
480+
elif builder_args.aoti_package_path:
481+
if not is_cuda_or_cpu_device(builder_args.device):
482+
print(
483+
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
484+
)
485+
builder_args.device = "cpu"
486+
487+
# assert (
488+
# quantize is None or quantize == "{ }"
489+
# ), "quantize not valid for exported PT2 model. Specify quantization during export."
490+
491+
with measure_time("Time to load model: {time:.02f} seconds"):
492+
model = _load_model(builder_args, only_config=True)
493+
device_sync(device=builder_args.device)
494+
495+
try:
496+
# Replace model forward with the AOT-compiled forward
497+
# This is a hacky way to quickly demo AOTI's capability.
498+
# model is still a Python object, and any mutation to its
499+
# attributes will NOT be seen on by AOTI-compiled forward
500+
# function, e.g. calling model.setup_cache will NOT touch
501+
# AOTI compiled and maintained model buffers such as kv_cache.
502+
from torch._inductor.package import load_package
503+
model.forward = load_package(
504+
str(builder_args.aoti_package_path.absolute()), builder_args.device
505+
)
506+
except:
507+
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")
508+
473509
elif builder_args.pte_path:
474510
if not is_cpu_device(builder_args.device):
475511
print(

build/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,42 +69,46 @@ def unpack_packed_weights(
6969

7070
active_builder_args_dso = None
7171
active_builder_args_pte = None
72+
active_builder_args_aoti_package = None
7273

7374

74-
def set_backend(dso, pte):
75+
def set_backend(dso, pte, aoti_package):
7576
global active_builder_args_dso
7677
global active_builder_args_pte
7778
active_builder_args_dso = dso
79+
active_builder_args_aoti_package = aoti_package
7880
active_builder_args_pte = pte
7981

8082

8183
def use_aoti_backend() -> bool:
8284
global active_builder_args_dso
85+
global active_builder_args_aoti_package
8386
global active_builder_args_pte
8487

8588
# eager == aoti, which is when backend has not been explicitly set
86-
if (not active_builder_args_dso) and not (active_builder_args_pte):
89+
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
8790
return True
8891

89-
if active_builder_args_pte and active_builder_args_dso:
92+
if active_builder_args_pte and active_builder_args_aoti_package:
9093
raise RuntimeError(
91-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
94+
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
9295
)
9396

94-
return bool(active_builder_args_dso)
97+
return bool(active_builder_args_dso) or bool(active_builder_args_aoti_package)
9598

9699

97100
def use_et_backend() -> bool:
98101
global active_builder_args_dso
102+
global active_builder_args_aoti_package
99103
global active_builder_args_pte
100104

101105
# eager == aoti, which is when backend has not been explicitly set
102-
if not (active_builder_args_pte or active_builder_args_dso):
103-
return False
106+
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
107+
return True
104108

105-
if active_builder_args_pte and active_builder_args_dso:
109+
if active_builder_args_pte and active_builder_args_aoti_package:
106110
raise RuntimeError(
107-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
111+
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
108112
)
109113

110114
return bool(active_builder_args_pte)

cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,12 @@ def _add_export_output_path_args(parser) -> None:
163163
default=None,
164164
help="Output to the specified AOT Inductor .dso model file",
165165
)
166+
output_path_parser.add_argument(
167+
"--output-aoti-package-path",
168+
type=str,
169+
default=None,
170+
help="Output directory for AOTInductor compiled artifacts",
171+
)
166172

167173

168174
# Add CLI Args representing user provided exported model files
@@ -174,6 +180,12 @@ def _add_exported_input_path_args(parser) -> None:
174180
default=None,
175181
help="Use the specified AOT Inductor .dso model file",
176182
)
183+
exported_model_path_parser.add_argument(
184+
"--aoti-package-path",
185+
type=Path,
186+
default=None,
187+
help="Use the specified directory containing AOT Inductor compiled files",
188+
)
177189
exported_model_path_parser.add_argument(
178190
"--pte-path",
179191
type=Path,

eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def main(args) -> None:
233233

234234
if compile:
235235
assert not (
236-
builder_args.dso_path or builder_args.pte_path
236+
builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path
237237
), "cannot compile exported model"
238238
global model_forward
239239
model_forward = torch.compile(
@@ -260,6 +260,8 @@ def main(args) -> None:
260260
)
261261
if builder_args.dso_path:
262262
print(f"For model {builder_args.dso_path}")
263+
if builder_args.aoti_package_path:
264+
print(f"For model {builder_args.aoti_package_path}")
263265
elif builder_args.pte_path:
264266
print(f"For model {builder_args.pte_path}")
265267
elif builder_args.checkpoint_path:

export.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ def main(args):
3939

4040
print(f"Using device={builder_args.device}")
4141
set_precision(builder_args.precision)
42-
set_backend(dso=args.output_dso_path, pte=args.output_pte_path)
42+
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path)
4343

4444
builder_args.dso_path = None
4545
builder_args.pte_path = None
46+
builder_args.aoti_package_path = None
4647
builder_args.setup_caches = True
4748

4849
output_pte_path = args.output_pte_path
4950
output_dso_path = args.output_dso_path
51+
output_aoti_package_path = args.output_aoti_package_path
5052

5153
if output_pte_path and builder_args.device != "cpu":
5254
print(
@@ -74,6 +76,7 @@ def main(args):
7476
)
7577
model_to_pte = model
7678
model_to_dso = model
79+
model_to_aoti_package = model
7780
else:
7881
if output_pte_path:
7982
_set_gguf_kwargs(builder_args, is_et=True, context="export")
@@ -83,12 +86,13 @@ def main(args):
8386
)
8487
_unset_gguf_kwargs(builder_args)
8588

86-
if output_dso_path:
89+
if output_dso_path or output_aoti_package_path:
8790
_set_gguf_kwargs(builder_args, is_et=False, context="export")
88-
model_to_dso = _initialize_model(
91+
model_to_aoti_package = _initialize_model(
8992
builder_args,
9093
quantize,
9194
)
95+
model_to_dso = model_to_aoti_package
9296
_unset_gguf_kwargs(builder_args)
9397

9498
with torch.no_grad():
@@ -104,10 +108,16 @@ def main(args):
104108
"Export with executorch requested but ExecuTorch could not be loaded"
105109
)
106110
print(executorch_exception)
111+
107112
if output_dso_path:
108113
output_dso_path = str(os.path.abspath(output_dso_path))
109114
print(f"Exporting model using AOT Inductor to {output_dso_path}")
110-
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)
115+
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args, False)
116+
117+
if output_aoti_package_path:
118+
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
119+
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
120+
export_model_aoti(model_to_aoti_package, builder_args.device, output_aoti_package_path, args, True)
111121

112122

113123
if __name__ == "__main__":

export_aoti.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
import torch
99
import torch.nn as nn
1010
from torch.export import Dim
11+
import torch._inductor.config
1112

1213
default_device = "cpu"
1314

1415

15-
def export_model(model: nn.Module, device, output_path, args=None):
16+
def export_model(model: nn.Module, device, output_path, args=None, package=True):
1617
max_seq_length = 350
1718

1819
input = (
@@ -25,11 +26,17 @@ def export_model(model: nn.Module, device, output_path, args=None):
2526
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
2627

2728
model.to(device)
28-
so = torch._export.aot_compile(
29+
30+
options = {"aot_inductor.output_path": output_path}
31+
# TODO: workaround until we update torch version
32+
if "aot_inductor.package" in torch._inductor.config._config:
33+
options["aot_inductor.package"] = package
34+
35+
path = torch._export.aot_compile(
2936
model,
3037
args=input,
31-
options={"aot_inductor.output_path": output_path},
38+
options=options,
3239
dynamic_shapes=dynamic_shapes,
3340
)
34-
print(f"The generated DSO model can be found at: {so}")
35-
return so
41+
print(f"The AOTInductor compiled files can be found at: {path}")
42+
return path

generate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def validate_build(
9191
reason = "model compilation for prefill"
9292
if self.compile:
9393
reason = "model compilation"
94-
if builder_args.dso_path:
95-
model_type = "DSO"
94+
if builder_args.aoti_package_path:
95+
model_type = "PT2"
9696
if builder_args.pte_path:
9797
model_type = "PTE"
9898
if model_type and reason:
@@ -103,7 +103,7 @@ def validate_build(
103103
@classmethod
104104
def from_args(cls, args):
105105
sequential_prefill = (
106-
args.sequential_prefill or bool(args.dso_path) or bool(args.pte_path)
106+
args.sequential_prefill or bool(args.aoti_package_path) or bool(args.pte_path)
107107
)
108108

109109
return cls(

0 commit comments

Comments
 (0)