File tree Expand file tree Collapse file tree 3 files changed +9
-10
lines changed
examples/models/llama/source_transformation Expand file tree Collapse file tree 3 files changed +9
-10
lines changed Original file line number Diff line number Diff line change 11
11
shim_et = shim_et
12
12
13
13
[repository_aliases]
14
+ bazel_skylib = shim
14
15
config = prelude
15
16
ovr_config = prelude
16
17
toolchains = shim_et
Original file line number Diff line number Diff line change @@ -8,11 +8,10 @@ set -eux
8
8
9
9
# TODO: expand this to //...
10
10
# TODO: can't query cadence & vulkan backends
11
- buck2 query " //backends/apple/... + //backends/arm/... + \
12
- //backends/example/... + //backends/mediatek/... + //backends/test/... + \
13
- //backends/transforms/... + //backends/xnnpack/... + //configurations/... + \
14
- //kernels/portable/cpu/... + //runtime/... + //schema/... + //test/... + \
15
- //util/..."
11
+ buck2 query " //backends/apple/... + //backends/example/... + \
12
+ //backends/mediatek/... + //backends/test/... + //backends/transforms/... + \
13
+ //backends/xnnpack/... + //configurations/... + //kernels/portable/cpu/... + \
14
+ //runtime/... + //schema/... + //test/... + //util/..."
16
15
17
16
# TODO: expand the covered scope of Buck targets.
18
17
buck2 build //runtime/core/portable_type/...
Original file line number Diff line number Diff line change @@ -119,11 +119,10 @@ def quantize( # noqa C901
119
119
# Check for required args
120
120
if group_size is None :
121
121
raise Exception ("For 8da4w quantization, group size must be specified." )
122
- from torchao .quantization .quant_api import Int8DynActInt4WeightQuantizer
123
122
124
- model = Int8DynActInt4WeightQuantizer (
125
- precision = torch_dtype , groupsize = group_size
126
- ). quantize (model )
123
+ from torchao . quantization import int8_dynamic_activation_int4_weight , quantize_
124
+
125
+ quantize_ (model , int8_dynamic_activation_int4_weight ( group_size = group_size ) )
127
126
128
127
if verbose :
129
128
print ("quantized model:" , model )
@@ -663,7 +662,7 @@ def convert_for_runtime(self) -> nn.Module:
663
662
def quantized_model (self ) -> nn .Module :
664
663
model_updated_state_dict = self .create_quantized_state_dict (self .packed )
665
664
self .convert_for_runtime ()
666
- self .mod .load_state_dict (model_updated_state_dict )
665
+ self .mod .load_state_dict (model_updated_state_dict , assign = True )
667
666
return self .mod
668
667
669
668
You can’t perform that action at this time.
0 commit comments