Skip to content

Commit 03c056b

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
Update int 4 flow for consistency
Summary: Update int 4 flow for consistency Reviewed By: kimishpatel Differential Revision: D54350318 fbshipit-source-id: 7d04c46f2fc78e277b515828f2a47f4acb5c7e86
1 parent a81bfe2 commit 03c056b

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,8 @@ def quantize(
192192
elif qmode == "int4":
193193
model_int4 = Int8DynActInt4WeightQuantHandler(
194194
model, activation_precision=torch_dtype
195-
)
196-
model_int4_state_dict = model_int4.create_quantized_state_dict()
197-
model_int4 = model_int4.convert_for_runtime()
195+
).quantized_model()
198196
print("quantized model:", model_int4)
199-
model_int4.load_state_dict(model_int4_state_dict)
200197
return model_int4
201198
else:
202199
raise Exception(f"Unrecognized quantize mode: {qmode}")

examples/models/llama2/quantize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,12 @@ def convert_for_runtime(self):
10021002
)
10031003
return self.mod
10041004

1005+
def quantized_model(self) -> nn.Module:
1006+
model_updated_state_dict = self.create_quantized_state_dict()
1007+
self.convert_for_runtime()
1008+
self.mod.load_state_dict(model_updated_state_dict)
1009+
return self.mod
1010+
10051011

10061012
class Int8DynActInt4WeightLinear(torch.nn.Module):
10071013
__constants__ = ["in_features", "out_features"]

0 commit comments

Comments
 (0)