Skip to content

Commit 2c8b7b5

Browse files
committed
Update AOTI package
1 parent f730056 commit 2c8b7b5

File tree

12 files changed

+171
-86
lines changed

12 files changed

+171
-86
lines changed

.ci/scripts/validate.sh

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,60 +133,60 @@ function generate_aoti_model_output() {
133133
echo "******************************************"
134134
echo "************** non-quantized *************"
135135
echo "******************************************"
136-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path "${MODEL_DIR}/${MODEL_NAME}.so" --device "$TARGET_DEVICE" || exit 1
137-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path "$MODEL_DIR/${MODEL_NAME}.so" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
136+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path "${MODEL_DIR}/${MODEL_NAME}.pt2" --device "$TARGET_DEVICE" || exit 1
137+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --aoti-package-path "$MODEL_DIR/${MODEL_NAME}.pt2" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
138138
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
139139

140140
echo "******************************************"
141141
echo "******* Emb: channel-wise quantized ******"
142142
echo "******************************************"
143-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
144-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
143+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
144+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
145145
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
146146

147147
echo "******************************************"
148148
echo "******** Emb: group-wise quantized *******"
149149
echo "******************************************"
150-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
151-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
150+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
151+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
152152
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
153153

154154
echo "***********************************************"
155155
echo "******* Emb: 4bit channel-wise quantized ******"
156156
echo "***********************************************"
157-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 0, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
158-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
157+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 0, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
158+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
159159
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
160160

161161
echo "***********************************************"
162162
echo "******** Emb: 4bit group-wise quantized *******"
163163
echo "***********************************************"
164-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
165-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
164+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"embedding" : {"bitwidth": 4, "groupsize": 8, "packed": "True"}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
165+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
166166
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
167167

168168
if [ "${EXCLUDE_INT8_QUANT:-false}" == false ]; then
169169
echo "******************************************"
170170
echo "******* INT8 channel-wise quantized ******"
171171
echo "******************************************"
172-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
173-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
172+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 0}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
173+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
174174
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
175175

