Skip to content

Commit a12a27e

Browse files
authored
Merge branch 'main' into main
2 parents d237ba8 + 70260eb commit a12a27e

File tree

16 files changed

+1179
-170
lines changed

16 files changed

+1179
-170
lines changed

.github/workflows/pull.yml

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,32 +1092,11 @@ jobs:
10921092
id: install-torchao-ops
10931093
run: |
10941094
bash torchchat/utils/scripts/build_torchao_ops.sh
1095-
- name: Set git shas
1096-
id: setup-hash
1097-
run: |
1098-
export TORCHCHAT_ROOT=${PWD}
1099-
echo "et-git-hash=$(cat ${TORCHCHAT_ROOT}/install/.pins/et-pin.txt)" >> "$GITHUB_ENV"
1100-
- name: Load or install ET
1101-
id: install-et
1102-
uses: actions/cache@v4
1103-
with:
1104-
path: |
1105-
./et-build
1106-
./torchchat/utils/scripts/install_et.sh
1107-
key: et-build-${{runner.os}}-${{runner.arch}}-${{env.et-git-hash}}-${{ hashFiles('**/install_et.sh') }}
1108-
- if: ${{ steps.install-et.outputs.cache-hit != 'true' }}
1109-
continue-on-error: true
1095+
- name: Install ET
11101096
run: |
11111097
echo "Installing ExecuTorch"
1098+
export TORCHCHAT_ROOT=${PWD}
11121099
bash torchchat/utils/scripts/install_et.sh
1113-
- name: Install ExecuTorch python
1114-
run: |
1115-
echo "Install ExecuTorch python"
1116-
export TORCHCHAT_ROOT=$PWD
1117-
export ET_BUILD_DIR="et-build"
1118-
ENABLE_ET_PYBIND="${1:-true}"
1119-
source "torchchat/utils/scripts/install_utils.sh"
1120-
install_executorch_python_libs $ENABLE_ET_PYBIND
11211100
- name: Install runner
11221101
run: |
11231102
echo "Installing runner"
@@ -1132,14 +1111,14 @@ jobs:
11321111
wget -O ./tokenizer.model https://github.com/karpathy/llama2.c/raw/master/tokenizer.model
11331112
export PRMT="Once upon a time in a land far away"
11341113
echo "Generate eager"
1135-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1114+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11361115
echo "Generate compile"
1137-
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile
1116+
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile
11381117
echo "Export and run ET (C++ runner)"
1139-
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1118+
python torchchat.py export stories110M --output-pte-path ./model.pte --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11401119
./cmake-out/et_run ./model.pte -z ./tokenizer.model -t 0 -i "${PRMT}"
11411120
echo "Export and run AOTI (C++ runner)"
1142-
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}'
1121+
python torchchat.py export stories110M --output-dso-path ./model.so --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}'
11431122
./cmake-out/aoti_run ./model.so -z ./tokenizer.model -t 0 -i "${PRMT}"
11441123
echo "Generate AOTI"
11451124
python torchchat.py generate stories110M --dso-path ./model.so --prompt "${PRMT}"

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ python3 torchchat.py download llama3.1
171171
<summary>Additional Model Inventory Management Commands</summary>
172172

173173
### Where
174-
This subcommand shows location of a particular model.
174+
This subcommand shows the location of a particular model.
175175
```bash
176176
python3 torchchat.py where llama3.1
177177
```
@@ -216,7 +216,6 @@ This mode generates text based on an input prompt.
216216
python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy and his bear"
217217
```
218218

219-
[skip default]: end
220219

221220
### Server
222221
This mode exposes a REST API for interacting with a model.
@@ -286,14 +285,16 @@ First, follow the steps in the Server section above to start a local server. The
286285
streamlit run torchchat/usages/browser.py
287286
```
288287

288+
[skip default]: end
289+
289290
Use the "Max Response Tokens" slider to limit the maximum number of tokens generated by the model for each response. Click the "Reset Chat" button to remove the message history and start a fresh chat.
290291

291292

292293
## Desktop/Server Execution
293294

