Skip to content

Commit b23d414

Browse files
committed
[aoti] Add cpp packaging for aoti + loading in python
1 parent 9a94b56 commit b23d414

File tree

9 files changed

+128
-35
lines changed

9 files changed

+128
-35
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ __pycache__/
66
# C extensions
77
*.so
88

9+
.vscode
910
.model-artifacts/
1011
.venv
1112
.torchchat
@@ -18,3 +19,7 @@ runner-aoti/cmake-out/*
1819

1920
# pte files
2021
*.pte
22+
23+
checkpoints/
24+
exportedModels/
25+
cmake-out/

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ source .venv/bin/activate
5454
```
5555
[skip default]: end
5656

57-
[shell default]: ./install_requirements.sh
57+
[shell default]: ./install_requirements.sh
5858

5959
Installations can be tested by
6060

@@ -152,11 +152,11 @@ users to test the exported model.
152152

153153
```
154154
# Compile
155-
python3 torchchat.py export llama3 --output-dso-path exportedModels/llama3.so
155+
python3 torchchat.py export llama3 --output-pt2-path exportedModels/llama3_artifacts
156156
157157
# Execute the exported model using Python
158158
159-
python3 torchchat.py generate llama3 --dso-path exportedModels/llama3.so --prompt "Hello my name is"
159+
python3 torchchat.py generate llama3 --pt2-path exportedModels/llama3_artifacts --prompt "Hello my name is"
160160
```
161161

162162
NOTE: If your machine has cuda add this flag for performance
@@ -174,7 +174,8 @@ scripts/build_native.sh aoti
174174

175175
Execute
176176
```bash
177-
cmake-out/aoti_run exportedModels/llama3.so -z `python3 torchchat.py where llama3`/tokenizer.model -l 3 -i "Once upon a time"
177+
make -C exportedModels/llama3_artifacts
178+
cmake-out/aoti_run exportedModels/llama3_artifacts.so -z `python3 torchchat.py where llama3`/tokenizer.model -l 3 -i "Once upon a time"
178179
```
179180

180181
## Mobile Execution

build/builder.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class BuilderArgs:
3333
gguf_path: Optional[Union[Path, str]] = None
3434
gguf_kwargs: Optional[Dict[str, Any]] = None
3535
dso_path: Optional[Union[Path, str]] = None
36+
pt2_path: Optional[Union[Path, str]] = None
3637
pte_path: Optional[Union[Path, str]] = None
3738
device: Optional[str] = None
3839
precision: torch.dtype = torch.float32
@@ -50,28 +51,29 @@ def __post_init__(self):
5051
or (self.checkpoint_dir and self.checkpoint_dir.is_dir())
5152
or (self.gguf_path and self.gguf_path.is_file())
5253
or (self.dso_path and Path(self.dso_path).is_file())
54+
or (self.pt2_path and Path(self.pt2_path).is_file())
5355
or (self.pte_path and Path(self.pte_path).is_file())
5456
):
5557
raise RuntimeError(
5658
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
5759
)
5860

59-
if self.dso_path and self.pte_path:
60-
raise RuntimeError("specify either DSO path or PTE path, but not both")
61+
if sum(1 for path in (self.dso_path, self.pte_path, self.pt2_path) if path is not None) > 1:
62+
raise RuntimeError("specify either DSO path, PT2 path, or PTE path, but not more than one")
6163

62-
if self.checkpoint_path and (self.dso_path or self.pte_path):
64+
if self.checkpoint_path and (self.dso_path or self.pte_path or self.pt2_path):
6365
print(
64-
"Warning: checkpoint path ignored because an exported DSO or PTE path specified"
66+
"Warning: checkpoint path ignored because an exported DSO, PT2, or PTE path specified"
6567
)
66-
if self.checkpoint_dir and (self.dso_path or self.pte_path):
68+
if self.checkpoint_dir and (self.dso_path or self.pte_path or self.pt2_path):
6769
print(
68-
"Warning: checkpoint dir ignored because an exported DSO or PTE path specified"
70+
"Warning: checkpoint dir ignored because an exported DSO, PT2, or PTE path specified"
6971
)
70-
if self.gguf_path and (self.dso_path or self.pte_path):
72+
if self.gguf_path and (self.dso_path or self.pte_path or self.pt2_path):
7173
print(
72-
"Warning: GGUF path ignored because an exported DSO or PTE path specified"
74+
"Warning: GGUF path ignored because an exported DSO, PT2, or PTE path specified"
7375
)
74-
if not (self.dso_path) and not (self.pte_path):
76+
if not (self.dso_path) and not (self.pte_path) and not (self.pt2_path):
7577
self.prefill_possible = True
7678

7779
@classmethod
@@ -105,6 +107,7 @@ def from_args(cls, args): # -> BuilderArgs:
105107
checkpoint_path,
106108
checkpoint_dir,
107109
args.dso_path,
110+
args.pt2_path,
108111
args.pte_path,
109112
args.gguf_path,
110113
]:
@@ -138,10 +141,11 @@ def from_args(cls, args): # -> BuilderArgs:
138141
gguf_path=args.gguf_path,
139142
gguf_kwargs=None,
140143
dso_path=args.dso_path,
144+
pt2_path=args.pt2_path,
141145
pte_path=args.pte_path,
142146
device=args.device,
143147
precision=dtype,
144-
setup_caches=(args.output_dso_path or args.output_pte_path),
148+
setup_caches=(args.output_dso_path or args.output_pte_path or args.output_pt2_path),
145149
use_distributed=args.distributed,
146150
is_chat_model=is_chat_model,
147151
)
@@ -154,6 +158,7 @@ def from_speculative_args(cls, args): # -> BuilderArgs:
154158
speculative_builder_args.checkpoint_path = args.draft_checkpoint_path
155159
speculative_builder_args.gguf_path = None
156160
speculative_builder_args.dso_path = None
161+
speculative_builder_args.pt2_path = None
157162
speculative_builder_args.pte_path = None
158163
return speculative_builder_args
159164

@@ -377,11 +382,12 @@ def _initialize_model(
377382
):
378383
print("Loading model...")
379384

380-
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
385+
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pt2_path or builder_args.pte_path):
381386
print("Setting gguf_kwargs for generate.")
382387
is_dso = builder_args.dso_path is not None
388+
is_pt2 = builder_args.pt2_path is not None
383389
is_pte = builder_args.pte_path is not None
384-
assert not (is_dso and is_pte)
390+
assert not (is_dso and is_pt2 and is_pte)
385391
assert builder_args.gguf_kwargs is None
386392
# TODO: make GGUF load independent of backend
387393
# currently not working because AVX int_mm broken
@@ -415,6 +421,36 @@ def _initialize_model(
415421
)
416422
except:
417423
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
424+
425+
elif builder_args.pt2_path:
426+
if not is_cuda_or_cpu_device(builder_args.device):
427+
print(
428+
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
429+
)
430+
builder_args.device = "cpu"
431+
432+
# assert (
433+
# quantize is None or quantize == "{ }"
434+
# ), "quantize not valid for exported PT2 model. Specify quantization during export."
435+
436+
with measure_time("Time to load model: {time:.02f} seconds"):
437+
model = _load_model(builder_args, only_config=True)
438+
device_sync(device=builder_args.device)
439+
440+
try:
441+
# Replace model forward with the AOT-compiled forward
442+
# This is a hacky way to quickly demo AOTI's capability.
443+
# model is still a Python object, and any mutation to its
444+
# attributes will NOT be seen on by AOTI-compiled forward
445+
# function, e.g. calling model.setup_cache will NOT touch
446+
# AOTI compiled and maintained model buffers such as kv_cache.
447+
from torch._inductor.package import load_package
448+
model.forward = load_package(
449+
str(builder_args.pt2_path.absolute()), builder_args.device
450+
)
451+
except:
452+
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.pt2_path}")
453+
418454
elif builder_args.pte_path:
419455
if not is_cpu_device(builder_args.device):
420456
print(

build/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,42 +69,46 @@ def unpack_packed_weights(
6969

7070
active_builder_args_dso = None
7171
active_builder_args_pte = None
72+
active_builder_args_pt2 = None
7273

7374

74-
def set_backend(dso, pte):
75+
def set_backend(dso, pte, pt2):
7576
global active_builder_args_dso
7677
global active_builder_args_pte
7778
active_builder_args_dso = dso
79+
active_builder_args_pt2 = pt2
7880
active_builder_args_pte = pte
7981

8082

8183
def use_aoti_backend() -> bool:
8284
global active_builder_args_dso
85+
global active_builder_args_pt2
8386
global active_builder_args_pte
8487

8588
# eager == aoti, which is when backend has not been explicitly set
86-
if (not active_builder_args_dso) and not (active_builder_args_pte):
89+
if (not active_builder_args_dso) and not (active_builder_args_pte) and not (active_builder_args_pt2):
8790
return True
8891

89-
if active_builder_args_pte and active_builder_args_dso:
92+
if sum(1 for builder in (active_builder_args_pte, active_builder_args_dso, active_builder_args_pt2)) > 1:
9093
raise RuntimeError(
91-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
94+
"code generation needs to choose different implementations for DSO, PT2, and PTE path. Please only use one export option, and call export twice if necessary!"
9295
)
9396

94-
return bool(active_builder_args_dso)
97+
return bool(active_builder_args_dso) or bool(active_builder_args_pt2)
9598

9699

97100
def use_et_backend() -> bool:
98101
global active_builder_args_dso
102+
global active_builder_args_pt2
99103
global active_builder_args_pte
100104

101105
# eager == aoti, which is when backend has not been explicitly set
102-
if not (active_builder_args_pte or active_builder_args_dso):
103-
return False
106+
if (not active_builder_args_dso) and not (active_builder_args_pte) and not (active_builder_args_pt2):
107+
return True
104108

105-
if active_builder_args_pte and active_builder_args_dso:
109+
if sum(1 for builder in (active_builder_args_pte, active_builder_args_dso, active_builder_args_pt2)) > 1:
106110
raise RuntimeError(
107-
"code generation needs to choose different implementations for DSO and PTE path. Please only use one export option, and call export twice if necessary!"
111+
"code generation needs to choose different implementations for DSO, PT2, and PTE path. Please only use one export option, and call export twice if necessary!"
108112
)
109113

110114
return bool(active_builder_args_pte)

cli.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ def _add_export_output_path_args(parser) -> None:
210210
default=None,
211211
help="Output to the specified AOT Inductor .dso model file",
212212
)
213+
output_path_parser.add_argument(
214+
"--output-pt2-path",
215+
type=str,
216+
default=None,
217+
help="Output directory for AOTInductor compiled artifacts",
218+
)
213219

214220

215221
# Add CLI Args representing user provided exported model files
@@ -221,6 +227,12 @@ def _add_exported_model_input_args(parser) -> None:
221227
default=None,
222228
help="Use the specified AOT Inductor .dso model file",
223229
)
230+
exported_model_path_parser.add_argument(
231+
"--pt2-path",
232+
type=Path,
233+
default=None,
234+
help="Use the specified directory containing AOT Inductor compiled files",
235+
)
224236
exported_model_path_parser.add_argument(
225237
"--pte-path",
226238
type=Path,

eval.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def main(args) -> None:
233233

234234
if compile:
235235
assert not (
236-
builder_args.dso_path or builder_args.pte_path
236+
builder_args.dso_path or builder_args.pte_path or builder_args.pt2_path
237237
), "cannot compile exported model"
238238
global model_forward
239239
model_forward = torch.compile(
@@ -260,6 +260,8 @@ def main(args) -> None:
260260
)
261261
if builder_args.dso_path:
262262
print(f"For model {builder_args.dso_path}")
263+
if builder_args.pt2_path:
264+
print(f"For model {builder_args.pt2_path}")
263265
elif builder_args.pte_path:
264266
print(f"For model {builder_args.pte_path}")
265267
elif builder_args.checkpoint_path:

export.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from build.utils import set_backend, set_precision
2222
from cli import add_arguments_for_verb, arg_init, check_args
23-
from export_aoti import export_model as export_model_aoti
23+
from export_aoti import export_model_so, export_model_pt2
2424

2525
try:
2626
executorch_export_available = True
@@ -39,14 +39,16 @@ def main(args):
3939

4040
print(f"Using device={builder_args.device}")
4141
set_precision(builder_args.precision)
42-
set_backend(dso=args.output_dso_path, pte=args.output_pte_path)
42+
set_backend(dso=args.output_dso_path, pte=args.output_pte_path, pt2=args.output_pt2_path)
4343

4444
builder_args.dso_path = None
4545
builder_args.pte_path = None
46+
builder_args.pt2_path = None
4647
builder_args.setup_caches = True
4748

4849
output_pte_path = args.output_pte_path
4950
output_dso_path = args.output_dso_path
51+
output_pt2_path = args.output_pt2_path
5052

5153
if output_pte_path and builder_args.device != "cpu":
5254
print(
@@ -74,6 +76,7 @@ def main(args):
7476
)
7577
model_to_pte = model
7678
model_to_dso = model
79+
model_to_pt2 = model
7780
else:
7881
if output_pte_path:
7982
_set_gguf_kwargs(builder_args, is_et=True, context="export")
@@ -83,12 +86,13 @@ def main(args):
8386
)
8487
_unset_gguf_kwargs(builder_args)
8588

86-
if output_dso_path:
89+
if output_dso_path or output_pt2_path:
8790
_set_gguf_kwargs(builder_args, is_et=False, context="export")
88-
model_to_dso = _initialize_model(
91+
model_to_pt2 = _initialize_model(
8992
builder_args,
9093
quantize,
9194
)
95+
model_to_dso = model_to_pt2
9296
_unset_gguf_kwargs(builder_args)
9397

9498
with torch.no_grad():
@@ -104,10 +108,16 @@ def main(args):
104108
"Export with executorch requested but ExecuTorch could not be loaded"
105109
)
106110
print(executorch_exception)
111+
107112
if output_dso_path:
108113
output_dso_path = str(os.path.abspath(output_dso_path))
109114
print(f"Exporting model using AOT Inductor to {output_dso_path}")
110-
export_model_aoti(model_to_dso, builder_args.device, output_dso_path, args)
115+
export_model_so(model_to_dso, builder_args.device, output_dso_path, args)
116+
117+
if output_pt2_path:
118+
output_pt2_path = str(os.path.abspath(output_pt2_path))
119+
print(f"Exporting model using AOT Inductor to {output_pt2_path}")
120+
export_model_pt2(model_to_pt2, builder_args.device, output_pt2_path, args)
111121

112122

113123
if __name__ == "__main__":

export_aoti.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,30 @@
1212
default_device = "cpu"
1313

1414

15-
def export_model(model: nn.Module, device, output_path, args=None):
15+
def export_model_pt2(model: nn.Module, device, output_path, args=None):
16+
max_seq_length = 350
17+
18+
input = (
19+
torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device),
20+
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
21+
)
22+
23+
seq = Dim("seq", min=1, max=max_seq_length)
24+
# Specify that the first dimension of each input is that batch size
25+
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
26+
27+
model.to(device)
28+
pt2_path = torch._export.aot_compile(
29+
model,
30+
args=input,
31+
options={"aot_inductor.output_path": output_path, "aot_inductor.package": True},
32+
dynamic_shapes=dynamic_shapes,
33+
)
34+
print(f"The AOTInductor compiled files can be found at: {pt2_path}")
35+
return pt2_path
36+
37+
38+
def export_model_so(model: nn.Module, device, output_path, args=None):
1639
max_seq_length = 350
1740

1841
input = (

generate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ def validate_build(
9191
reason = "model compilation for prefill"
9292
if self.compile:
9393
reason = "model compilation"
94-
if builder_args.dso_path:
95-
model_type = "DSO"
94+
if builder_args.pt2_path:
95+
model_type = "PT2"
9696
if builder_args.pte_path:
9797
model_type = "PTE"
9898
if model_type and reason:
@@ -103,7 +103,7 @@ def validate_build(
103103
@classmethod
104104
def from_args(cls, args):
105105
sequential_prefill = (
106-
args.sequential_prefill or bool(args.dso_path) or bool(args.pte_path)
106+
args.sequential_prefill or bool(args.pt2_path) or bool(args.pte_path)
107107
)
108108

109109
return cls(

0 commit comments

Comments
 (0)