176176
echo "******************************************"
177177
echo "******** INT8 group-wise quantized *******"
178178
echo "******************************************"
179-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
180-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
179+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int8" : {"bitwidth": 8, "groupsize": 8}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
180+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
181181
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
182182
fi
183183
echo "******************************************"
184184
echo "******** INT4 group-wise quantized *******"
185185
echo "******************************************"
186186
if [[ "$TARGET_DEVICE" != "cuda" || "$DTYPE" == "bfloat16" ]]; then
187187
# For CUDA, only bfloat16 makes sense for int4 mm kernel
188-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" || exit 1
189-
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
188+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant '{"linear:int4" : {"groupsize": 32}}' --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" || exit 1
189+
python3 -W ignore torchchat.py generate --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --temperature 0 --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1
190190
.ci/scripts/check_gibberish "$MODEL_DIR/output_aoti"
191191
fi
192192
done
@@ -285,8 +285,8 @@ function eval_model_sanity_check() {
285285
echo "******** INT4 group-wise quantized (AOTI) *******"
286286
echo "*************************************************"
287287
if [ "$DTYPE" != "float16" ]; then
288-
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-dso-path ${MODEL_DIR}/${MODEL_NAME}.so --dynamic-shapes --device "$TARGET_DEVICE" || exit 1
289-
python3 -W ignore torchchat.py eval --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path ${MODEL_DIR}/${MODEL_NAME}.so --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
288+
python3 -W ignore torchchat.py export --dtype ${DTYPE} --quant "$QUANT_OPTIONS" --checkpoint-path "$CHECKPOINT_PATH" --output-aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --dynamic-shapes --device "$TARGET_DEVICE" || exit 1
289+
python3 -W ignore torchchat.py eval --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --aoti-package-path ${MODEL_DIR}/${MODEL_NAME}.pt2 --device "$TARGET_DEVICE" --limit 5 > "$MODEL_DIR/output_eval_aoti" || exit 1
290290
cat "$MODEL_DIR/output_eval_aoti"
291291
fi;
292292
fi;

.github/workflows/pull.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ jobs:
378378
379379
echo "::group::Run inference with quantize file"
380380
if [ $(uname -s) == Darwin ]; then
381-
python3 torchchat.py export --output-dso-path /tmp/model.so --quantize torchchat/quant_config/cuda.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth"
382-
python3 torchchat.py generate --dso-path /tmp/model.so --checkpoint "./checkpoints/${REPO_NAME}/model.pth"~
381+
python3 torchchat.py export --output-aoti-package-path /tmp/model.pt2 --quantize torchchat/quant_config/cuda.json --checkpoint "./checkpoints/${REPO_NAME}/model.pth"
382+
python3 torchchat.py generate --aoti-package-path /tmp/model.pt2 --checkpoint "./checkpoints/${REPO_NAME}/model.pth"~
383383
fi
384384
echo "::endgroup::"
385385
@@ -1003,8 +1003,8 @@ jobs:
10031003
10041004
for dtype in fp32 fp16 bf16 fast fast16; do
10051005
echo "Running export + runner with dtype=$dtype"
1006-
python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --dtype $dtype --output-dso-path /tmp/model.so
1007-
./cmake-out/aoti_run /tmp/model.so -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
1006+
python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --dtype $dtype --output-aoti-package-path /tmp/model.pt2
1007+
./cmake-out/aoti_run /tmp/model.pt2 -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
10081008
done
10091009
10101010
echo "Tests complete."

.github/workflows/runner-cuda-dtype.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ jobs:
5656
for DTYPE in bfloat16; do
5757
python torchchat.py generate --dtype ${DTYPE} --checkpoint-path ${MODEL_DIR}/stories15M.pt --temperature 0 --prompt "${PROMPT}" --device cuda
5858
59-
python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-dso-path /tmp/model.so
59+
python torchchat.py export --checkpoint-path ${MODEL_DIR}/stories15M.pt --output-pt2-path /tmp/model.pt2
6060
61-
./cmake-out/aoti_run /tmp/model.so -d CUDA -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
61+
./cmake-out/aoti_run /tmp/model.pt2 -d CUDA -z ${MODEL_DIR}/tokenizer.model -i "${PROMPT}"
6262
6363
done
6464

README.md

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy an
182182
[skip default]: end
183183

184184
### Server
185-
This mode exposes a REST API for interacting with a model.
185+
This mode exposes a REST API for interacting with a model.
186186
The server follows the [OpenAI API specification](https://platform.openai.com/docs/api-reference/chat) for chat completions.
187187

188188
To test out the REST API, **you'll need 2 terminals**: one to host the server, and one to send the request.
@@ -255,13 +255,14 @@ Use the "Max Response Tokens" slider to limit the maximum number of tokens gener
255255
## Desktop/Server Execution
256256

257257
### AOTI (AOT Inductor)
258-
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a [DSO](https://en.wikipedia.org/wiki/Shared_library) model (represented by a file with extension `.so`)
258+
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a zipped PT2 file containing all the artifacts generated by AOTInductor, and a [.so](https://en.wikipedia.org/wiki/Shared_library) file with the runnable contents
259259
that is then loaded for inference. This can be done with both Python and C++ enviroments.
260260

261261
The following example exports and executes the Llama3.1 8B Instruct
262262
model. The first command compiles and performs the actual export.
263-
```
264-
python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.so
263+
264+
```bash
265+
python3 torchchat.py export llama3.1 --output-aoti-package-path exportedModels/llama3_1_artifacts.pt2
265266
```
266267

267268
> [!NOTE]
@@ -273,12 +274,11 @@ case visit our [customization guide](docs/model_customization.md).
273274

274275
### Run in a Python Enviroment
275276

276-
To run in a python enviroment, use the generate subcommand like before, but include the dso file.
277+
To run in a python enviroment, use the generate subcommand like before, but include the pt2 file.
277278

279+
```bash
280+
python3 torchchat.py generate llama3.1 --aoti-package-path exportedModels/llama3_1_artifacts.pt2 --prompt "Hello my name is"
278281
```
279-
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
280-
```
281-
**Note:** Depending on which accelerator is used to generate the .dso file, the command may need the device specified: `--device (cuda | cpu)`.
282282

283283

284284
### Run using our C++ Runner
@@ -288,11 +288,10 @@ To run in a C++ enviroment, we need to build the runner binary.
288288
torchchat/utils/scripts/build_native.sh aoti
289289
```
290290

291-
Then run the compiled executable, with the exported DSO from earlier.
291+
Then run the compiled executable, with the pt2.
292292
```bash
293-
cmake-out/aoti_run exportedModels/llama3.1.so -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
293+
cmake-out/aoti_run exportedModels/llama3_1_artifacts.pt2 -z `python3 torchchat.py where llama3.1`/tokenizer.model -l 3 -i "Once upon a time"
294294
```
295-
**Note:** Depending on which accelerator is used to generate the .dso file, the runner may need the device specified: `-d (CUDA | CPU)`.
296295

297296
## Mobile Execution
298297

install/install_requirements.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ fi
4747
# NOTE: If a newly-fetched version of the executorch repo changes the value of
4848
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
4949
# package versions.
50-
PYTORCH_NIGHTLY_VERSION=dev20240814
50+
PYTORCH_NIGHTLY_VERSION=dev20240913
5151

5252
# Nightly version for torchvision
53-
VISION_NIGHTLY_VERSION=dev20240814
53+
VISION_NIGHTLY_VERSION=dev20240913
5454

5555
# Nightly version for torchtune
5656
TUNE_NIGHTLY_VERSION=dev20240916
@@ -74,7 +74,7 @@ fi
7474

7575
# pip packages needed by exir.
7676
REQUIREMENTS_TO_INSTALL=(
77-
torch=="2.5.0.${PYTORCH_NIGHTLY_VERSION}"
77+
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
7878
torchvision=="0.20.0.${VISION_NIGHTLY_VERSION}"
7979
torchtune=="0.3.0.${TUNE_NIGHTLY_VERSION}"
8080
)

runner/run.cpp

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,7 @@ LICENSE file in the root directory of this source tree.
3131
#endif
3232

3333
#ifdef __AOTI_MODEL__
34-
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cpu.h>
35-
#ifdef USE_CUDA
36-
#include <torch/csrc/inductor/aoti_runner/model_container_runner_cuda.h>
37-
#endif
34+
#include <torch/csrc/inductor/aoti_package/model_package_loader.h>
3835
torch::Device aoti_device(torch::kCPU);
3936

4037
#else // __ET_MODEL__
@@ -93,7 +90,7 @@ typedef struct {
9390
RunState state; // buffers for the "wave" of activations in the forward pass
9491

9592
#ifdef __AOTI_MODEL__
96-
torch::inductor::AOTIModelContainerRunner* runner;
93+
torch::inductor::AOTIModelPackageLoader* runner;
9794
#else // __ET_MODEL__
9895
Module* runner;
9996
#endif
@@ -143,16 +140,8 @@ void build_transformer(
143140
malloc_run_state(&t->state, &t->config);
144141

145142
#ifdef __AOTI_MODEL__
146-
#ifdef USE_CUDA
147-
if (aoti_device.type() == torch::kCUDA) {
148-
t->runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path);
149-
aoti_device = torch::Device(torch::kCUDA);
150-
} else {
151-
#else
152-
{
153-
#endif
154-
t->runner = new torch::inductor::AOTIModelContainerRunnerCpu(model_path);
155-
}
143+
t->runner = new torch::inductor::AOTIModelPackageLoader(model_path);
144+
aoti_device = t->runner->get_metadata()["AOTI_DEVICE_KEY"] == "cpu" ? torch::Device(torch::kCPU) : torch::Device(torch::kCUDA);
156145
#else //__ET_MODEL__
157146
t->runner = new Module(
158147
/* path to PTE model */ model_path,

0 commit comments

Comments
 (0)