Skip to content

Commit db26ea0

Browse files
committed
[aoti] Add cpp packaging for aoti + loading in python
1 parent 9a94b56 commit db26ea0

File tree

9 files changed

+109
-39
lines changed

9 files changed

+109
-39
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__/
66
# C extensions
77
*.so
88

9+
.vscode
910
.model-artifacts/
1011
.venv
1112
.torchchat
@@ -18,3 +19,7 @@ runner-aoti/cmake-out/*
1819

1920
# pte files
2021
*.pte
22+
23+
checkpoints/
24+
exportedModels/
25+
cmake-out/

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ source .venv/bin/activate
5454
```
5555
[skip default]: end
5656

57-
[shell default]: ./install_requirements.sh
57+
[shell default]: ./install_requirements.sh
5858

5959
Installations can be tested by
6060

@@ -150,13 +150,13 @@ model. The first command performs the actual export, the second
150150
command loads the exported model into the Python interface to enable
151151
users to test the exported model.
152152

153-
```
153+
```bash
154154
# Compile
155-
python3 torchchat.py export llama3 --output-dso-path exportedModels/llama3.so
155+
python3 torchchat.py export llama3 --output-aoti-package-path exportedModels/llama3_artifacts
156156

157157
# Execute the exported model using Python
158158

159-
python3 torchchat.py generate llama3 --dso-path exportedModels/llama3.so --prompt "Hello my name is"
159+
python3 torchchat.py generate llama3 --aoti-package-path exportedModels/llama3_artifacts --prompt "Hello my name is"
160160
```
161161

162162
NOTE: If your machine has cuda add this flag for performance
@@ -174,7 +174,8 @@ scripts/build_native.sh aoti
174174

175175
Execute
176176
```bash
177-
cmake-out/aoti_run exportedModels/llama3.so -z `python3 torchchat.py where llama3`/tokenizer.model -l 3 -i "Once upon a time"
177+
make -C exportedModels/llama3_artifacts
178+
cmake-out/aoti_run exportedModels/llama3_artifacts/llama3_artifacts.so -z `python3 torchchat.py where llama3`/tokenizer.model -l 3 -i "Once upon a time" -d cpu
178179
```
179180

180181
## Mobile Execution

build/builder.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class BuilderArgs:
3333
gguf_path: Optional[Union[Path, str]] = None
3434
gguf_kwargs: Optional[Dict[str, Any]] = None
3535
dso_path: Optional[Union[Path, str]] = None
36+
aoti_package_path: Optional[Union[Path, str]] = None
3637
pte_path: Optional[Union[Path, str]] = None
3738
device: Optional[str] = None
3839
precision: torch.dtype = torch.float32
@@ -50,28 +51,29 @@ def __post_init__(self):
5051
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
5152
or (self.gguf_path and self.gguf_path.is_file())
5253
or (self.dso_path and Path(self.dso_path).is_file())
54+
or (self.aoti_package_path and Path(self.aoti_package_path).is_file())
5355
or (self.pte_path and Path(self.pte_path).is_file())
5456
):
5557
raise RuntimeError(
5658
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
5759
)
5860

59-
if self.dso_path and self.pte_path:
60-
raise RuntimeError("specify either DSO path or PTE path, but not both")
61+
if self.pte_path and self.aoti_package_path:
62+
raise RuntimeError("specify either AOTI Package path or PTE path, but not more than one")
6163

62-
if self.checkpoint_path and (self.dso_path or self.pte_path):
64+
if self.checkpoint_path and (self.pte_path or self.aoti_package_path):
6365
print(
64-
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
66+
"Warning: checkpoint path ignored because an exported AOTI or PTE path specified"
6567
)
66-
if self.checkpoint_dir and (self.dso_path or self.pte_path):
68+
if self.checkpoint_dir and (self.pte_path or self.aoti_package_path):
6769
print(
68-
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
70+
"Warning: checkpoint dir ignored because an exported AOTI or PTE path specified"
6971
)
70-
if self.gguf_path and (self.dso_path or self.pte_path):
72+
if self.gguf_path and (self.pte_path or self.aoti_package_path):
7173
print(
72-
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
74+
"Warning: GGUF path ignored because an exported AOTI or PTE path specified"
7375
)
74-
if not (self.dso_path) and not (self.pte_path):
76+
if not (self.dso_path) and not (self.aoti_package_path):
7577
self.prefill_possible = True
7678

7779
@classmethod
@@ -105,6 +107,7 @@ def from_args(cls, args): # -> BuilderArgs:
105107
checkpoint_path,
106108
checkpoint_dir,
107109
args.dso_path,
110+
args.aoti_package_path,
108111
args.pte_path,
109112
args.gguf_path,
110113
]:
@@ -138,10 +141,11 @@ def from_args(cls, args): # -> BuilderArgs:
138141
gguf_path=args.gguf_path,
139142
gguf_kwargs=None,
140143
dso_path=args.dso_path,
144+
aoti_package_path=args.aoti_package_path,
141145
pte_path=args.pte_path,
142146
device=args.device,
143147
precision=dtype,
144-
setup_caches=(args.output_dso_path or args.output_pte_path),
148+
setup_caches=(args.output_dso_path or args.output_pte_path or args.output_aoti_package_path),
145149
use_distributed=args.distributed,
146150
is_chat_model=is_chat_model,
147151
)
@@ -154,6 +158,7 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
154158
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
155159
speculative_builder_args.gguf_path = None
156160
speculative_builder_args.dso_path = None
161+
speculative_builder_args.aoti_package_path = None
157162
speculative_builder_args.pte_path = None
158163
return speculative_builder_args
159164

@@ -377,11 +382,12 @@ def _initialize_model(
377382
):
378383
print("Loading model...")
379384

380-
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
385+
if builder_args.gguf_path and (builder_args.dso_path or builder_args.aoti_package_path or builder_args.pte_path):
381386
print("Setting gguf_kwargs for generate.")
382387
is_dso = builder_args.dso_path is not None
388+
is_aoti_package = builder_args.aoti_package_path is not None
383389
is_pte = builder_args.pte_path is not None
384-
assert not (is_dso and is_pte)
390+
assert not (is_dso and is_aoti_package and is_pte)
385391
assert builder_args.gguf_kwargs is None
386392
# TODO: make GGUF load independent of backend
387393
# currently not working because AVX int_mm broken
@@ -415,6 +421,36 @@ def _initialize_model(
415421
)
416422
except:
417423
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
424+
425+
elif builder_args.aoti_package_path:
426+
if not is_cuda_or_cpu_device(builder_args.device):
427+
print(
428+
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
429+
)
430+
builder_args.device = "cpu"
431+
432+
# assert (
433+
# quantize is None or quantize == "{ }"
434+
# ), "quantize not valid for exported PT2 model. Specify quantization during export."
435+
436+
with measure_time("Time to load model: {time:.02f} seconds"):
437+
model = _load_model(builder_args, only_config=True)
438+
device_sync(device=builder_args.device)
439+
440+
try:
441+
# Replace model forward with the AOT-compiled forward
442+
# This is a hacky way to quickly demo AOTI's capability.
443+
# model is still a Python object, and any mutation to its
444+
# attributes will NOT be seen on by AOTI-compiled forward
445+
# function, e.g. calling model.setup_cache will NOT touch
446+
# AOTI compiled and maintained model buffers such as kv_cache.
447+
from torch._inductor.package import load_package
448+
model.forward = load_package(
449+
str(builder_args.aoti_package_path.absolute()), builder_args.device
450+
)
451+
except:
452+
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.aoti_package_path}")
453+
418454
elif builder_args.pte_path:
419455
if not is_cpu_device(builder_args.device):
420456
print(

build/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,42 +69,46 @@ def unpack_packed_weights(
6969

7070
active_builder_args_dso = None
7171
active_builder_args_pte = None
72+
active_builder_args_aoti_package = None
7273

7374

74-
def set_backend(dso, pte):
75+
def set_backend(dso, pte, aoti_package):
7576
global active_builder_args_dso
7677
global active_builder_args_pte
7778
active_builder_args_dso = dso
79+
active_builder_args_aoti_package = aoti_package
7880
active_builder_args_pte = pte
7981

8082

8183
def use_aoti_backend() -> bool:
8284
global active_builder_args_dso
85+
global active_builder_args_aoti_package
8386
global active_builder_args_pte
8487

8588
# eager == aoti, which is when backend has not been explicitly set
86-
if (not active_builder_args_dso) and not (active_builder_args_pte):
89+
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
8790
return True
8891

89-
if active_builder_args_pte and active_builder_args_dso:
92+
if active_builder_args_pte and active_builder_args_aoti_package:
9093
raise RuntimeError(
91-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
94+
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
9295
)
9396

94-
return bool(active_builder_args_dso)
97+
return bool(active_builder_args_dso) or bool(active_builder_args_aoti_package)
9598

9699

97100
def use_et_backend() -> bool:
98101
global active_builder_args_dso
102+
global active_builder_args_aoti_package
99103
global active_builder_args_pte
100104

101105
# eager == aoti, which is when backend has not been explicitly set
102-
if not (active_builder_args_pte or active_builder_args_dso):
103-
return False
106+
if (not active_builder_args_pte) and (not active_builder_args_aoti_package):
107+
return True
104108

105-
if active_builder_args_pte and active_builder_args_dso:
109+
if active_builder_args_pte and active_builder_args_aoti_package:
106110
raise RuntimeError(
107-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
111+
"code generation needs to choose different implementations for AOTI and PTE path. Please only use one export option, and call export twice if necessary!"
108112
)
109113

110114
return bool(active_builder_args_pte)

cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ def _add_export_output_path_args(parser) -> None:
210210
default=None,
211211
help="Output to the specified AOT Inductor .dso model file",
212212
)
213+
output_path_parser.add_argument(
214+
"--output-aoti-package-path",
215+
type=str,
216+
default=None,
217+
help="Output directory for AOTInductor compiled artifacts",
218+
)
213219

214220

215221
# Add CLI Args representing user provided exported model files
@@ -221,6 +227,12 @@ def _add_exported_model_input_args(parser) -> None:
221227
default=None,
222228
help="Use the specified AOT Inductor .dso model file",
223229
)
230+
exported_model_path_parser.add_argument(
231+
"--aoti-package-path",
232+
type=Path,
233+
default=None,
234+
help="Use the specified directory containing AOT Inductor compiled files",
235+
)
224236
exported_model_path_parser.add_argument(
225237
"--pte-path",
226238
type=Path,

eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def main(args) -> None:
233233

234234
if compile:
235235
assert not (
236-
builder_args.dso_path or builder_args.pte_path
236+
builder_args.dso_path or builder_args.pte_path or builder_args.aoti_package_path
237237
), "cannot compile exported model"
238238
global model_forward
239239
model_forward = torch.compile(
@@ -260,6 +260,8 @@ def main(args) -> None:
260260
)
261261
if builder_args.dso_path:
262262
print(f"For model {builder_args.dso_path}")
263+
if builder_args.aoti_package_path:
264+
print(f"For model {builder_args.aoti_package_path}")
263265
elif builder_args.pte_path:
264266
print(f"For model {builder_args.pte_path}")
265267
elif builder_args.checkpoint_path:

export.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ def main(args):
3939

4040
print(f"Using device={builder_args.device}")
4141
set_precision(builder_args.precision)
42-
set_backend(dso=args.output_dso_path, pte=args.output_pte_path)
42+
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, aoti_package=args.output_aoti_package_path)
4343

4444
builder_args.dso_path = None
4545
builder_args.pte_path = None
46+
builder_args.aoti_package_path = None
4647
builder_args.setup_caches = True
4748

4849
output_pte_path = args.output_pte_path
4950
output_dso_path = args.output_dso_path
51+
output_aoti_package_path = args.output_aoti_package_path
5052

5153
if output_pte_path and builder_args.device != "cpu":
5254
print(
@@ -74,6 +76,7 @@ def main(args):
7476
)
7577
model_to_pte = model
7678
model_to_dso = model
79+
model_to_aoti_package = model
7780
else:
7881
if output_pte_path:
7982
_set_gguf_kwargs(builder_args, is_et=True, context="export")
@@ -83,12 +86,13 @@ def main(args):
8386
)
8487
_unset_gguf_kwargs(builder_args)
8588

86-
if output_dso_path:
89+
if output_dso_path or output_aoti_package_path:
8790
_set_gguf_kwargs(builder_args, is_et=False, context="export")
88-
model_to_dso = _initialize_model(
91+
model_to_aoti_package = _initialize_model(
8992
builder_args,
9093
quantize,
9194
)
95+
model_to_dso = model_to_aoti_package
9296
_unset_gguf_kwargs(builder_args)
9397

9498
with torch.no_grad():
@@ -104,10 +108,16 @@ def main(args):
104108
"Export with executorch requested but ExecuTorch could not be loaded"
105109
)
106110
print(executorch_exception)
111+
107112
if output_dso_path:
108113
output_dso_path = str(os.path.abspath(output_dso_path))
109114
print(f"Exporting model using AOT Inductor to {output_dso_path}")
110-
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)
115+
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args, False)
116+
117+
if output_aoti_package_path:
118+
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
119+
print(f"Exporting model using AOT Inductor to {output_aoti_package_path}")
120+
export_model_aoti(model_to_aoti_package, builder_args.device, output_aoti_package_path, args, True)
111121

112122

113123
if __name__ == "__main__":

export_aoti.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
default_device = "cpu"
1313

1414

15-
def export_model(model: nn.Module, device, output_path, args=None):
15+
def export_model(model: nn.Module, device, output_path, args=None, package=True):
1616
max_seq_length = 350
1717

1818
input = (
@@ -25,11 +25,11 @@ def export_model(model: nn.Module, device, output_path, args=None):
2525
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
2626

2727
model.to(device)
28-
so = torch._export.aot_compile(
28+
path = torch._export.aot_compile(
2929
model,
3030
args=input,
31-
options={"aot_inductor.output_path": output_path},
31+
options={"aot_inductor.output_path": output_path, "aot_inductor.package": package},
3232
dynamic_shapes=dynamic_shapes,
3333
)
34-
print(f"The generated DSO model can be found at: {so}")
35-
return so
34+
print(f"The AOTInductor compiled files can be found at: {path}")
35+
return path

generate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def validate_build(
9191
reason = "model compilation for prefill"
9292
if self.compile:
9393
reason = "model compilation"
94-
if builder_args.dso_path:
95-
model_type = "DSO"
94+
if builder_args.aoti_package_path:
95+
model_type = "PT2"
9696
if builder_args.pte_path:
9797
model_type = "PTE"
9898
if model_type and reason:
@@ -103,7 +103,7 @@ def validate_build(
103103
@classmethod
104104
def from_args(cls, args):
105105
sequential_prefill = (
106-
args.sequential_prefill or bool(args.dso_path) or bool(args.pte_path)
106+
args.sequential_prefill or bool(args.aoti_package_path) or bool(args.pte_path)
107107
)
108108

109109
return cls(

0 commit comments

Comments
 (0)