Skip to content

Commit 386d1b3

Browse files
mcr229kirklandsign
authored andcommitted
Default Uninitialized Llama2 Weights to Zeros, and Provide Better Quantization for Example Models (#9634)
### Summary Changes: 1. When initializing Llama2 for aot_compiler, since checkpoints can only e downloaded from hugging face, we initialize llama2 with uninitialized weights. The problem with this is that when running quantization, we can run into errors with the histogram if the unitialized values are nan. We fix this by initializing the weights with zeros if no check point is provided. This enforces that quantization step can still work. 2. Quant Type in AoT compiler. When looking at the model options available to XNNPACK, everything is quantized with per-tensor static quantization. This isn't the best option for all the models available. For example transformer based models like Llama and MobileBert would likely prefer dynamically quantized per channel weights, where has CNN like MobileNet would prefer statically quantized per channel weights. We add this type of Quant Type to the existing models options. This also helps with Test Timeouts. per-tensor static quantization on a model like llama can take a long time due to the introduction of MANY q/dq nodes, and the complex partitions it creates. As a result, proposing partitions can take a long time due to the constant BFS to find the largest possible partition. By specifying the more apt quantization scheme like dynamic per-channel quantization, we can avoid this complexity. Overall this should help with flakey [nan, nan] errors in the quantization histogram, and it should also help with CI timing out. ### Test plan OSS XNNPACK CI for all model delegation cc @digantdesai @cbilgin
1 parent 0b1c29b commit 386d1b3

File tree

5 files changed

+60
-28
lines changed

5 files changed

+60
-28
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any
1515

1616
from examples.models import MODEL_NAME_TO_MODEL
17-
from examples.xnnpack import MODEL_NAME_TO_OPTIONS
17+
from examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType
1818

1919
DEFAULT_RUNNERS = {
2020
"linux": "linux.2xlarge",
@@ -154,7 +154,7 @@ def export_models_for_ci() -> dict[str, dict]:
154154
if backend == "xnnpack":
155155
if name not in MODEL_NAME_TO_OPTIONS:
156156
continue
157-
if MODEL_NAME_TO_OPTIONS[name].quantization:
157+
if MODEL_NAME_TO_OPTIONS[name].quantization != QuantType.NONE:
158158
backend += "-quantization"
159159

160160
if MODEL_NAME_TO_OPTIONS[name].delegation:

examples/models/llama/model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,15 +259,22 @@ def __init__(self, **kwargs):
259259
assign=True,
260260
) # self.model_ = Transformer(gptconf)
261261
else:
262-
print("Checkpoint not provided, defaulting to uninitialized weights.")
262+
print("Checkpoint not provided, defaulting weights to zeros.")
263263
self.model_.to_empty(device="cpu")
264+
for p in self.model_.parameters():
265+
p.data.fill_(0)
266+
for b in self.model_.buffers():
267+
b.data.fill_(0)
264268
except RuntimeError as e:
265269
print(
266-
f"Could not load checkpoint into mode and will default to uninitialized weights due to error: {e}."
270+
f"Could not load checkpoint into mode and will defaulting weights to zeros due to error: {e}."
267271
)
268272
# Need to provide concrete (empty) values for meta-initialized tensors for quantization.
269273
self.model_.to_empty(device="cpu")
270-
274+
for p in self.model_.parameters():
275+
p.data.fill_(0)
276+
for b in self.model_.buffers():
277+
b.data.fill_(0)
271278
if missing:
272279
missing_weights = [fqn for fqn in missing if fqn.endswith(".weight")]
273280
if missing_weights:

examples/xnnpack/__init__.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,44 @@
77
# pyre-unsafe
88

99
from dataclasses import dataclass
10+
from enum import Enum
11+
12+
13+
class QuantType(Enum):
14+
NONE = 1
15+
# Used for Operations that don't have weights
16+
STATIC_PER_TENSOR = 2
17+
# Used best for CNN/RNN Models with Conv layers
18+
STATIC_PER_CHANNEL = 3
19+
# Used for Linear Layers and Transformer Based Models
20+
DYNAMIC_PER_CHANNEL = 4
1021

1122

1223
@dataclass
1324
class XNNPACKOptions(object):
14-
quantization: bool
25+
quantization: QuantType
1526
delegation: bool
1627

1728

1829
MODEL_NAME_TO_OPTIONS = {
19-
"linear": XNNPACKOptions(True, True),
20-
"add": XNNPACKOptions(True, True),
21-
"add_mul": XNNPACKOptions(True, True),
22-
"dl3": XNNPACKOptions(True, True),
23-
"ic3": XNNPACKOptions(True, True),
24-
"ic4": XNNPACKOptions(True, True),
25-
"mv2": XNNPACKOptions(True, True),
26-
"mv3": XNNPACKOptions(True, True),
27-
"resnet18": XNNPACKOptions(True, True),
28-
"resnet50": XNNPACKOptions(True, True),
29-
"vit": XNNPACKOptions(True, True),
30-
"w2l": XNNPACKOptions(True, True),
31-
"edsr": XNNPACKOptions(True, True),
32-
"mobilebert": XNNPACKOptions(True, True),
33-
"llama2": XNNPACKOptions(False, True),
34-
"emformer_join": XNNPACKOptions(True, True),
35-
"emformer_predict": XNNPACKOptions(True, True),
36-
"emformer_transcribe": XNNPACKOptions(True, True),
30+
"linear": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
31+
"add": XNNPACKOptions(QuantType.STATIC_PER_TENSOR, True),
32+
"add_mul": XNNPACKOptions(QuantType.STATIC_PER_TENSOR, True),
33+
"dl3": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
34+
"ic3": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
35+
"ic4": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
36+
"mv2": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
37+
"mv3": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
38+
"resnet18": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
39+
"resnet50": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
40+
"vit": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
41+
"w2l": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
42+
"edsr": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
43+
"mobilebert": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
44+
"llama2": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
45+
"emformer_join": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
46+
"emformer_predict": XNNPACKOptions(QuantType.DYNAMIC_PER_CHANNEL, True),
47+
"emformer_transcribe": XNNPACKOptions(QuantType.STATIC_PER_CHANNEL, True),
3748
}
3849

3950

examples/xnnpack/aot_compiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767
args = parser.parse_args()
6868

69-
if not args.delegate:
69+
if not args.delegate and args.quantize:
7070
raise NotImplementedError(
7171
"T161880157: Quantization-only without delegation is not supported yet"
7272
)
@@ -79,6 +79,8 @@
7979
f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
8080
)
8181