294295
### AOTI (AOT Inductor)
295296
[AOTI](https://pytorch.org/blog/pytorch2-2/) compiles models before execution for faster inference. The process creates a [DSO](https://en.wikipedia.org/wiki/Shared_library) model (represented by a file with extension `.so`)
296-
that is then loaded for inference. This can be done with both Python and C++ enviroments.
297+
that is then loaded for inference. This can be done with both Python and C++ environments.
297298

298299
The following example exports and executes the Llama3.1 8B Instruct
299300
model. The first command compiles and performs the actual export.
@@ -308,9 +309,9 @@ python3 torchchat.py export llama3.1 --output-dso-path exportedModels/llama3.1.s
308309
For more details on quantization and what settings to use for your use
309310
case visit our [customization guide](docs/model_customization.md).
310311

311-
### Run in a Python Enviroment
312+
### Run in a Python Environment
312313

313-
To run in a python enviroment, use the generate subcommand like before, but include the dso file.
314+
To run in a python environment, use the generate subcommand like before, but include the dso file.
314315

315316
```
316317
python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is"
@@ -377,7 +378,7 @@ While ExecuTorch does not focus on desktop inference, it is capable
377378
of doing so. This is handy for testing out PTE
378379
models without sending them to a physical device.
379380

380-
Specifically there are 2 ways of doing so: Pure Python and via a Runner
381+
Specifically, there are 2 ways of doing so: Pure Python and via a Runner
381382

382383
<details>
383384
<summary>Deploying via Python</summary>
@@ -501,7 +502,7 @@ The following assumes you've completed the steps for [Setting up ExecuTorch](#se
501502
and use [this script](https://github.com/pytorch/executorch/blob/main/build/build_android_llm_demo.sh) to build the AAR library.
502503
503504
<p align="center">
504-
<img src="https://pytorch.org/executorch/main/_static/img/android_llama_app.png" width="600" alt="Android app running a LlaMA model">
505+
<img src="https://pytorch.org/executorch/main/_static/img/chat.png" width="600" alt="Android app running a LlaMA model">
505506
</p>
506507
507508

assets/view.jpg

93.3 KB
Loading

dist_run.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
2121
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs
2222

23-
from torchchat.distributed.logging_utils import SingletonLogger
24-
2523
# TODO - these are not distributed specific, consider moving to new package
2624
from torchchat.distributed.checkpoint_utils import (
2725
get_hf_config_file,
2826
load_weights_from_hf_format,
2927
load_weights_from_torchchat_format,
3028
)
29+
30+
from torchchat.distributed.logging_utils import SingletonLogger
3131
from torchchat.distributed.utils import (
3232
bytes_to_readable,
3333
Color as color,
@@ -153,7 +153,9 @@ def _load_model_weights(
153153
# This format stands for:
154154
# single binary file, OR
155155
# multiple binary files without index files.
156-
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
156+
load_weights_from_torchchat_format(
157+
stage_module, distribution, device, model_config
158+
)
157159
else:
158160
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
159161

@@ -593,9 +595,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
593595
parser.add_argument(
594596
"model_name",
595597
type=str,
598+
default="llama3",
596599
help="Name of the model to load",
597600
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
598601
)
602+
599603
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
600604
parser.add_argument(
601605
"--ntokens",

docs/quantization.md

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,23 +120,32 @@ python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my n
120120

121121
## Experimental TorchAO lowbit kernels
122122

123+
WARNING: These kernels only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
124+
123125
### Use
124-
The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
126+
127+
#### linear:a8wxdq
128+
The quantization scheme linear:a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
125129
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
126130
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
127131
Roughly speaking, {bitwidth: 4, groupsize: 32, has_weight_zeros: false} is similar to GGML's Q4_0 quantization scheme.
128132

129-
You should expect high performance on ARM CPU if bitwidth is 1, 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
133+
You should expect high performance on ARM CPU if groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
134+
135+
#### embedding:wx
136+
The quantization scheme embedding:wx quantizes embeddings in a groupwise manner with the specified bitwidth and groupsize. It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize. Unlike linear:a8wxdq, embedding:wx always quantizes with scales and zeros.
137+
138+
You should expect high performance on ARM CPU if groupsize is divisible by 32. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
130139

131140
### Setup
132-
To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
141+
To use linear:a8wxdq and embedding:wx, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
133142

134143
From the torchchat root directory, run
135144
```
136145
sh torchchat/utils/scripts/build_torchao_ops.sh
137146
```
138147

139-
This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat.
148+
This should take about 10 seconds to complete.
140149

141150
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao_ops when running the scripts the build the runners.
142151

@@ -156,17 +165,17 @@ Below we show how to use the new kernels. Except for ExecuTorch, you can specif
156165

157166
#### Eager mode
158167
```
159-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
168+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
160169
```
161170

162171
#### torch.compile
163172
```
164-
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
173+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
165174
```
166175

167176
#### AOTI
168177
```
169-
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-dso llama3_1.so
178+
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-dso llama3_1.so
170179
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
171180
```
172181

@@ -178,7 +187,7 @@ OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cac
178187

179188
#### ExecuTorch
180189
```
181-
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-pte llama3_1.pte
190+
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"embedding:wx": {"bitwidth": 2, "groupsize": 32}, "linear:a8wxdq": {"bitwidth": 3, "groupsize": 128, "has_weight_zeros": false}}' --output-pte llama3_1.pte
182191
```
183192

184193
Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command.

install/.pins/torchao-pin.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
49b1fb61c8b8eceda755579a2fd92c756d822de2
1+
c8f1174a06dcc0102849c8348ca6573bde8847a9

install/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@ streamlit
3030

3131
# Server mode
3232
flask
33+
34+
# eval
35+
lm_eval==0.4.2

torchchat/cli/builder.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,14 @@
1616
import torch._inductor.config
1717
import torch.nn as nn
1818

19-
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
20-
21-
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
22-
2319
from torch.distributed.device_mesh import DeviceMesh
20+
from torch.distributed.elastic.multiprocessing.errors import record
21+
from torch.distributed.elastic.utils.distributed import get_free_port
2422

25-
from torchtune.models.convert_weights import meta_to_tune
26-
27-
from torchtune.training import set_default_dtype
23+
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama
2824

2925
from torchchat.model import Model, ModelArgs, ModelType
3026

31-
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
32-
3327
from torchchat.model_config.model_config import resolve_model_config
3428
from torchchat.utils.build_utils import (
3529
device_sync,
@@ -40,6 +34,14 @@
4034
from torchchat.utils.measure_time import measure_time
4135
from torchchat.utils.quantize import quantize_model
4236

37+
from torchtune.models.convert_weights import meta_to_tune
38+
39+
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
40+
41+
from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune
42+
43+
from torchtune.training import set_default_dtype
44+
4345

4446
@dataclass
4547
class BuilderArgs:
@@ -55,7 +57,10 @@ class BuilderArgs:
5557
device: Optional[str] = None
5658
precision: torch.dtype = torch.float32
5759
setup_caches: bool = False
58-
use_distributed: bool = False
60+
distributed: bool = False
61+
pp: int = 1
62+
tp: int = 1
63+
chpt_from: str = "hf"
5964
is_chat_model: bool = False
6065
prefill_possible: bool = False
6166
dynamic_shapes: bool = False
@@ -87,7 +92,9 @@ def __post_init__(self):
8792
]
8893
for param, param_msg in ignored_params:
8994
if param:
90-
print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified")
95+
print(
96+
f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified"
97+
)
9198
else:
9299
self.prefill_possible = True
93100

@@ -153,7 +160,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
153160
dtype = torch.float16
154161
else:
155162
dtype = name_to_dtype(args.dtype, args.device)
156-
163+
# distributed args
164+
distributed = getattr(args, "distributed", False)
165+
pp = getattr(args, "pp", 1)
166+
tp = getattr(args, "tp", 1)
167+
chpt_from = getattr(args, "chpt_from", "hf")
157168
return cls(
158169
checkpoint_dir=checkpoint_dir,
159170
checkpoint_path=checkpoint_path,
@@ -167,7 +178,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
167178
device=args.device,
168179
precision=dtype,
169180
setup_caches=(output_dso_path or output_pte_path),
170-
use_distributed=args.distributed,
181+
distributed=distributed,
182+
pp=pp,
183+
tp=tp,
184+
chpt_from=chpt_from,
171185
is_chat_model=is_chat_model,
172186
dynamic_shapes=getattr(args, "dynamic_shapes", False),
173187
max_seq_length=getattr(args, "max_seq_length", None),
@@ -397,10 +411,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
397411
# does not host any actual values, need to reinitialize them in the actual
398412
# device. Only do those buffer initialization, without initializing the entire
399413
# model.
400-
decoder_config = model.config.transformer_args['decoder']
401-
head_dim = decoder_config['embed_dim'] // decoder_config['num_heads']
402-
max_seq_len = decoder_config['max_seq_len']
403-
rope_base = decoder_config['rope_base']
414+
decoder_config = model.config.transformer_args["decoder"]
415+
head_dim = decoder_config["embed_dim"] // decoder_config["num_heads"]
416+
max_seq_len = decoder_config["max_seq_len"]
417+
rope_base = decoder_config["rope_base"]
404418
for submodule in model.modules():
405419
if isinstance(submodule, Llama3ScaledRoPE):
406420
submodule.__init__(head_dim, max_seq_len, rope_base)
@@ -476,18 +490,19 @@ def _maybe_parallelize_model(
476490

477491

478492
def _load_model(builder_args: BuilderArgs) -> Model:
479-
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
493+
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
480494
if builder_args.gguf_path:
481495
model = _load_model_gguf(builder_args)
482-
elif builder_args.use_distributed:
483-
model = _init_model_on_meta_device(builder_args)
496+
# elif builder_args.use_distributed:
497+
# model = _init_model_on_meta_device(builder_args)
484498
else:
485499
model = _load_model_default(builder_args)
486-
model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
500+
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
487501

488502
model = model.to(device=builder_args.device, dtype=builder_args.precision)
489503
return model.eval()
490504

505+
491506
def _initialize_model(
492507
builder_args: BuilderArgs,
493508
quantize,
@@ -496,7 +511,6 @@ def _initialize_model(
496511
support_tensor_subclass: bool = True,
497512
) -> Model:
498513
print("Loading model...")
499-
500514
if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
501515
print("Setting gguf_kwargs for generate.")
502516
is_dso = builder_args.dso_path is not None

0 commit comments

Comments
 (0)