Skip to content

Commit cece5bc

Browse files
committed
Update
[ghstack-poisoned]
2 parents 40bcf6f + f3fc096 commit cece5bc

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

.buckconfig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
shim_et = shim_et
1212

1313
[repository_aliases]
14+
bazel_skylib = shim
1415
config = prelude
1516
ovr_config = prelude
1617
toolchains = shim_et

.ci/scripts/unittest-buck2.sh

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@ set -eux
88

99
# TODO: expand this to //...
1010
# TODO: can't query cadence & vulkan backends
11-
buck2 query "//backends/apple/... + //backends/arm/... + \
12-
//backends/example/... + //backends/mediatek/... + //backends/test/... + \
13-
//backends/transforms/... + //backends/xnnpack/... + //configurations/... + \
14-
//kernels/portable/cpu/... + //runtime/... + //schema/... + //test/... + \
15-
//util/..."
11+
buck2 query "//backends/apple/... + //backends/example/... + \
12+
//backends/mediatek/... + //backends/test/... + //backends/transforms/... + \
13+
//backends/xnnpack/... + //configurations/... + //kernels/portable/cpu/... + \
14+
//runtime/... + //schema/... + //test/... + //util/..."
1615

1716
# TODO: expand the covered scope of Buck targets.
1817
buck2 build //runtime/core/portable_type/...

examples/models/llama/source_transformation/quantize.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,10 @@ def quantize( # noqa C901
119119
# Check for required args
120120
if group_size is None:
121121
raise Exception("For 8da4w quantization, group size must be specified.")
122-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
123122

124-
model = Int8DynActInt4WeightQuantizer(
125-
precision=torch_dtype, groupsize=group_size
126-
).quantize(model)
123+
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
124+
125+
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
127126

128127
if verbose:
129128
print("quantized model:", model)
@@ -663,7 +662,7 @@ def convert_for_runtime(self) -> nn.Module:
663662
def quantized_model(self) -> nn.Module:
664663
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
665664
self.convert_for_runtime()
666-
self.mod.load_state_dict(model_updated_state_dict)
665+
self.mod.load_state_dict(model_updated_state_dict, assign=True)
667666
return self.mod
668667

669668

0 commit comments

Comments
 (0)