82+
quant_type = MODEL_NAME_TO_OPTIONS[args.model_name].quantization
83+
8284
model, example_inputs, _, _ = EagerModelFactory.create_model(
8385
*MODEL_NAME_TO_MODEL[args.model_name]
8486
)
@@ -91,7 +93,7 @@
9193
if args.quantize:
9294
logging.info("Quantizing Model...")
9395
# TODO(T165162973): This pass shall eventually be folded into quantizer
94-
model = quantize(model, example_inputs)
96+
model = quantize(model, example_inputs, quant_type)
9597
ep = torch.export.export_for_training(model, example_inputs)
9698

9799
edge = to_edge_transform_and_lower(

examples/xnnpack/quantization/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,25 @@
1313

1414
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
1515

16+
from .. import QuantType
1617

17-
def quantize(model, example_inputs):
18+
19+
def quantize(
20+
model, example_inputs, quant_type: QuantType = QuantType.STATIC_PER_TENSOR
21+
):
1822
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
1923
logging.info(f"Original model: {model}")
2024
quantizer = XNNPACKQuantizer()
2125
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
22-
operator_config = get_symmetric_quantization_config(is_per_channel=False)
26+
is_per_channel = (
27+
quant_type == QuantType.STATIC_PER_CHANNEL
28+
or quant_type == QuantType.DYNAMIC_PER_CHANNEL
29+
)
30+
is_dynamic = quant_type == QuantType.DYNAMIC_PER_CHANNEL
31+
operator_config = get_symmetric_quantization_config(
32+
is_per_channel=is_per_channel,
33+
is_dynamic=is_dynamic,
34+
)
2335
quantizer.set_global(operator_config)
2436
m = prepare_pt2e(model, quantizer)
2537
# calibration

0 commit comments

Comments
 (0)