You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2024-09-26-pytorch-native-architecture-optimization.md
+29-31Lines changed: 29 additions & 31 deletions
Original file line number
Diff line number
Diff line change
@@ -10,57 +10,50 @@ We’re happy to officially launch torchao, a PyTorch native library that makes
10
10
We benchmarked our techniques on popular GenAI models like LLama 3 and Diffusion models and saw minimal drops in accuracy. Unless otherwise noted the baselines are bf16 run on A100 80GB GPU.
11
11
12
12
Our topline metrics for llama 3 are
13
-
14
-
For inference
15
-
16
-
* 97% speedup for Llama 3 8B using autoquant with int4 weight only quantization and hqq
17
-
* 73% peak VRAM reduction for Llama 3.1 8B at 128K context length with a quantized KV cache
18
-
19
-
For training
20
-
13
+
* 97% speedup for Llama 3 8B inference using autoquant with int4 weight only quantization and hqq
14
+
* 73% peak VRAM reduction for Llama 3.1 8B inference at 128K context length with a quantized KV cache
21
15
* 50% speedup for Llama 3 70B pretraining using float8 training on H100
22
16
* 30% peak VRAM reduction for Llama 3 8B using 4 bit quantized optimizers.
23
17
24
18
Our topline metrics for diffusion model inference
25
-
26
19
* 53% speedup using float8 dynamic quantization inference with float8 row-wise scaling on flux1.dev onH100
27
20
* 50% reduction in model VRAM for CogVideoX using int8 dynamic quantization
28
21
29
22
Below we'll walk through some of the techniques available in torchao you can apply to your models for inference and training.
30
23
31
24
## Inference
32
25
33
-
[Our inference quantization algorithms](https://github.com/pytorch/ao/tree/main/torchao/quantization) work over arbitrary PyTorch models that contain nn.Linear layers. Weight only and dynamic activation quantization for various dtypes and sparse layouts can be chosen using our top level quantize\_ api
26
+
[Our inference quantization algorithms](https://github.com/pytorch/ao/tree/main/torchao/quantization) work over arbitrary PyTorch models that contain nn.Linear layers. Weight only and dynamic activation quantization for various dtypes and sparse layouts can be chosen using our top level `quantize_` api
34
27
35
-
```py
28
+
```python
36
29
from torchao.quantization import (
37
-
quantize\_,
38
-
int4\_weight\_only,
30
+
quantize_,
31
+
int4_weight_only,
39
32
)
40
-
quantize\_(model, int4\_weight\_only())
33
+
quantize_(model, int4_weight_only())
41
34
```
42
35
43
36
Sometimes quantizing a layer can make it slower because of overhead so if you’d rather we just pick how to quantize each layer in a model for you then you can instead run
44
37
45
-
```py
46
-
model \= torchao.autoquant(torch.compile(model, mode='max-autotune'))
38
+
```python
39
+
model = torchao.autoquant(torch.compile(model, mode='max-autotune'))
47
40
```
48
41
49
-
quantize\_ API has a few different options depending on whether your model is compute bound or memory bound.
42
+
`quantize_` API has a few different options depending on whether your model is compute bound or memory bound.
@@ -86,8 +79,10 @@ Post training quantization, especially at less than 4 bit can suffer from seriou
86
79
87
80
torchao provides easy to use e2e workflows for reducing the precision of training compute and distributed communications, starting with float8 for \`torch.nn.Linear\` layers.Here is a one-liner to convert the compute gemms of your training run to float8:
88
81
89
-
from torchao.float8 import convert\_to\_float8\_training
90
-
convert\_to\_float8\_training(model)
82
+
```python
83
+
from torchao.float8 import convert_to_float8_training
84
+
convert_to_float8_training(model)
85
+
```
91
86
92
87
For an e2e example of how to speed up LLaMa 3 70B pretraining by up to **1.5x** with float8, see our [README](https://github.com/pytorch/ao/tree/main/torchao/float8), and torchtitan's [blog](https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359) and [float8 recipe](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md).
93
88
@@ -106,8 +101,11 @@ We are expanding our training workflows to more dtypes and layouts
106
101
107
102
Inspired by Bits and Bytes we’ve also added prototype support for 8 and 4 bit optimizers as a drop in replacement for AdamW.
108
103
109
-
from torchao.prototype.low\_bit\_optim import AdamW8bit, AdamW4bit
110
-
optim \= AdamW8bit(model.parameters())
104
+
```python
105
+
from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit
There are a lot of things we’re excited about next ranging from going lower than 4 bit, performant kernels for high-throughput inference, expanding to more layers, scaling types or granularities, MX hardware support and supporting more hardware backends. If any of the above sounds exciting you can follow our progress at: [https://github.com/pytorch/ao](https://github.com/pytorch/ao)
131
129
132
-
If you’re interested in working on torchao, we’ve created a [contributors guide](https://github.com/pytorch/ao/issues/391), and if you have any questions we hang out on the \#torchao channel on [discord.gg/cudamode](http://discord.gg/cudamode)
130
+
If you’re interested in working on torchao, we’ve created a [contributors guide](https://github.com/pytorch/ao/issues/391), and if you have any questions we hang out on the `#torchao` channel on [discord.gg/gpumode](http://discord.gg/gpumode)
133
131
134
132
## Acknowledgements
135
133
136
-
We are fortunate to stand on the shoulders of giants and collaborate with some of the best people in open source. Thank you\!
134
+
We are fortunate to stand on the shoulders of giants and collaborate with some of the best people in open source. Thank you!
137
135
138
136
1. Bits and Bytes for pioneering work in low bit optimizers and QLoRA
139
137
2. Answer.ai for their engineering work to get FSDP and QLoRA composing
140
138
3. Mobius Labs for the lovely back and forths on quantization algorithms and low bit kernels
141
139
4. HuggingFace transformers for their help in battle testing and integrating our work
142
140
5. HuggingFace diffusers for our collaboration on extensive benchmarks and best practices
143
141
6. torch.compile so we could write our algorithms in pure PyTorch
0 commit comments